OR-1 dataflow CPU sketch
1"""Tests for seed constants, inline constants, and non-commutative validation.
2
3Tests verify:
4- Seed syntax (const N |> &dest) produces seed tokens, not IRAM slots
5- Inline const syntax (&foo <| add 7) is equivalent to (&foo <| add, 7)
6- Non-commutative ops with IRAM const require explicit port on incoming edges
7- Seed tokens match target instruction arity (DyadToken for dyadic, MonadToken for monadic)
8- Triggerable constants (&c <| const, 7) still work as before
9"""
10
11import simpy
12
13from asm import assemble, run_pipeline
14from asm.codegen import generate_direct
15from asm.errors import ErrorSeverity
16from asm.ir import (
17 IRGraph, IRNode, IREdge, SystemConfig, SourceLoc, ResolvedDest,
18)
19from asm.opcodes import is_dyadic
20from cm_inst import OutputStyle, ArithOp, LogicOp, RoutingOp, Port
21from emu import build_topology
22from tokens import DyadToken, MonadToken
23
24
25class TestSeedSyntax:
26 """Tests for the `const N |> &dest` seed syntax."""
27
28 def test_seed_no_iram_slot(self):
29 """Seed constant nodes don't consume IRAM slots."""
30 source = """\
31@system pe=1, sm=1
32&target|pe0 <| pass
33const 42 |> &target
34"""
35 graph = run_pipeline(source)
36 errors = [e for e in graph.errors]
37 assert not errors, f"Pipeline errors: {errors}"
38
39 # Find the seed node and the target node
40 from asm.ir import collect_all_nodes
41 all_nodes = collect_all_nodes(graph)
42
43 seed_nodes = [n for n in all_nodes.values() if n.seed]
44 assert len(seed_nodes) == 1
45 assert seed_nodes[0].iram_offset is None
46
47 target_nodes = [n for n in all_nodes.values() if not n.seed and n.pe is not None]
48 assert len(target_nodes) == 1
49 assert target_nodes[0].iram_offset is not None
50
51 def test_seed_generates_monad_for_monadic_target(self):
52 """Seed into a monadic target produces MonadToken."""
53 source = """\
54@system pe=1, sm=1
55&target|pe0 <| pass
56const 42 |> &target
57"""
58 result = assemble(source)
59 seeds = result.seed_tokens
60 assert len(seeds) == 1
61 assert isinstance(seeds[0], MonadToken)
62 assert seeds[0].data == 42
63 assert seeds[0].target == 0
64
65 def test_seed_generates_dyad_for_dyadic_target(self):
66 """Seed into a dyadic target produces DyadToken."""
67 source = """\
68@system pe=1, sm=1
69&target|pe0 <| add
70const 5 |> &target:L
71const 3 |> &target:R
72"""
73 result = assemble(source)
74 seeds = result.seed_tokens
75 assert len(seeds) == 2
76
77 dyad_seeds = [s for s in seeds if isinstance(s, DyadToken)]
78 assert len(dyad_seeds) == 2
79
80 # Check ports
81 ports = {s.port for s in dyad_seeds}
82 assert ports == {Port.L, Port.R}
83
84 # Check data values
85 data_values = {s.data for s in dyad_seeds}
86 assert data_values == {5, 3}
87
88 def test_seed_e2e_add_two_seeds(self):
89 """End-to-end: two seed constants into ADD, routed to second PE."""
90 source = """\
91@system pe=2, sm=1
92&sum|pe0 <| add
93&sink|pe1 <| pass
94&sum |> &sink
95const 10 |> &sum:L
96const 32 |> &sum:R
97"""
98 result = assemble(source)
99 env = simpy.Environment()
100 sys = build_topology(env, result.pe_configs, result.sm_configs)
101 for seed in result.seed_tokens:
102 sys.inject(seed)
103 env.run(until=100)
104
105 # ADD(10, 32) = 42, routed to PE1; PE0 output_log has the emitted token
106 assert len(sys.pes[0].output_log) >= 1
107 out = sys.pes[0].output_log[0]
108 assert out.data == 42
109
110
111class TestInlineConst:
112 """Tests for the `&foo <| add 7` inline constant shorthand."""
113
114 def test_inline_const_parsing(self):
115 """Inline const shorthand parses correctly."""
116 source = """\
117@system pe=1, sm=1
118&foo|pe0 <| add 7
119const 3 |> &foo:L
120"""
121 result = assemble(source)
122 assert len(result.seed_tokens) >= 1
123
124 def test_inline_const_equivalent_to_comma(self):
125 """Inline const `add 7` is equivalent to `add, 7`."""
126 source_inline = """\
127@system pe=1, sm=1
128&foo|pe0 <| add 7
129const 3 |> &foo:L
130"""
131 source_comma = """\
132@system pe=1, sm=1
133&foo|pe0 <| add, 7
134const 3 |> &foo:L
135"""
136 result_inline = assemble(source_inline)
137 result_comma = assemble(source_comma)
138
139 # Both should produce same IRAM contents
140 assert len(result_inline.pe_configs) == len(result_comma.pe_configs)
141 pe_inline = result_inline.pe_configs[0]
142 pe_comma = result_comma.pe_configs[0]
143
144 # Same IRAM instruction
145 for offset in pe_inline.iram.keys():
146 inst_inline = pe_inline.iram[offset]
147 inst_comma = pe_comma.iram[offset]
148 assert inst_inline.opcode == inst_comma.opcode
149
150 def test_inline_const_hex(self):
151 """Inline const supports hex values."""
152 source = """\
153@system pe=1, sm=1
154&foo|pe0 <| add 0xFF
155&out|pe0 <| pass
156&foo |> &out
157const 1 |> &foo:L
158"""
159 result = assemble(source)
160 pe = result.pe_configs[0]
161 # Find the ADD instruction (should have has_const=True)
162 add_insts = [inst for inst in pe.iram.values() if inst.has_const]
163 assert len(add_insts) > 0
164 inst = add_insts[0]
165 from cm_inst import ArithOp
166 assert inst.opcode == ArithOp.ADD
167 assert inst.has_const == True
168
169
170class TestNonCommutativeValidation:
171 """Tests for non-commutative op + IRAM const port validation."""
172
173 def test_noncommutative_no_port_warning(self):
174 """Non-commutative op with const and no explicit port emits warning."""
175 source = """\
176@system pe=1, sm=1
177&diff|pe0 <| sub 3
178&src|pe0 <| pass
179&src |> &diff
180const 1 |> &src
181"""
182 # Should still assemble (warnings don't block)
183 result = assemble(source)
184 assert len(result.pe_configs) == 1
185
186 # But run_pipeline should show the warning in graph.errors
187 graph = run_pipeline(source)
188 warnings = [
189 e for e in graph.errors
190 if e.severity == ErrorSeverity.WARNING and 'Non-commutative' in e.message
191 ]
192 assert len(warnings) >= 1
193 assert 'SUB' in warnings[0].message
194
195 def test_noncommutative_explicit_port_ok(self):
196 """Non-commutative op with const and explicit port is fine."""
197 source = """\
198@system pe=1, sm=1
199&diff|pe0 <| sub 3
200&src|pe0 <| pass
201&src |> &diff:L
202const 1 |> &src
203"""
204 result = assemble(source)
205 # Should assemble without error
206 assert len(result.pe_configs) == 1
207
208 def test_commutative_no_port_ok(self):
209 """Commutative op with const and no explicit port is fine."""
210 source = """\
211@system pe=1, sm=1
212&total|pe0 <| add 10
213&src|pe0 <| pass
214&src |> &total
215const 1 |> &src
216"""
217 result = assemble(source)
218 assert len(result.pe_configs) == 1
219
220 def test_comparison_ops_noncommutative(self):
221 """LT, GT, LTE, GTE are non-commutative and emit warnings."""
222 for op_name in ['lt', 'gt', 'lte', 'gte']:
223 source = f"""\
224@system pe=1, sm=1
225&cmp|pe0 <| {op_name} 5
226&src|pe0 <| pass
227&src |> &cmp
228const 1 |> &src
229"""
230 graph = run_pipeline(source)
231 warnings = [
232 e for e in graph.errors
233 if e.severity == ErrorSeverity.WARNING and 'Non-commutative' in e.message
234 ]
235 assert len(warnings) >= 1, f"{op_name} should emit non-commutative warning"
236
237
238class TestTriggerableConstant:
239 """Verify that triggerable constants (&c <| const, 7) still work."""
240
241 def test_triggerable_const_in_iram(self):
242 """Triggerable constant occupies IRAM slot and fires on input."""
243 source = """\
244@system pe=1, sm=1
245&c|pe0 <| const, 7
246"""
247 result = assemble(source)
248 pe = result.pe_configs[0]
249 assert len(pe.iram) == 1
250
251 # Should also produce a seed token (no incoming edges)
252 assert len(result.seed_tokens) == 1
253 assert isinstance(result.seed_tokens[0], MonadToken)
254 assert result.seed_tokens[0].data == 7
255
256 def test_triggerable_const_with_incoming_edge_not_seed(self):
257 """Triggerable const with incoming edge is NOT a seed."""
258 source = """\
259@system pe=1, sm=1
260&src|pe0 <| pass
261&c|pe0 <| const, 42
262&src |> &c
263const 1 |> &src
264"""
265 result = assemble(source)
266 # The triggerable const has an incoming edge, so only &src seed counts
267 # plus the seed node for const 1
268 const_seeds = [s for s in result.seed_tokens if s.data == 42]
269 assert len(const_seeds) == 0