"""Tests for variadic repetition expansion (Phase 6). Tests verify: - Variadic macros expand correctly (repetition block once per argument) - ${_idx} produces correct iteration indices (0-based) - Mixed params: non-variadic first, variadic captures remaining args - Empty variadic invocation: no error, nothing expanded - Single variadic invocation: one iteration - Variadic parameter not last: error at lower pass - Full pipeline: variadic macro assembles and runs in emulator """ from pathlib import Path import simpy from lark import Lark from asm import assemble from asm.errors import ErrorCategory from asm.expand import expand from asm.ir import ( IREdge, IRGraph, IRMacroCall, IRNode, IRRepetitionBlock, MacroDef, MacroParam, ParamRef, SourceLoc, ) from asm.lower import lower from cm_inst import ArithOp from emu import build_topology def _get_parser(): """Get the dfasm parser.""" grammar_path = Path(__file__).parent.parent / "dfasm.lark" return Lark( grammar_path.read_text(), parser="earley", propagate_positions=True, ) def parse_and_lower(source: str) -> IRGraph: """Parse source and lower to IRGraph (before expansion).""" parser = _get_parser() tree = parser.parse(source) return lower(tree) def parse_lower_expand(source: str) -> IRGraph: """Parse, lower, and expand.""" graph = parse_and_lower(source) return expand(graph) class TestVariadicSimpleExpansion: """Test basic variadic repetition expansion.""" def test_simple_variadic_expands_three_iterations(self): """Simple variadic: #inject *gates creates 3 pass nodes for 3 args.""" source = """ @system pe=1, sm=1 #inject *gates |> { $( &g <| pass ),* } #inject &a, &b, &c """ graph = parse_lower_expand(source) # Should have 3 nodes: #inject_0_rep0.&g, #inject_0_rep1.&g, #inject_0_rep2.&g nodes = list(graph.nodes.keys()) assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" # Each should have rep0, rep1, rep2 suffix to distinguish iterations rep0 = [n for n in nodes if "rep0" in n] rep1 = [n for n in nodes if "rep1" in n] rep2 = [n for n in nodes if "rep2" in n] assert len(rep0) == 1, f"Expected rep0 node in {nodes}" assert len(rep1) == 1, f"Expected rep1 node in {nodes}" assert len(rep2) == 1, f"Expected rep2 node in {nodes}" def test_variadic_with_multiple_statements_per_iteration(self): """Repetition block with multiple statements per iteration.""" source = """ @system pe=1, sm=1 #loop *items |> { $( &item <| pass &item |> &output:L ),* } #loop &x, &y """ graph = parse_lower_expand(source) # 2 invocations * 2 statements = 4 nodes nodes = list(graph.nodes.keys()) assert len(nodes) >= 2, f"Expected at least 2 nodes, got {len(nodes)}: {nodes}" # Should have rep0 and rep1 in names assert any("rep0" in n for n in nodes), f"Expected rep0 in {nodes}" assert any("rep1" in n for n in nodes), f"Expected rep1 in {nodes}" class TestVariadicIndexVariable: """Test ${_idx} substitution in repetition blocks.""" def test_idx_variable_expands_to_iteration_index(self): """${_idx} becomes 0, 1, 2 in successive iterations via token pasting. Tests that ParamRef with _idx parameter substitutes correctly during variadic expansion. The _idx value is set to the iteration index (0-based) and is available for token pasting concatenation. """ # Construct macro body with a ParamRef containing _idx # Node name will be: &node_${_idx} -> ParamRef(param="_idx", prefix="&node_", suffix="") param_ref = ParamRef(param="_idx", prefix="&node_", suffix="") body_node = IRNode( name=param_ref, opcode=ArithOp.ADD, loc=SourceLoc(0, 0), ) # Create repetition block with the node rep_body = IRGraph( nodes={"node_placeholder": body_node}, edges=[], macro_defs=[], macro_calls=[], ) rep_block = IRRepetitionBlock( body=rep_body, variadic_param="vals", loc=SourceLoc(0, 0), ) # Create macro definition with variadic parameter macro_def = MacroDef( name="maker", params=(MacroParam(name="vals", variadic=True),), body=IRGraph( nodes={}, edges=[], macro_defs=[], macro_calls=[], ), repetition_blocks=[rep_block], loc=SourceLoc(0, 0), ) # Create macro call with 3 arguments macro_call = IRMacroCall( name="maker", positional_args=(42, 100, 200), named_args=(), loc=SourceLoc(0, 0), ) # Create graph with macro definition and call graph = IRGraph( nodes={}, edges=[], regions=[], data_defs=[], macro_defs=[macro_def], macro_calls=[macro_call], ) # Expand the graph expanded = expand(graph) # After expansion, should have 3 nodes with names: # #maker_0_rep0.&node_0, #maker_0_rep1.&node_1, #maker_0_rep2.&node_2 node_names = list(expanded.nodes.keys()) assert len(node_names) == 3, ( f"Expected 3 nodes, got {len(node_names)}: {node_names}" ) # Verify that _idx was substituted correctly in node names # Each iteration should have node_0, node_1, node_2 respectively assert any("&node_0" in name for name in node_names), ( f"Expected node with &node_0 (iteration 0), got {node_names}" ) assert any("&node_1" in name for name in node_names), ( f"Expected node with &node_1 (iteration 1), got {node_names}" ) assert any("&node_2" in name for name in node_names), ( f"Expected node with &node_2 (iteration 2), got {node_names}" ) class TestVariadicMixedParams: """Test variadic with non-variadic parameters.""" def test_mixed_params_non_variadic_first(self): """Macro with dest, *sources: first param is non-variadic.""" source = """ @system pe=1, sm=1 #route dest, *sources |> { $( &src <| pass ),* } #route &output, &in1, &in2 """ graph = parse_lower_expand(source) # Should have 2 nodes (one per source) nodes = list(graph.nodes.keys()) assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}: {nodes}" def test_mixed_params_three_args_two_non_variadic(self): """Macro with a, b, *rest: args 3+ go to rest.""" source = """ @system pe=1, sm=1 #process a, b, *rest |> { $( &r <| pass ),* } #process &x, &y, &z1, &z2, &z3 """ graph = parse_lower_expand(source) # 3 args go to *rest -> 3 iterations nodes = list(graph.nodes.keys()) assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" class TestVariadicEdgeCases: """Test edge cases: empty variadic, single arg, etc.""" def test_empty_variadic_no_error(self): """Invoke variadic macro with zero args: no error, nothing expanded.""" source = """ @system pe=1, sm=1 #optional *args |> { $( &x <| pass ),* } #optional """ graph = parse_lower_expand(source) # No nodes should be created nodes = list(graph.nodes.keys()) assert len(nodes) == 0, ( f"Expected 0 nodes for empty variadic, got {len(nodes)}: {nodes}" ) def test_single_variadic_one_iteration(self): """Invoke with one variadic arg: one iteration.""" source = """ @system pe=1, sm=1 #single *args |> { $( &item <| pass ),* } #single &only """ graph = parse_lower_expand(source) # One iteration -> one node with rep0 nodes = list(graph.nodes.keys()) assert len(nodes) == 1 assert any("rep0" in n for n in nodes), f"Expected rep0 in {nodes}" # Should NOT have rep1 assert not any("rep1" in n for n in nodes), f"Should not have rep1 in {nodes}" class TestVariadicGrammarValidation: """Test grammar validation: variadic must be last, etc.""" def test_variadic_not_last_is_error(self): """Variadic parameter not last: parser/lower should reject.""" source = """ @system pe=1, sm=1 #bad *args, b |> { $( &x <| pass ),* } """ graph = parse_and_lower(source) # Lower pass should catch this error assert any(e.category == ErrorCategory.NAME for e in graph.errors), ( f"Expected NAME error for variadic not last, got: {graph.errors}" ) def test_multiple_variadic_is_error(self): """Multiple variadic parameters: parser/lower should reject.""" source = """ @system pe=1, sm=1 #bad *a, *b |> { $( &x <| pass ),* } """ graph = parse_and_lower(source) # Lower pass should catch this error assert any(e.category == ErrorCategory.NAME for e in graph.errors), ( f"Expected NAME error for multiple variadic, got: {graph.errors}" ) class TestVariadicIntegration: """Integration tests with full pipeline.""" def test_variadic_with_edges_between_iterations(self): """Repetition block with edges wiring iterations together.""" source = """ @system pe=1, sm=1 #chain *items |> { $( &item <| pass ),* &item |> &output:L } #chain &a, &b, &c """ graph = parse_lower_expand(source) # 3 nodes from repetition, plus potential edges nodes = list(graph.nodes.keys()) assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" # Should have edges (one per node in this case) edges = graph.edges assert len(edges) > 0, "Expected at least one edge" def test_variadic_can_be_invoked_multiple_times(self): """Same variadic macro invoked twice with different args.""" source = """ @system pe=1, sm=1 #expand *items |> { $( &item <| pass ),* } #expand &a, &b #expand &x, &y, &z """ graph = parse_lower_expand(source) # First invocation: 2 nodes # Second invocation: 3 nodes # Total: 5 nodes nodes = list(graph.nodes.keys()) assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}: {nodes}" # First invocation should have #expand_0_rep0, #expand_0_rep1 # Second invocation should have #expand_1_rep0, #expand_1_rep1, #expand_1_rep2 expand_0 = [n for n in nodes if "#expand_0" in n] expand_1 = [n for n in nodes if "#expand_1" in n] assert len(expand_0) == 2, ( f"Expected 2 nodes from first invocation, got {len(expand_0)}" ) assert len(expand_1) == 3, ( f"Expected 3 nodes from second invocation, got {len(expand_1)}" ) def test_variadic_nested_with_other_macros(self): """Variadic macro combined with non-variadic macros.""" source = """ @system pe=1, sm=1 #simple |> { &fixed <| pass } #expand *items |> { $( &item <| pass ),* } #simple #expand &a, &b """ graph = parse_lower_expand(source) # 1 from #simple + 2 from #expand = 3 nodes nodes = list(graph.nodes.keys()) assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" # Should have both macro invocations in names simple_nodes = [n for n in nodes if "#simple" in n] expand_nodes = [n for n in nodes if "#expand" in n] assert len(simple_nodes) == 1, f"Expected 1 #simple node" assert len(expand_nodes) == 2, f"Expected 2 #expand nodes" class TestVariadicPositionalRet: """Test positional @ret wiring in variadic repetition blocks.""" def test_bare_ret_maps_to_positional_outputs_by_iteration(self): """@ret in iteration N maps to Nth positional output at call site.""" source = """ @system pe=1, sm=1 #fan *vals |> { $( &v <| pass &v |> @ret ),* } &a <| pass &b <| pass &c <| pass #fan &a, &b, &c |> &x, &y, &z &x <| pass &y <| pass &z <| pass """ graph = parse_lower_expand(source) assert not graph.errors, f"Unexpected errors: {graph.errors}" # Check edges from expanded nodes to positional outputs ret_edges = [e for e in graph.edges if e.dest in ("&x", "&y", "&z")] dests = [e.dest for e in ret_edges] assert "&x" in dests, f"Expected &x in ret edge dests: {dests}" assert "&y" in dests, f"Expected &y in ret edge dests: {dests}" assert "&z" in dests, f"Expected &z in ret edge dests: {dests}" def test_positional_ret_with_two_outputs(self): """Two variadic iterations map to two positional outputs.""" source = """ @system pe=1, sm=1 #pair *vals |> { $( &v <| pass &v |> @ret ),* } &a <| pass &b <| pass &left <| pass &right <| pass #pair &a, &b |> &left, &right """ graph = parse_lower_expand(source) assert not graph.errors, f"Unexpected errors: {graph.errors}" ret_edges = [e for e in graph.edges if e.dest in ("&left", "&right")] assert len(ret_edges) == 2, ( f"Expected 2 ret edges, got {len(ret_edges)}: {ret_edges}" ) def test_fewer_outputs_than_iterations_errors(self): """More @ret iterations than positional outputs produces errors.""" source = """ @system pe=1, sm=1 #too_many *vals |> { $( &v <| pass &v |> @ret ),* } &a <| pass &b <| pass &c <| pass &only_one <| pass #too_many &a, &b, &c |> &only_one """ graph = parse_lower_expand(source) macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO] assert len(macro_errors) >= 1, ( f"Expected error for unmatched @ret, got: {graph.errors}" ) class TestVariadicFullPipeline: """Full pipeline test: variadic macro through assemble and emulator.""" def test_variadic_macro_assembles_and_runs(self): """Variadic macro with positional @ret wiring through full pipeline. Each iteration's @ret maps to the next positional output at the call site, so `#multi_const 3, 4 |> &sum:L, &sum:R` wires iteration 0 → &sum:L and iteration 1 → &sum:R. """ source = """ @system pe=1, sm=0 #multi_const *vals |> { $( &c <| const, ${vals} &c |> @ret ),* } &sum <| add &out <| pass #multi_const 3, 4 |> &sum:L, &sum:R &sum |> &out:L """ result = assemble(source) assert result is not None assert len(result.pe_configs) > 0 env = simpy.Environment() sys = build_topology(env, result.pe_configs, result.sm_configs) for setup in result.setup_tokens: sys.inject(setup) for seed in result.seed_tokens: sys.inject(seed) env.run(until=500) all_values = [] for pe in sys.pes.values(): all_values.extend(t.data for t in pe.output_log if hasattr(t, "data")) assert 7 in all_values, f"Expected 3+4=7 in outputs, got {all_values}"