"""Tests for seed constants, inline constants, and non-commutative validation. Tests verify: - Seed syntax (const N |> &dest) produces seed tokens, not IRAM slots - Inline const syntax (&foo <| add 7) is equivalent to (&foo <| add, 7) - Non-commutative ops with IRAM const require explicit port on incoming edges - Seed tokens match target instruction arity (DyadToken for dyadic, MonadToken for monadic) - Triggerable constants (&c <| const, 7) still work as before """ import simpy from asm import assemble, run_pipeline from asm.codegen import generate_direct from asm.errors import ErrorSeverity from asm.ir import ( IRGraph, IRNode, IREdge, SystemConfig, SourceLoc, ResolvedDest, ) from asm.opcodes import is_dyadic from cm_inst import OutputStyle, ArithOp, LogicOp, RoutingOp, Port from emu import build_topology from tokens import DyadToken, MonadToken class TestSeedSyntax: """Tests for the `const N |> &dest` seed syntax.""" def test_seed_no_iram_slot(self): """Seed constant nodes don't consume IRAM slots.""" source = """\ @system pe=1, sm=1 &target|pe0 <| pass const 42 |> &target """ graph = run_pipeline(source) errors = [e for e in graph.errors] assert not errors, f"Pipeline errors: {errors}" # Find the seed node and the target node from asm.ir import collect_all_nodes all_nodes = collect_all_nodes(graph) seed_nodes = [n for n in all_nodes.values() if n.seed] assert len(seed_nodes) == 1 assert seed_nodes[0].iram_offset is None target_nodes = [n for n in all_nodes.values() if not n.seed and n.pe is not None] assert len(target_nodes) == 1 assert target_nodes[0].iram_offset is not None def test_seed_generates_monad_for_monadic_target(self): """Seed into a monadic target produces MonadToken.""" source = """\ @system pe=1, sm=1 &target|pe0 <| pass const 42 |> &target """ result = assemble(source) seeds = result.seed_tokens assert len(seeds) == 1 assert isinstance(seeds[0], MonadToken) assert seeds[0].data == 42 assert seeds[0].target == 0 def test_seed_generates_dyad_for_dyadic_target(self): """Seed into a dyadic target produces DyadToken.""" source = """\ @system pe=1, sm=1 &target|pe0 <| add const 5 |> &target:L const 3 |> &target:R """ result = assemble(source) seeds = result.seed_tokens assert len(seeds) == 2 dyad_seeds = [s for s in seeds if isinstance(s, DyadToken)] assert len(dyad_seeds) == 2 # Check ports ports = {s.port for s in dyad_seeds} assert ports == {Port.L, Port.R} # Check data values data_values = {s.data for s in dyad_seeds} assert data_values == {5, 3} def test_seed_e2e_add_two_seeds(self): """End-to-end: two seed constants into ADD, routed to second PE.""" source = """\ @system pe=2, sm=1 &sum|pe0 <| add &sink|pe1 <| pass &sum |> &sink const 10 |> &sum:L const 32 |> &sum:R """ result = assemble(source) env = simpy.Environment() sys = build_topology(env, result.pe_configs, result.sm_configs) for seed in result.seed_tokens: sys.inject(seed) env.run(until=100) # ADD(10, 32) = 42, routed to PE1; PE0 output_log has the emitted token assert len(sys.pes[0].output_log) >= 1 out = sys.pes[0].output_log[0] assert out.data == 42 class TestInlineConst: """Tests for the `&foo <| add 7` inline constant shorthand.""" def test_inline_const_parsing(self): """Inline const shorthand parses correctly.""" source = """\ @system pe=1, sm=1 &foo|pe0 <| add 7 const 3 |> &foo:L """ result = assemble(source) assert len(result.seed_tokens) >= 1 def test_inline_const_equivalent_to_comma(self): """Inline const `add 7` is equivalent to `add, 7`.""" source_inline = """\ @system pe=1, sm=1 &foo|pe0 <| add 7 const 3 |> &foo:L """ source_comma = """\ @system pe=1, sm=1 &foo|pe0 <| add, 7 const 3 |> &foo:L """ result_inline = assemble(source_inline) result_comma = assemble(source_comma) # Both should produce same IRAM contents assert len(result_inline.pe_configs) == len(result_comma.pe_configs) pe_inline = result_inline.pe_configs[0] pe_comma = result_comma.pe_configs[0] # Same IRAM instruction for offset in pe_inline.iram.keys(): inst_inline = pe_inline.iram[offset] inst_comma = pe_comma.iram[offset] assert inst_inline.opcode == inst_comma.opcode def test_inline_const_hex(self): """Inline const supports hex values.""" source = """\ @system pe=1, sm=1 &foo|pe0 <| add 0xFF &out|pe0 <| pass &foo |> &out const 1 |> &foo:L """ result = assemble(source) pe = result.pe_configs[0] # Find the ADD instruction (should have has_const=True) add_insts = [inst for inst in pe.iram.values() if inst.has_const] assert len(add_insts) > 0 inst = add_insts[0] from cm_inst import ArithOp assert inst.opcode == ArithOp.ADD assert inst.has_const == True class TestNonCommutativeValidation: """Tests for non-commutative op + IRAM const port validation.""" def test_noncommutative_no_port_warning(self): """Non-commutative op with const and no explicit port emits warning.""" source = """\ @system pe=1, sm=1 &diff|pe0 <| sub 3 &src|pe0 <| pass &src |> &diff const 1 |> &src """ # Should still assemble (warnings don't block) result = assemble(source) assert len(result.pe_configs) == 1 # But run_pipeline should show the warning in graph.errors graph = run_pipeline(source) warnings = [ e for e in graph.errors if e.severity == ErrorSeverity.WARNING and 'Non-commutative' in e.message ] assert len(warnings) >= 1 assert 'SUB' in warnings[0].message def test_noncommutative_explicit_port_ok(self): """Non-commutative op with const and explicit port is fine.""" source = """\ @system pe=1, sm=1 &diff|pe0 <| sub 3 &src|pe0 <| pass &src |> &diff:L const 1 |> &src """ result = assemble(source) # Should assemble without error assert len(result.pe_configs) == 1 def test_commutative_no_port_ok(self): """Commutative op with const and no explicit port is fine.""" source = """\ @system pe=1, sm=1 &total|pe0 <| add 10 &src|pe0 <| pass &src |> &total const 1 |> &src """ result = assemble(source) assert len(result.pe_configs) == 1 def test_comparison_ops_noncommutative(self): """LT, GT, LTE, GTE are non-commutative and emit warnings.""" for op_name in ['lt', 'gt', 'lte', 'gte']: source = f"""\ @system pe=1, sm=1 &cmp|pe0 <| {op_name} 5 &src|pe0 <| pass &src |> &cmp const 1 |> &src """ graph = run_pipeline(source) warnings = [ e for e in graph.errors if e.severity == ErrorSeverity.WARNING and 'Non-commutative' in e.message ] assert len(warnings) >= 1, f"{op_name} should emit non-commutative warning" class TestTriggerableConstant: """Verify that triggerable constants (&c <| const, 7) still work.""" def test_triggerable_const_in_iram(self): """Triggerable constant occupies IRAM slot and fires on input.""" source = """\ @system pe=1, sm=1 &c|pe0 <| const, 7 """ result = assemble(source) pe = result.pe_configs[0] assert len(pe.iram) == 1 # Should also produce a seed token (no incoming edges) assert len(result.seed_tokens) == 1 assert isinstance(result.seed_tokens[0], MonadToken) assert result.seed_tokens[0].data == 7 def test_triggerable_const_with_incoming_edge_not_seed(self): """Triggerable const with incoming edge is NOT a seed.""" source = """\ @system pe=1, sm=1 &src|pe0 <| pass &c|pe0 <| const, 42 &src |> &c const 1 |> &src """ result = assemble(source) # The triggerable const has an incoming edge, so only &src seed counts # plus the seed node for const 1 const_seeds = [s for s in result.seed_tokens if s.data == 42] assert len(const_seeds) == 0