OR-1 dataflow CPU sketch
1"""Tests for variadic repetition expansion (Phase 6).
2
3Tests verify:
4- Variadic macros expand correctly (repetition block once per argument)
5- ${_idx} produces correct iteration indices (0-based)
6- Mixed params: non-variadic first, variadic captures remaining args
7- Empty variadic invocation: no error, nothing expanded
8- Single variadic invocation: one iteration
9- Variadic parameter not last: error at lower pass
10- Full pipeline: variadic macro assembles and runs in emulator
11"""
12
13from pathlib import Path
14
15import simpy
16from lark import Lark
17
18from asm import assemble
19from asm.errors import ErrorCategory
20from asm.expand import expand
21from asm.ir import (
22 IREdge,
23 IRGraph,
24 IRMacroCall,
25 IRNode,
26 IRRepetitionBlock,
27 MacroDef,
28 MacroParam,
29 ParamRef,
30 SourceLoc,
31)
32from asm.lower import lower
33from cm_inst import ArithOp
34from emu import build_topology
35
36
37def _get_parser():
38 """Get the dfasm parser."""
39 grammar_path = Path(__file__).parent.parent / "dfasm.lark"
40 return Lark(
41 grammar_path.read_text(),
42 parser="earley",
43 propagate_positions=True,
44 )
45
46
47def parse_and_lower(source: str) -> IRGraph:
48 """Parse source and lower to IRGraph (before expansion)."""
49 parser = _get_parser()
50 tree = parser.parse(source)
51 return lower(tree)
52
53
54def parse_lower_expand(source: str) -> IRGraph:
55 """Parse, lower, and expand."""
56 graph = parse_and_lower(source)
57 return expand(graph)
58
59
60class TestVariadicSimpleExpansion:
61 """Test basic variadic repetition expansion."""
62
63 def test_simple_variadic_expands_three_iterations(self):
64 """Simple variadic: #inject *gates creates 3 pass nodes for 3 args."""
65 source = """
66 @system pe=1, sm=1
67
68 #inject *gates |> {
69 $( &g <| pass ),*
70 }
71
72 #inject &a, &b, &c
73 """
74 graph = parse_lower_expand(source)
75
76 # Should have 3 nodes: #inject_0_rep0.&g, #inject_0_rep1.&g, #inject_0_rep2.&g
77 nodes = list(graph.nodes.keys())
78 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}"
79
80 # Each should have rep0, rep1, rep2 suffix to distinguish iterations
81 rep0 = [n for n in nodes if "rep0" in n]
82 rep1 = [n for n in nodes if "rep1" in n]
83 rep2 = [n for n in nodes if "rep2" in n]
84 assert len(rep0) == 1, f"Expected rep0 node in {nodes}"
85 assert len(rep1) == 1, f"Expected rep1 node in {nodes}"
86 assert len(rep2) == 1, f"Expected rep2 node in {nodes}"
87
88 def test_variadic_with_multiple_statements_per_iteration(self):
89 """Repetition block with multiple statements per iteration."""
90 source = """
91 @system pe=1, sm=1
92
93 #loop *items |> {
94 $( &item <| pass
95 &item |> &output:L ),*
96 }
97
98 #loop &x, &y
99 """
100 graph = parse_lower_expand(source)
101
102 # 2 invocations * 2 statements = 4 nodes
103 nodes = list(graph.nodes.keys())
104 assert len(nodes) >= 2, f"Expected at least 2 nodes, got {len(nodes)}: {nodes}"
105
106 # Should have rep0 and rep1 in names
107 assert any("rep0" in n for n in nodes), f"Expected rep0 in {nodes}"
108 assert any("rep1" in n for n in nodes), f"Expected rep1 in {nodes}"
109
110
111class TestVariadicIndexVariable:
112 """Test ${_idx} substitution in repetition blocks."""
113
114 def test_idx_variable_expands_to_iteration_index(self):
115 """${_idx} becomes 0, 1, 2 in successive iterations via token pasting.
116
117 Tests that ParamRef with _idx parameter substitutes correctly during
118 variadic expansion. The _idx value is set to the iteration index (0-based)
119 and is available for token pasting concatenation.
120 """
121 # Construct macro body with a ParamRef containing _idx
122 # Node name will be: &node_${_idx} -> ParamRef(param="_idx", prefix="&node_", suffix="")
123 param_ref = ParamRef(param="_idx", prefix="&node_", suffix="")
124 body_node = IRNode(
125 name=param_ref,
126 opcode=ArithOp.ADD,
127 loc=SourceLoc(0, 0),
128 )
129
130 # Create repetition block with the node
131 rep_body = IRGraph(
132 nodes={"node_placeholder": body_node},
133 edges=[],
134 macro_defs=[],
135 macro_calls=[],
136 )
137
138 rep_block = IRRepetitionBlock(
139 body=rep_body,
140 variadic_param="vals",
141 loc=SourceLoc(0, 0),
142 )
143
144 # Create macro definition with variadic parameter
145 macro_def = MacroDef(
146 name="maker",
147 params=(MacroParam(name="vals", variadic=True),),
148 body=IRGraph(
149 nodes={},
150 edges=[],
151 macro_defs=[],
152 macro_calls=[],
153 ),
154 repetition_blocks=[rep_block],
155 loc=SourceLoc(0, 0),
156 )
157
158 # Create macro call with 3 arguments
159 macro_call = IRMacroCall(
160 name="maker",
161 positional_args=(42, 100, 200),
162 named_args=(),
163 loc=SourceLoc(0, 0),
164 )
165
166 # Create graph with macro definition and call
167 graph = IRGraph(
168 nodes={},
169 edges=[],
170 regions=[],
171 data_defs=[],
172 macro_defs=[macro_def],
173 macro_calls=[macro_call],
174 )
175
176 # Expand the graph
177 expanded = expand(graph)
178
179 # After expansion, should have 3 nodes with names:
180 # #maker_0_rep0.&node_0, #maker_0_rep1.&node_1, #maker_0_rep2.&node_2
181 node_names = list(expanded.nodes.keys())
182 assert len(node_names) == 3, (
183 f"Expected 3 nodes, got {len(node_names)}: {node_names}"
184 )
185
186 # Verify that _idx was substituted correctly in node names
187 # Each iteration should have node_0, node_1, node_2 respectively
188 assert any("&node_0" in name for name in node_names), (
189 f"Expected node with &node_0 (iteration 0), got {node_names}"
190 )
191 assert any("&node_1" in name for name in node_names), (
192 f"Expected node with &node_1 (iteration 1), got {node_names}"
193 )
194 assert any("&node_2" in name for name in node_names), (
195 f"Expected node with &node_2 (iteration 2), got {node_names}"
196 )
197
198
199class TestVariadicMixedParams:
200 """Test variadic with non-variadic parameters."""
201
202 def test_mixed_params_non_variadic_first(self):
203 """Macro with dest, *sources: first param is non-variadic."""
204 source = """
205 @system pe=1, sm=1
206
207 #route dest, *sources |> {
208 $( &src <| pass ),*
209 }
210
211 #route &output, &in1, &in2
212 """
213 graph = parse_lower_expand(source)
214
215 # Should have 2 nodes (one per source)
216 nodes = list(graph.nodes.keys())
217 assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}: {nodes}"
218
219 def test_mixed_params_three_args_two_non_variadic(self):
220 """Macro with a, b, *rest: args 3+ go to rest."""
221 source = """
222 @system pe=1, sm=1
223
224 #process a, b, *rest |> {
225 $( &r <| pass ),*
226 }
227
228 #process &x, &y, &z1, &z2, &z3
229 """
230 graph = parse_lower_expand(source)
231
232 # 3 args go to *rest -> 3 iterations
233 nodes = list(graph.nodes.keys())
234 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}"
235
236
237class TestVariadicEdgeCases:
238 """Test edge cases: empty variadic, single arg, etc."""
239
240 def test_empty_variadic_no_error(self):
241 """Invoke variadic macro with zero args: no error, nothing expanded."""
242 source = """
243 @system pe=1, sm=1
244
245 #optional *args |> {
246 $( &x <| pass ),*
247 }
248
249 #optional
250 """
251 graph = parse_lower_expand(source)
252
253 # No nodes should be created
254 nodes = list(graph.nodes.keys())
255 assert len(nodes) == 0, (
256 f"Expected 0 nodes for empty variadic, got {len(nodes)}: {nodes}"
257 )
258
259 def test_single_variadic_one_iteration(self):
260 """Invoke with one variadic arg: one iteration."""
261 source = """
262 @system pe=1, sm=1
263
264 #single *args |> {
265 $( &item <| pass ),*
266 }
267
268 #single &only
269 """
270 graph = parse_lower_expand(source)
271
272 # One iteration -> one node with rep0
273 nodes = list(graph.nodes.keys())
274 assert len(nodes) == 1
275 assert any("rep0" in n for n in nodes), f"Expected rep0 in {nodes}"
276
277 # Should NOT have rep1
278 assert not any("rep1" in n for n in nodes), f"Should not have rep1 in {nodes}"
279
280
281class TestVariadicGrammarValidation:
282 """Test grammar validation: variadic must be last, etc."""
283
284 def test_variadic_not_last_is_error(self):
285 """Variadic parameter not last: parser/lower should reject."""
286 source = """
287 @system pe=1, sm=1
288
289 #bad *args, b |> {
290 $( &x <| pass ),*
291 }
292 """
293 graph = parse_and_lower(source)
294
295 # Lower pass should catch this error
296 assert any(e.category == ErrorCategory.NAME for e in graph.errors), (
297 f"Expected NAME error for variadic not last, got: {graph.errors}"
298 )
299
300 def test_multiple_variadic_is_error(self):
301 """Multiple variadic parameters: parser/lower should reject."""
302 source = """
303 @system pe=1, sm=1
304
305 #bad *a, *b |> {
306 $( &x <| pass ),*
307 }
308 """
309 graph = parse_and_lower(source)
310
311 # Lower pass should catch this error
312 assert any(e.category == ErrorCategory.NAME for e in graph.errors), (
313 f"Expected NAME error for multiple variadic, got: {graph.errors}"
314 )
315
316
317class TestVariadicIntegration:
318 """Integration tests with full pipeline."""
319
320 def test_variadic_with_edges_between_iterations(self):
321 """Repetition block with edges wiring iterations together."""
322 source = """
323 @system pe=1, sm=1
324
325 #chain *items |> {
326 $( &item <| pass ),*
327 &item |> &output:L
328 }
329
330 #chain &a, &b, &c
331 """
332 graph = parse_lower_expand(source)
333
334 # 3 nodes from repetition, plus potential edges
335 nodes = list(graph.nodes.keys())
336 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}"
337
338 # Should have edges (one per node in this case)
339 edges = graph.edges
340 assert len(edges) > 0, "Expected at least one edge"
341
342 def test_variadic_can_be_invoked_multiple_times(self):
343 """Same variadic macro invoked twice with different args."""
344 source = """
345 @system pe=1, sm=1
346
347 #expand *items |> {
348 $( &item <| pass ),*
349 }
350
351 #expand &a, &b
352 #expand &x, &y, &z
353 """
354 graph = parse_lower_expand(source)
355
356 # First invocation: 2 nodes
357 # Second invocation: 3 nodes
358 # Total: 5 nodes
359 nodes = list(graph.nodes.keys())
360 assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}: {nodes}"
361
362 # First invocation should have #expand_0_rep0, #expand_0_rep1
363 # Second invocation should have #expand_1_rep0, #expand_1_rep1, #expand_1_rep2
364 expand_0 = [n for n in nodes if "#expand_0" in n]
365 expand_1 = [n for n in nodes if "#expand_1" in n]
366 assert len(expand_0) == 2, (
367 f"Expected 2 nodes from first invocation, got {len(expand_0)}"
368 )
369 assert len(expand_1) == 3, (
370 f"Expected 3 nodes from second invocation, got {len(expand_1)}"
371 )
372
373 def test_variadic_nested_with_other_macros(self):
374 """Variadic macro combined with non-variadic macros."""
375 source = """
376 @system pe=1, sm=1
377
378 #simple |> {
379 &fixed <| pass
380 }
381
382 #expand *items |> {
383 $( &item <| pass ),*
384 }
385
386 #simple
387 #expand &a, &b
388 """
389 graph = parse_lower_expand(source)
390
391 # 1 from #simple + 2 from #expand = 3 nodes
392 nodes = list(graph.nodes.keys())
393 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}"
394
395 # Should have both macro invocations in names
396 simple_nodes = [n for n in nodes if "#simple" in n]
397 expand_nodes = [n for n in nodes if "#expand" in n]
398 assert len(simple_nodes) == 1, f"Expected 1 #simple node"
399 assert len(expand_nodes) == 2, f"Expected 2 #expand nodes"
400
401
402class TestVariadicPositionalRet:
403 """Test positional @ret wiring in variadic repetition blocks."""
404
405 def test_bare_ret_maps_to_positional_outputs_by_iteration(self):
406 """@ret in iteration N maps to Nth positional output at call site."""
407 source = """
408 @system pe=1, sm=1
409
410 #fan *vals |> {
411 $( &v <| pass
412 &v |> @ret ),*
413 }
414
415 &a <| pass
416 &b <| pass
417 &c <| pass
418 #fan &a, &b, &c |> &x, &y, &z
419 &x <| pass
420 &y <| pass
421 &z <| pass
422 """
423 graph = parse_lower_expand(source)
424 assert not graph.errors, f"Unexpected errors: {graph.errors}"
425
426 # Check edges from expanded nodes to positional outputs
427 ret_edges = [e for e in graph.edges if e.dest in ("&x", "&y", "&z")]
428 dests = [e.dest for e in ret_edges]
429 assert "&x" in dests, f"Expected &x in ret edge dests: {dests}"
430 assert "&y" in dests, f"Expected &y in ret edge dests: {dests}"
431 assert "&z" in dests, f"Expected &z in ret edge dests: {dests}"
432
433 def test_positional_ret_with_two_outputs(self):
434 """Two variadic iterations map to two positional outputs."""
435 source = """
436 @system pe=1, sm=1
437
438 #pair *vals |> {
439 $( &v <| pass
440 &v |> @ret ),*
441 }
442
443 &a <| pass
444 &b <| pass
445 &left <| pass
446 &right <| pass
447 #pair &a, &b |> &left, &right
448 """
449 graph = parse_lower_expand(source)
450 assert not graph.errors, f"Unexpected errors: {graph.errors}"
451
452 ret_edges = [e for e in graph.edges if e.dest in ("&left", "&right")]
453 assert len(ret_edges) == 2, (
454 f"Expected 2 ret edges, got {len(ret_edges)}: {ret_edges}"
455 )
456
457 def test_fewer_outputs_than_iterations_errors(self):
458 """More @ret iterations than positional outputs produces errors."""
459 source = """
460 @system pe=1, sm=1
461
462 #too_many *vals |> {
463 $( &v <| pass
464 &v |> @ret ),*
465 }
466
467 &a <| pass
468 &b <| pass
469 &c <| pass
470 &only_one <| pass
471 #too_many &a, &b, &c |> &only_one
472 """
473 graph = parse_lower_expand(source)
474 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
475 assert len(macro_errors) >= 1, (
476 f"Expected error for unmatched @ret, got: {graph.errors}"
477 )
478
479
480class TestVariadicFullPipeline:
481 """Full pipeline test: variadic macro through assemble and emulator."""
482
483 def test_variadic_macro_assembles_and_runs(self):
484 """Variadic macro with positional @ret wiring through full pipeline.
485
486 Each iteration's @ret maps to the next positional output at the
487 call site, so `#multi_const 3, 4 |> &sum:L, &sum:R` wires
488 iteration 0 → &sum:L and iteration 1 → &sum:R.
489 """
490 source = """
491 @system pe=1, sm=0
492
493 #multi_const *vals |> {
494 $( &c <| const, ${vals}
495 &c |> @ret ),*
496 }
497
498 &sum <| add
499 &out <| pass
500 #multi_const 3, 4 |> &sum:L, &sum:R
501 &sum |> &out:L
502 """
503 result = assemble(source)
504 assert result is not None
505 assert len(result.pe_configs) > 0
506
507 env = simpy.Environment()
508 sys = build_topology(env, result.pe_configs, result.sm_configs)
509 for setup in result.setup_tokens:
510 sys.inject(setup)
511 for seed in result.seed_tokens:
512 sys.inject(seed)
513 env.run(until=500)
514
515 all_values = []
516 for pe in sys.pes.values():
517 all_values.extend(t.data for t in pe.output_log if hasattr(t, "data"))
518 assert 7 in all_values, f"Expected 3+4=7 in outputs, got {all_values}"