"""Tests for function call wiring in the expand pass. Verifies AC4.1 through AC4.10: call syntax parsing, lowering, and wiring. """ import pytest from asm import _get_parser from asm.expand import expand from asm.ir import IRRegion, RegionKind, CallSite from asm.errors import ErrorCategory from cm_inst import Port, RoutingOp from tests.pipeline import parse_and_lower def parse_lower_expand(source: str): """Parse, lower, and expand dfasm source.""" parser = _get_parser() graph = parse_and_lower(parser, source) return expand(graph) # AC4.1: Call syntax with ctx_override edges def test_call_syntax_basic(): """AC4.1: Parse and lower function call with named argument.""" parser = _get_parser() source = """ $add |> { &a <| pass &sum <| pass } &x <| const, 5 $add a=&x |> @result """ graph = parse_and_lower(parser, source) # Verify raw_call_sites was populated assert len(graph.raw_call_sites) == 1 call_site = graph.raw_call_sites[0] assert call_site.func_name == "$add" assert len(call_site.input_args) == 1 # Input args are stored as (param_name, ref_dict) param_name, ref = call_site.input_args[0] assert param_name == "a" assert isinstance(ref, dict) and ref.get("name") == "&x" def test_call_syntax_multiple_args(): """AC4.1: Call with multiple named arguments.""" parser = _get_parser() source = """ $adder |> { &a <| pass &b <| pass } &x <| const, 3 &y <| const, 4 $adder a=&x, b=&y |> @out """ graph = parse_and_lower(parser, source) assert len(graph.raw_call_sites) == 1 call_site = graph.raw_call_sites[0] assert len(call_site.input_args) == 2 # Check args are present with correct parameter names param_names = [arg[0] for arg in call_site.input_args] assert "a" in param_names assert "b" in param_names def test_call_syntax_positional_args(): """AC4.1: Call with positional arguments.""" parser = _get_parser() source = """ $func |> { &a <| pass } &x <| const, 1 $func &x |> @result """ graph = parse_and_lower(parser, source) assert len(graph.raw_call_sites) == 1 call_site = graph.raw_call_sites[0] # Positional args are stored as (None, ref_dict) tuples assert len(call_site.input_args) == 1 assert call_site.input_args[0][0] is None # param_name is None for positional assert call_site.input_args[0][1]["name"] == "&x" # source ref is &x # AC4.2: Synthetic @ret rendezvous node creation def test_synthetic_ret_node_creation(): """AC4.2: Expand creates synthetic $func.@ret pass node.""" source = """ $add |> { &a <| pass &b <| pass &sum <| add &a |> &sum:L &b |> &sum:R &sum |> @ret } &x <| const, 3 &y <| const, 4 $add a=&x, b=&y |> @result """ graph = parse_lower_expand(source) # Check that synthetic node was created assert "$add.@ret" in graph.nodes synthetic_ret = graph.nodes["$add.@ret"] assert synthetic_ret.opcode == RoutingOp.PASS def test_trampoline_node_creation(): """AC4.2: Expand creates per-call-site trampoline pass node.""" source = """ $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 5 $add a=&x |> @result """ graph = parse_lower_expand(source) # Check that trampoline node was created trampoline_found = False for node_name in graph.nodes: if node_name.startswith("$add.__ret_trampoline_"): trampoline_found = True tramp = graph.nodes[node_name] assert tramp.opcode == RoutingOp.PASS break assert trampoline_found, "No trampoline node found" # AC4.3: Named returns with dual outputs def test_named_returns_multiple(): """AC4.3: Multiple named @ret_name variants create separate synthetic nodes.""" source = """ $adder |> { &a <| pass &b <| pass &sum <| add &carry <| pass &a |> &sum:L &b |> &sum:R &sum |> @ret_sum &carry |> @ret_carry } &three <| const, 3 &two <| const, 2 $adder a=&three, b=&two |> sum=@s, carry=@c """ graph = parse_lower_expand(source) # Check for both synthetic nodes assert "$adder.@ret_sum" in graph.nodes assert "$adder.@ret_carry" in graph.nodes # Both should be pass nodes assert graph.nodes["$adder.@ret_sum"].opcode == RoutingOp.PASS assert graph.nodes["$adder.@ret_carry"].opcode == RoutingOp.PASS # AC4.4: Named output wiring def test_named_output_wiring(): """AC4.4: Call with sum=@dest wires trampoline to specified destination.""" source = """ $add |> { &a <| pass &sum <| add &sum |> @ret_sum } &x <| const, 5 $add a=&x |> sum=@my_output """ graph = parse_lower_expand(source) # Verify trampoline exists trampoline_found = False for node_name in graph.nodes: if "__ret_trampoline_" in node_name: trampoline_found = True break assert trampoline_found, "No trampoline node found" # Find edge from trampoline to @my_output with ctx_override tramp_to_output_found = False for edge in graph.edges: if isinstance(edge.dest, str) and "my_output" in edge.dest: if "trampoline" in str(edge.source): tramp_to_output_found = True assert edge.ctx_override == True break assert tramp_to_output_found, "Trampoline to output edge not found" # AC4.5: free_ctx auto-insertion def test_free_ctx_auto_insertion(): """AC4.5: free_ctx node auto-inserted on every return path.""" source = """ $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 5 $add a=&x |> @result """ graph = parse_lower_expand(source) # Check for free_frame node free_frame_found = False for node_name in graph.nodes: if node_name.startswith("$add.__free_frame_"): free_frame_found = True free_frame = graph.nodes[node_name] assert free_frame.opcode == RoutingOp.FREE_FRAME break assert free_frame_found, "No free_frame node found" # Check that trampoline's dest_r wires to free_frame tramp_to_free_found = False for edge in graph.edges: if "trampoline" in edge.source and "free_frame" in edge.dest: tramp_to_free_found = True assert edge.source_port == Port.R # Output from trampoline R port break assert tramp_to_free_found, "Trampoline to free_frame edge not found" # AC4.6: Multiple call sites get distinct contexts and trampolines def test_multiple_call_sites(): """AC4.6: Two calls to same function get separate trampolines and ctx slots.""" source = """ $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 3 &y <| const, 4 $add a=&x |> @r1 $add a=&y |> @r2 """ graph = parse_lower_expand(source) # Check that we have 2 CallSite entries assert len(graph.call_sites) == 2 assert graph.call_sites[0].call_id == 0 assert graph.call_sites[1].call_id == 1 # Check that we have separate trampolines trampoline_names = [ n for n in graph.nodes.keys() if "__ret_trampoline_" in n ] assert len(trampoline_names) == 2, f"Expected 2 trampolines, got {len(trampoline_names)}" # AC4.7: Cross-PE function calls def test_cross_pe_function_calls(): """AC4.7: Call from one PE to function on another PE.""" source = """ @system pe=2, sm=1 $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 5 $add a=&x |> @result """ graph = parse_lower_expand(source) # Verify function region exists func_region = None for region in graph.regions: if region.kind == RegionKind.FUNCTION and region.tag == "$add": func_region = region break assert func_region is not None # Check input edges have ctx_override=True input_ctx_override_found = False for edge in graph.edges: if edge.ctx_override: input_ctx_override_found = True break assert input_ctx_override_found, "Input edge with ctx_override not found" # AC4.9: Named arg not matching any label produces NAME error def test_undefined_argument_label(): """AC4.9: Call with argument that doesn't match any label in function.""" source = """ $add |> { &a <| pass } &x <| const, 5 $add b=&x |> @result """ graph = parse_lower_expand(source) # Should have an error about argument 'b' not matching errors = graph.errors assert len(errors) > 0 assert any( err.category == ErrorCategory.CALL and "b" in err.message for err in errors ) # AC4.10: Call to undefined function produces NAME error def test_undefined_function(): """AC4.10: Call to non-existent function produces NAME error.""" source = """ &x <| const, 5 $nonexistent a=&x |> @result """ graph = parse_lower_expand(source) # Should have an error about undefined function errors = graph.errors assert len(errors) > 0 assert any( err.category == ErrorCategory.CALL and "undefined" in err.message for err in errors ) def test_call_site_metadata(): """CallSite metadata correctly populated.""" source = """ $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 5 $add a=&x |> @result """ graph = parse_lower_expand(source) assert len(graph.call_sites) == 1 call_site = graph.call_sites[0] assert call_site.func_name == "$add" assert call_site.call_id == 0 assert len(call_site.trampoline_nodes) > 0 assert len(call_site.free_frame_nodes) > 0 def test_input_edges_use_inherit_not_ctx_override(): """Input edges from call site to function parameters use INHERIT, not CHANGE_TAG. In the frame-based model, cross-context routing for input edges is handled by the FrameDest's act_id (which differs between caller and function activation). Only return trampolines use ctx_override/CHANGE_TAG because they decode a packed FrameDest from EXTRACT_TAG. """ source = """ $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 5 $add a=&x |> @result """ graph = parse_lower_expand(source) # Since &x has const=5, a trampoline is inserted: &x -> trampoline -> $add.&a # Neither edge should have ctx_override — input routing uses INHERIT mode for edge in graph.edges: if "$add.&a" in str(edge.dest): assert not edge.ctx_override, ( f"Input edge to $add.&a should NOT have ctx_override " f"(frame-based routing uses INHERIT with act_id in FrameDest)" ) # Only return trampoline edges should have ctx_override ret_ctx_override = [e for e in graph.edges if e.ctx_override] assert len(ret_ctx_override) > 0, "Return trampoline should have ctx_override" for e in ret_ctx_override: assert "__ret_trampoline" in e.source, ( f"Only return trampolines should have ctx_override, got: {e.source} -> {e.dest}" ) def test_shared_function_body(): """Function body nodes and edges are shared across multiple call sites.""" source = """ $add |> { &a <| pass &sum <| add &sum |> @ret } &x <| const, 3 &y <| const, 4 $add a=&x |> @r1 $add a=&y |> @r2 """ graph = parse_lower_expand(source) # Verify that we have 2 call sites assert len(graph.call_sites) == 2 # Function body nodes are in the function region, not top-level # Just verify the structure is correct func_region = None for region in graph.regions: if region.kind == RegionKind.FUNCTION and region.tag == "$add": func_region = region break assert func_region is not None # Body should have &a and &sum nodes assert "$add.&a" in func_region.body.nodes assert "$add.&sum" in func_region.body.nodes if __name__ == "__main__": pytest.main([__file__, "-v"])