"""Macro expansion pass for the OR1 assembler. This module implements macro invocation expansion (Phase 2). It processes IRMacroCall entries from the lowering pass, expands them by cloning and substituting macro bodies, and qualifies expanded names with scope prefixes. The expand() function receives an IRGraph from lower, processes all macro definitions and invocations, and returns a clean IRGraph with all macro artefacts removed. """ from __future__ import annotations import ast from dataclasses import replace from typing import Optional from collections.abc import Iterable from asm.errors import AssemblyError, ErrorCategory from asm.ir import ( IRGraph, IRNode, IREdge, IRRegion, RegionKind, ParamRef, ConstExpr, MacroDef, IRMacroCall, CallSiteResult, CallSite, IRRepetitionBlock, SourceLoc, PlacementRef, PortRef, ActSlotRef, ActSlotRange, ) from asm.opcodes import MNEMONIC_TO_OP from cm_inst import Port, RoutingOp MAX_EXPANSION_DEPTH = 32 def _levenshtein(a: str, b: str) -> int: """Compute Levenshtein (edit) distance between two strings. Note: This is duplicated from asm/resolve.py. If a third copy appears, extract to a shared utility module. Args: a: First string b: Second string Returns: Minimum edit distance (number of single-character edits) """ if len(a) < len(b): return _levenshtein(b, a) if not b: return len(a) prev = list(range(len(b) + 1)) for i, ca in enumerate(a): curr = [i + 1] for j, cb in enumerate(b): curr.append(min( prev[j + 1] + 1, # deletion curr[j] + 1, # insertion prev[j] + (ca != cb), # substitution )) prev = curr return prev[-1] def _suggest_names(unresolved: str, available_names: Iterable[str]) -> list[str]: """Generate "did you mean" suggestions via Levenshtein distance. Compares unresolved name against all available names, returning suggestions with distance <= 3, or the closest match if all distances are > 3. Args: unresolved: The unresolved name available_names: Iterable of available macro names Returns: List of suggestion strings (may be empty) """ if not available_names: return [] # Compute distances candidates = [] for name in available_names: dist = _levenshtein(unresolved, name) candidates.append((dist, name)) # Sort by distance candidates.sort(key=lambda x: x[0]) # Return suggestions with distance <= 3, or best if all > 3 suggestions = [] best_distance = candidates[0][0] for dist, name in candidates: if dist <= 3 or dist == best_distance: suggestions.append(f"Did you mean '#{name}'?") else: break return suggestions def _substitute_param( value: object, subst_map: dict[str, object], ) -> object: """Resolve a ParamRef or name against the substitution map. Supports token pasting: ParamRef with prefix/suffix concatenates the parameter value with the prefix and suffix to form a new name. For const fields, returns the actual int value. For names, returns the ref name string (possibly qualified). Args: value: The value to substitute (could be ParamRef, int, str, etc.) subst_map: Map of formal param names to actual argument values Returns: The substituted value, or unchanged if not a ParamRef/param name. If ParamRef has prefix/suffix, returns concatenated string. """ if isinstance(value, ParamRef): # Look up the parameter in the substitution map actual = subst_map.get(value.param) if actual is not None: # Extract name from dict refs (e.g., {"name": "&x"} -> "&x") if isinstance(actual, dict) and "name" in actual: actual = actual["name"] # Handle token pasting with prefix/suffix if value.prefix or value.suffix: # Convert actual value to string if it's an int actual_str = str(actual) if isinstance(actual, int) else actual # Concatenate: prefix + value + suffix return value.prefix + actual_str + value.suffix else: # No prefix/suffix: return the actual value as-is return actual # Parameter not found - return unchanged (should not happen with proper validation) return value # For string names, check if they match a formal parameter if isinstance(value, str): # Don't substitute sigil-prefixed names (they may be qualified later) if value and value[0] in "&@$#": return value # Check if this name is a formal parameter if value in subst_map: return subst_map[value] return value def _eval_node(node, bindings: dict[str, int]) -> int: """Evaluate a single AST node in a constant expression. Args: node: An ast node (Constant, Name, BinOp, or UnaryOp) bindings: Map of parameter names to integer values Returns: The evaluated integer result Raises: ValueError: If node type is unsupported or value is non-numeric """ if isinstance(node, ast.Constant) and isinstance(node.value, int): return node.value elif isinstance(node, ast.Name): if node.id not in bindings: raise ValueError(f"Undefined parameter: {node.id}") val = bindings[node.id] if not isinstance(val, int): raise ValueError(f"Non-numeric value in arithmetic context") return val elif isinstance(node, ast.BinOp): left = _eval_node(node.left, bindings) right = _eval_node(node.right, bindings) if isinstance(node.op, ast.Add): return left + right elif isinstance(node.op, ast.Sub): return left - right elif isinstance(node.op, ast.Mult): return left * right elif isinstance(node.op, ast.FloorDiv): if right == 0: raise ValueError("division by zero") return left // right else: raise ValueError(f"Unsupported operator: {type(node.op).__name__}") elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): return -_eval_node(node.operand, bindings) else: raise ValueError(f"Unsupported expression node: {type(node).__name__}") def _eval_const_expr(expr: str, bindings: dict[str, int]) -> int: """Evaluate a simple arithmetic expression with parameter bindings. Supports: integer literals, +, -, *, // (integer division), parentheses. No eval() call — safe AST walking only. Args: expr: Expression string, e.g. "base + 1" bindings: Map of parameter names to integer values Returns: The evaluated integer result Raises: ValueError: If expression is invalid or contains non-numeric values """ tree = ast.parse(expr, mode='eval') return _eval_node(tree.body, bindings) def _qualify_expanded_name( name: str, macro_scope: str, parent_scope: str = "", func_scope: Optional[str] = None, ) -> str: """Apply scope prefix to an expanded name. Takes a name and applies macro and optional function scopes. Names starting with & are qualified; other sigils pass through. Args: name: The original name from the macro body macro_scope: The macro scope (e.g., "#loop_counted_0") parent_scope: Optional parent macro scope (e.g., "#outer_0") func_scope: Optional function scope (e.g., "$main") Returns: The qualified name """ if not name: return name # Check if it's a label (starts with &) if name.startswith("&"): # Build full scope: [func_scope.][parent_scope.]macro_scope.name if func_scope and parent_scope: # Triple-scoped: $func.#parent_N.#macro_M.&label return f"{func_scope}.{parent_scope}.{macro_scope}.{name}" elif parent_scope: # Double-scoped: #parent_N.#macro_M.&label return f"{parent_scope}.{macro_scope}.{name}" elif func_scope: # Double-scoped: $func.#macro_N.&label return f"{func_scope}.{macro_scope}.{name}" else: # Single-scoped: #macro_N.&label return f"{macro_scope}.{name}" # Other sigils (@, $, #) pass through unqualified return name def _clone_and_substitute_node( node: IRNode, macro_scope: str, subst_map: dict[str, object], func_scope: Optional[str] = None, parent_scope: str = "", ) -> tuple[IRNode, list[AssemblyError]]: """Deep-clone a node and substitute parameters. Args: node: The template node from the macro body macro_scope: The macro scope for qualification subst_map: Map of formal params to actual arguments func_scope: Optional function scope parent_scope: Optional parent macro scope Returns: Tuple of (new IRNode with substitutions applied and name qualified, errors list) """ errors = [] # Substitute the const field new_const = _substitute_param(node.const, subst_map) # If const is a ConstExpr, evaluate it if isinstance(new_const, ConstExpr): try: # Build bindings dict from subst_map, converting all to int bindings = {} for param_name in new_const.params: if param_name in subst_map: val = subst_map[param_name] if not isinstance(val, int): errors.append(AssemblyError( loc=new_const.loc, category=ErrorCategory.VALUE, message=f"Non-numeric value '{val}' in arithmetic context", )) # Return node with ConstExpr unchanged (will be caught later) substituted_name = _substitute_param(node.name, subst_map) if not isinstance(substituted_name, str): substituted_name = str(substituted_name) new_name = _qualify_expanded_name(substituted_name, macro_scope, parent_scope, func_scope) return replace(node, name=new_name, const=new_const), errors bindings[param_name] = val # Evaluate the expression evaluated = _eval_const_expr(new_const.expression, bindings) new_const = evaluated except ValueError as e: errors.append(AssemblyError( loc=new_const.loc, category=ErrorCategory.VALUE, message=str(e), )) # Resolve opcode if it's a ParamRef new_opcode = node.opcode if isinstance(new_opcode, ParamRef): resolved = _substitute_param(new_opcode, subst_map) if isinstance(resolved, str): if resolved in MNEMONIC_TO_OP: new_opcode = MNEMONIC_TO_OP[resolved] else: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.MACRO, message=f"'{resolved}' is not a valid opcode mnemonic", )) new_opcode = node.opcode else: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.MACRO, message=f"opcode parameter must resolve to an opcode mnemonic, got {type(resolved).__name__}", )) new_opcode = node.opcode # Resolve placement if it's a PlacementRef new_pe = node.pe if isinstance(new_pe, PlacementRef): resolved = _substitute_param(new_pe.param, subst_map) if isinstance(resolved, str) and resolved.startswith("pe"): try: new_pe = int(resolved[2:]) except ValueError: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.MACRO, message=f"placement parameter must resolve to 'peN', got '{resolved}'", )) new_pe = None elif isinstance(resolved, int): new_pe = resolved else: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.MACRO, message=f"placement parameter must resolve to 'peN', got {type(resolved).__name__}", )) new_pe = None # Resolve act_slot if it's a ActSlotRef new_act_slot = node.act_slot if isinstance(new_act_slot, ActSlotRef): resolved = _substitute_param(new_act_slot.param, subst_map) if isinstance(resolved, int): new_act_slot = ActSlotRange(start=resolved, end=resolved) else: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.MACRO, message=f"act_slot parameter must resolve to an integer, got {type(resolved).__name__}", )) new_act_slot = None # Substitute the node name (may be a ParamRef with token pasting) substituted_name = _substitute_param(node.name, subst_map) # Ensure name is a string before qualification if not isinstance(substituted_name, str): substituted_name = str(substituted_name) # Qualify the node name new_name = _qualify_expanded_name(substituted_name, macro_scope, parent_scope, func_scope) return replace(node, name=new_name, const=new_const, opcode=new_opcode, pe=new_pe, act_slot=new_act_slot), errors def _clone_and_substitute_edge( edge: IREdge, macro_scope: str, subst_map: dict[str, object], func_scope: Optional[str] = None, parent_scope: str = "", ) -> tuple[IREdge, list[AssemblyError]]: """Deep-clone an edge and substitute/qualify names. Args: edge: The template edge from the macro body macro_scope: The macro scope for qualification subst_map: Map of formal params to actual arguments func_scope: Optional function scope parent_scope: Optional parent macro scope Returns: Tuple of (new IREdge with names qualified, list of errors) """ errors: list[AssemblyError] = [] # Substitute source and dest names. # Track whether each was a ParamRef — substituted refs are external # and must NOT be qualified with the macro scope. source_was_param = isinstance(edge.source, ParamRef) source = _substitute_param(edge.source, subst_map) if not isinstance(source, str): source = str(source) dest_was_param = isinstance(edge.dest, ParamRef) dest = _substitute_param(edge.dest, subst_map) if not isinstance(dest, str): dest = str(dest) # Only qualify names that came from the macro body template directly. # Substituted parameter refs point to external names and stay unqualified. if not source_was_param: source = _qualify_expanded_name(source, macro_scope, parent_scope, func_scope) if not dest_was_param: dest = _qualify_expanded_name(dest, macro_scope, parent_scope, func_scope) # Resolve PortRef on dest port new_port = edge.port if isinstance(new_port, PortRef): resolved = _substitute_param(new_port.param, subst_map) if isinstance(resolved, str): if resolved == "L": new_port = Port.L elif resolved == "R": new_port = Port.R else: errors.append(AssemblyError( loc=edge.loc, category=ErrorCategory.MACRO, message=f"port parameter must resolve to 'L' or 'R', got '{resolved}'", )) new_port = Port.L elif isinstance(resolved, Port): new_port = resolved else: errors.append(AssemblyError( loc=edge.loc, category=ErrorCategory.MACRO, message=f"port parameter must resolve to 'L' or 'R', got '{resolved}'", )) new_port = Port.L # Resolve PortRef on source port new_source_port = edge.source_port if isinstance(new_source_port, PortRef): resolved = _substitute_param(new_source_port.param, subst_map) if isinstance(resolved, str): if resolved == "L": new_source_port = Port.L elif resolved == "R": new_source_port = Port.R else: errors.append(AssemblyError( loc=edge.loc, category=ErrorCategory.MACRO, message=f"source port parameter must resolve to 'L' or 'R', got '{resolved}'", )) new_source_port = None elif isinstance(resolved, Port): new_source_port = resolved else: errors.append(AssemblyError( loc=edge.loc, category=ErrorCategory.MACRO, message=f"source port parameter must resolve to 'L' or 'R', got '{resolved}'", )) new_source_port = None return replace(edge, source=source, dest=dest, port=new_port, source_port=new_source_port), errors def _add_expansion_context( error: AssemblyError, call: IRMacroCall, builtin_line_offset: int = 0, ) -> AssemblyError: """Add expansion context to an error. Appends "expanded from #macro_name at line N, column C" to the context_lines to trace the error back to the macro invocation site. Args: error: The error to enhance call: The IRMacroCall being expanded builtin_line_offset: Lines to subtract for display (built-in macro prefix) Returns: New AssemblyError with expansion context added to context_lines """ display_line = call.loc.line if builtin_line_offset > 0 and display_line > builtin_line_offset: display_line -= builtin_line_offset expansion_context = ( f"expanded from #{call.name} at line {display_line}, " f"column {call.loc.column}" ) return replace( error, context_lines=list(error.context_lines) + [expansion_context], ) def _expand_repetition_block( rep_block: IRRepetitionBlock, variadic_args: list[object], macro_scope: str, subst_map: dict[str, object], func_scope: Optional[str] = None, parent_scope: str = "", ) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]: """Expand a repetition block once per variadic argument. For each iteration, clones the body, substitutes the variadic param and ${_idx} (iteration index), and qualifies names. Args: rep_block: The IRRepetitionBlock to expand variadic_args: List of actual arguments for the variadic parameter macro_scope: The macro scope for qualification subst_map: Base substitution map (will be extended with variadic param and _idx) func_scope: Optional function scope parent_scope: Optional parent macro scope Returns: Tuple of (expanded_nodes dict, expanded_edges list, errors list) """ errors = [] expanded_nodes: dict[str, IRNode] = {} expanded_edges: list[IREdge] = [] # Iterate over variadic arguments for idx, arg_value in enumerate(variadic_args): # Create iteration-specific substitution map iter_subst_map = dict(subst_map) iter_subst_map[rep_block.variadic_param] = arg_value iter_subst_map["_idx"] = idx # Make iteration index available as parameter # Clone and substitute nodes from the repetition body for node_name, node in rep_block.body.nodes.items(): # Create a unique name for this iteration # Qualify the node name with macro scope and iteration suffix qualified_node, node_errors = _clone_and_substitute_node( node, f"{macro_scope}_rep{idx}", iter_subst_map, func_scope, parent_scope, ) errors.extend(node_errors) expanded_nodes[qualified_node.name] = qualified_node # Clone and substitute edges from the repetition body for edge in rep_block.body.edges: qualified_edge, edge_errors = _clone_and_substitute_edge( edge, f"{macro_scope}_rep{idx}", iter_subst_map, func_scope, parent_scope, ) errors.extend(edge_errors) expanded_edges.append(qualified_edge) return expanded_nodes, expanded_edges, errors def _expand_call( call: IRMacroCall, macro_table: dict[str, MacroDef], expansion_counter: list[int], func_scope: Optional[str] = None, parent_scope: str = "", depth: int = 0, builtin_line_offset: int = 0, ) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]: """Process a single macro call. Looks up the macro, validates arity, builds substitution map, clones the body, and performs parameter substitution and name qualification. Args: call: The IRMacroCall to expand macro_table: Map of macro names to MacroDef objects expansion_counter: [int] list for mutable counter (incremented per expansion) func_scope: Optional function scope the call is in parent_scope: Optional parent macro scope (for nested macros) depth: Recursion depth (error if exceeds 32) builtin_line_offset: Lines to subtract for display in error context Returns: Tuple of (expanded_nodes dict, expanded_edges list, errors list) """ errors = [] # Check depth limit if depth > MAX_EXPANSION_DEPTH: error = AssemblyError( loc=call.loc, category=ErrorCategory.MACRO, message=f"macro expansion depth exceeds {MAX_EXPANSION_DEPTH} (likely infinite recursion in macro '{call.name}')", ) return {}, [], [error] # Look up macro definition if call.name not in macro_table: suggestions = _suggest_names(call.name, macro_table.keys()) error = AssemblyError( loc=call.loc, category=ErrorCategory.MACRO, message=f"undefined macro '#{call.name}'", suggestions=suggestions, ) return {}, [], [error] macro_def = macro_table[call.name] # Validate arity and separate variadic arguments total_args = len(call.positional_args) + len(call.named_args) # Count required parameters (non-variadic) required_params = [p for p in macro_def.params if not p.variadic] variadic_param = next((p for p in macro_def.params if p.variadic), None) if variadic_param: # With variadic: need at least as many args as required params if total_args < len(required_params): error = AssemblyError( loc=call.loc, category=ErrorCategory.MACRO, message=f"macro '#{call.name}' expects at least {len(required_params)} argument(s), got {total_args}", ) return {}, [], [error] else: # Without variadic: exact match required expected_count = len(macro_def.params) if total_args != expected_count: error = AssemblyError( loc=call.loc, category=ErrorCategory.MACRO, message=f"macro '#{call.name}' expects {expected_count} argument(s), got {total_args}", ) return {}, [], [error] # Build substitution map subst_map: dict[str, object] = {} variadic_args: list[object] = [] # Add positional arguments for i, actual_value in enumerate(call.positional_args): if i < len(required_params): # Regular parameter param_name = required_params[i].name subst_map[param_name] = actual_value elif variadic_param: # Extra arguments go to variadic parameter variadic_args.append(actual_value) # Add named arguments (to required params only; named variadic args not supported) for param_name, actual_value in call.named_args: subst_map[param_name] = actual_value # Generate unique macro scope expansion_id = expansion_counter[0] expansion_counter[0] += 1 macro_scope = f"#{call.name}_{expansion_id}" # Recursively expand and qualify the macro body, including nested calls def _expand_body_recursive( body: IRGraph, depth: int, ) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]: """Recursively expand all macro calls in a body graph and its regions.""" body_errors: list[AssemblyError] = [] body_nodes: dict[str, IRNode] = {} body_edges: list[IREdge] = [] # Qualify and add the body's own nodes for node_name, node in body.nodes.items(): qualified_node, node_errors = _clone_and_substitute_node(node, macro_scope, subst_map, func_scope, parent_scope) # Add expansion context to node-level errors (const expression evaluation, etc.) for error in node_errors: body_errors.append(_add_expansion_context(error, call, builtin_line_offset)) body_nodes[qualified_node.name] = qualified_node # Qualify the body's own edges for edge in body.edges: qualified_edge, edge_errors = _clone_and_substitute_edge(edge, macro_scope, subst_map, func_scope, parent_scope) body_errors.extend(edge_errors) body_edges.append(qualified_edge) # Expand macro calls at this body level # Nested calls have current macro_scope as their parent_scope for nested_call in body.macro_calls: nested_expanded_nodes, nested_expanded_edges, nested_errors = _expand_call( nested_call, macro_table, expansion_counter, func_scope, macro_scope, # Current macro scope becomes parent for nested depth + 1, builtin_line_offset, ) # Add expansion context to nested errors (trace them back to the nested call) for error in nested_errors: body_errors.append(_add_expansion_context(error, nested_call, builtin_line_offset)) body_nodes.update(nested_expanded_nodes) # Filter out leaked @ret edges from failed inner expansions to prevent # spurious "defines output(s) @ret" errors at the outer macro level for nested_edge in nested_expanded_edges: if isinstance(nested_edge.dest, str) and nested_edge.dest.startswith("@ret"): continue body_edges.append(nested_edge) # Expand repetition blocks (Phase 6 variadic macros) if variadic_param: for rep_block in macro_def.repetition_blocks: # Only expand blocks for the current variadic parameter if rep_block.variadic_param == variadic_param.name: rep_nodes, rep_edges, rep_errors = _expand_repetition_block( rep_block, variadic_args, macro_scope, subst_map, func_scope, parent_scope, ) body_errors.extend(rep_errors) body_nodes.update(rep_nodes) body_edges.extend(rep_edges) # Recursively expand regions in the body for region in body.regions: region_func_scope = region.tag if region.kind == RegionKind.FUNCTION else func_scope region_nodes, region_edges, region_errors = _expand_body_recursive( region.body, depth + 1, ) body_errors.extend(region_errors) body_nodes.update(region_nodes) body_edges.extend(region_edges) return body_nodes, body_edges, body_errors expanded_nodes, expanded_edges, nested_errors = _expand_body_recursive( macro_def.body, depth, ) errors.extend(nested_errors) # Rewrite @ret edges: either substitute with output destinations, or # report an error if the macro body uses @ret but the call site doesn't # provide output wiring. has_ret_edges = any( isinstance(e.dest, str) and e.dest.startswith("@ret") for e in expanded_edges ) if has_ret_edges and not call.output_dests: # Collect the @ret markers for the error message ret_markers = sorted({ e.dest for e in expanded_edges if isinstance(e.dest, str) and e.dest.startswith("@ret") }) errors.append(AssemblyError( loc=call.loc, category=ErrorCategory.MACRO, message=f"macro '#{call.name}' defines output(s) {', '.join(ret_markers)} but call site has no '|>' output wiring", )) if call.output_dests: rewritten_edges = [] all_outputs = list(call.output_dests) # Build ordered list of positional outputs for bare @ret resolution. # Each bare @ret consumes the next positional output in order, # enabling variadic macros to wire each iteration to a separate dest: # #macro *vals |> { $( &c <| const, ${vals}; &c |> @ret ),* } # #macro 3, 4 |> &x, &y ← iteration 0 → &x, iteration 1 → &y positional_outputs: list[str] = [] for output in all_outputs: if isinstance(output, dict): if "name" in output and "ref" in output: continue # Named output — not positional name = output.get("name", None) if name is not None: positional_outputs.append(name) else: positional_outputs.append(str(output)) positional_idx = 0 for edge in expanded_edges: if not (isinstance(edge.dest, str) and edge.dest.startswith("@ret")): rewritten_edges.append(edge) continue # This edge targets an @ret marker — resolve it ret_dest = edge.dest # e.g. "@ret" or "@ret_body" dest_name = None # Try named match: @ret_body -> output with name="body" if ret_dest.startswith("@ret_"): expected_suffix = ret_dest[5:] # "body" from "@ret_body" for output in all_outputs: if isinstance(output, dict) and "name" in output and "ref" in output: if output["name"] == expected_suffix: ref = output["ref"] dest_name = ref.get("name", ref) if isinstance(ref, dict) else str(ref) break # Bare @ret -> next positional output (advances counter) if dest_name is None and ret_dest == "@ret": if positional_idx < len(positional_outputs): dest_name = positional_outputs[positional_idx] positional_idx += 1 if dest_name is None: errors.append(AssemblyError( loc=call.loc, category=ErrorCategory.MACRO, message=f"macro '#{call.name}' has output marker '{ret_dest}' but no matching output destination in call site", )) rewritten_edges.append(edge) continue # Replace the @ret destination with the concrete node reference rewritten_edges.append(replace(edge, dest=dest_name)) expanded_edges = rewritten_edges # Propagate errors from macro body template for body_error in macro_def.body.errors: # Adjust source location to point to the call site adjusted_error = replace( body_error, loc=call.loc, suggestions=list(body_error.suggestions) + [ f"defined in macro #{macro_def.name} at line {macro_def.loc.line}" ], ) errors.append(adjusted_error) return expanded_nodes, expanded_edges, errors def _expand_graph_recursive( graph: IRGraph, macro_table: dict[str, MacroDef], expansion_counter: list[int], func_scope: Optional[str] = None, builtin_line_offset: int = 0, ) -> tuple[IRGraph, list[AssemblyError]]: """Recursively expand macros in a graph and its regions. Args: graph: The IRGraph to expand macro_table: Map of macro names to MacroDef objects expansion_counter: [int] list for mutable counter func_scope: Optional function scope for name qualification builtin_line_offset: Lines to subtract for display in error context Returns: Tuple of (new_graph, all_errors) """ new_errors: list[AssemblyError] = [] expanded_nodes: dict[str, IRNode] = dict(graph.nodes) expanded_edges: list[IREdge] = list(graph.edges) # Collect all macro calls from this graph level # Note: The lower pass doesn't populate macro_calls in regions, # so we also need to collect from macro_calls in the graph all_calls_at_level = list(graph.macro_calls) # Expand all macro calls at this level for call in all_calls_at_level: # Determine the function scope for nested calls call_func_scope = func_scope call_expanded_nodes, call_expanded_edges, call_errors = _expand_call( call, macro_table, expansion_counter, call_func_scope, "", # No parent scope at top level builtin_line_offset=builtin_line_offset, ) for error in call_errors: new_errors.append(_add_expansion_context(error, call, builtin_line_offset)) expanded_nodes.update(call_expanded_nodes) expanded_edges.extend(call_expanded_edges) # Recursively expand regions (function bodies, etc.) new_regions: list[IRRegion] = [] for region in graph.regions: # For function regions, pass the region tag as the func_scope for name qualification region_func_scope = region.tag if region.kind == RegionKind.FUNCTION else func_scope new_body, region_errors = _expand_graph_recursive( region.body, macro_table, expansion_counter, region_func_scope, builtin_line_offset, ) new_errors.extend(region_errors) new_region = replace(region, body=new_body) new_regions.append(new_region) # Create new graph with expanded content and no macro artefacts new_graph = replace( graph, nodes=expanded_nodes, edges=expanded_edges, regions=new_regions, macro_defs=[], # Remove all macro defs macro_calls=[], # Remove all macro calls ) return new_graph, new_errors def _wire_call_site( call_site: CallSiteResult, graph: IRGraph, call_id: int, wired_nodes: dict[str, IRNode], wired_edges: list[IREdge], processed_ret_nodes: set, function_ret_destinations: dict[str, set], ) -> tuple[CallSite, list[AssemblyError]]: """Process a single function call site and wire it into the graph. This function: 1. Finds the function definition in the graph's regions 2. Matches input arguments to function labels 3. Synthesises @ret rendezvous nodes (shared across call sites) 4. Creates per-call-site trampolines and free_frame nodes 5. Wires everything together with ctx_override edges Args: call_site: The CallSiteResult from the lower pass graph: The IRGraph containing regions (functions) call_id: Unique ID for this call site wired_nodes: Dictionary to accumulate generated nodes wired_edges: List to accumulate generated edges processed_ret_nodes: Cache of already-synthesised @ret nodes (func_name.@ret -> node_name) Returns: Tuple of (CallSite metadata, errors list) """ errors = [] # Find the function definition in the graph's regions func_region = None for region in graph.regions: if region.kind == RegionKind.FUNCTION and region.tag == call_site.func_name: func_region = region break if func_region is None: error = AssemblyError( loc=call_site.loc, category=ErrorCategory.CALL, message=f"undefined function '{call_site.func_name}'", ) return CallSite( func_name=call_site.func_name, call_id=call_id, ), [error] # Collect all nodes in the function body (including nested regions) func_all_nodes = {} func_all_edges = [] def _collect_from_region(r: IRGraph): func_all_nodes.update(r.nodes) func_all_edges.extend(r.edges) for sub_region in r.regions: _collect_from_region(sub_region.body) _collect_from_region(func_region.body) input_edge_names = [] trampoline_nodes = [] free_frame_nodes = [] # Process input arguments: match each to a label in the function for param_name, source_ref in call_site.input_args: # source_ref may be a dict with {"name": "..."} or a simple string if isinstance(source_ref, dict): source_name = source_ref.get("name", str(source_ref)) else: source_name = str(source_ref) # Look for a label ¶m_name in the function target_label = f"{call_site.func_name}.&{param_name}" if target_label not in func_all_nodes: error = AssemblyError( loc=call_site.loc, category=ErrorCategory.CALL, message=f"argument '{param_name}' does not match any label in '{call_site.func_name}'", ) errors.append(error) continue # Check if source node has a const (AC5.3: const+CTX_OVRD conflict) # If so, insert a pass trampoline between source and target source_node = graph.nodes.get(source_name) if source_node is not None and source_node.const is not None: # Insert a pass trampoline to separate const from ctx_override tramp_name = f"{call_site.func_name}.__input_tramp_{call_id}_{param_name}" tramp_node = IRNode( name=tramp_name, opcode=RoutingOp.PASS, loc=call_site.loc, ) wired_nodes[tramp_name] = tramp_node # Wire: source -> trampoline (no ctx_override, inherits ctx) source_to_tramp = IREdge( source=source_name, dest=tramp_name, port=Port.L, loc=call_site.loc, ) wired_edges.append(source_to_tramp) # Wire: trampoline -> target (no ctx_override — INHERIT mode reads # the destination FrameDest from the frame, which already encodes # the function's act_id. CHANGE_TAG is wrong here because the left # operand is raw data, not a packed FrameDest.) tramp_to_target = IREdge( source=tramp_name, dest=target_label, port=Port.L, loc=call_site.loc, ) wired_edges.append(tramp_to_target) else: # No conflict — direct edge (no ctx_override — the destination # node's act_id in the FrameDest handles cross-context routing) input_edge = IREdge( source=source_name, dest=target_label, port=Port.L, loc=call_site.loc, ) wired_edges.append(input_edge) edge_name = f"{call_site.func_name}.__input_{call_id}_{param_name}" input_edge_names.append(edge_name) # Get @ret destinations for this function (pre-computed during expand setup) ret_destinations = set() if function_ret_destinations and call_site.func_name in function_ret_destinations: ret_destinations = function_ret_destinations[call_site.func_name] # For each @ret variant, create a per-call-site trampoline # (synthetic nodes are already created during expand pass setup) for ret_dest in ret_destinations: # Determine the synthetic node name: $func.@ret or $func.@ret_name synthetic_node_name = f"{call_site.func_name}.{ret_dest}" # Synthetic node should already exist from expand setup if synthetic_node_name not in processed_ret_nodes: # This shouldn't happen, but create it just in case synthetic_pass_node = IRNode( name=synthetic_node_name, opcode=RoutingOp.PASS, loc=call_site.loc, ) wired_nodes[synthetic_node_name] = synthetic_pass_node processed_ret_nodes.add(synthetic_node_name) # Create a per-call-site trampoline pass node trampoline_name = f"{call_site.func_name}.__ret_trampoline_{call_id}_{ret_dest[1:]}" trampoline_node = IRNode( name=trampoline_name, opcode=RoutingOp.PASS, dest_l=None, # Will be wired below dest_r=None, # Will be wired below loc=call_site.loc, ) wired_nodes[trampoline_name] = trampoline_node trampoline_nodes.append(trampoline_name) # Create edge from synthetic @ret node to trampoline ret_to_tramp_edge = IREdge( source=synthetic_node_name, dest=trampoline_name, port=Port.L, loc=call_site.loc, ) wired_edges.append(ret_to_tramp_edge) # Find the corresponding output destination from call_site.output_dests # output_dests is a flat tuple of dicts: each dict is either a named_output # {"name": "...", "ref": {...}} or positional_output {...} output_dest = None dest_name = f"@__unmatched_{ret_dest}" # Iterate directly over flattened output_dests all_outputs = list(call_site.output_dests) if call_site.output_dests else [] # Try to find named output matching ret_dest for output in all_outputs: if isinstance(output, dict): output_name = output.get("name") # ret_dest is "@ret_name", so we need to match "name" part (without @ prefix and "ret_" prefix) # Possible forms: @ret (bare), @ret_sum, @ret_carry, etc. expected_suffix = ret_dest[5:] if ret_dest.startswith("@ret_") else "" # "sum" from "@ret_sum" if output_name and output_name == expected_suffix: # Found named output output_ref = output.get("ref") if isinstance(output_ref, dict): dest_name = output_ref.get("name", "@__unmatched") else: dest_name = str(output_ref) break # If not found by name and this is @ret (bare), try positional mapping if dest_name.startswith("@__unmatched") and ret_dest == "@ret": for output in all_outputs: # Named outputs have "name" key that matches a @ret_name label # Positional outputs don't have this structure has_label_name = isinstance(output, dict) and "name" in output and "ref" in output if not has_label_name: # This is a positional output if isinstance(output, dict): # Positional output is stored as a ref dict with just "name" key dest_name = output.get("name", "@__unmatched") else: # Non-dict positional output dest_name = str(output) break # Wire trampoline dest_l to the caller's output destination with ctx_override=True tramp_to_output_edge = IREdge( source=trampoline_name, dest=dest_name, port=Port.L, source_port=Port.L, # Output from trampoline's L port ctx_override=True, loc=call_site.loc, ) wired_edges.append(tramp_to_output_edge) # Create a free_frame node (one per call site, not per @ret variant) # Wire it to trampoline's dest_r free_frame_name = f"{call_site.func_name}.__free_frame_{call_id}" if free_frame_name not in wired_nodes: # Only create once per call site free_frame_node = IRNode( name=free_frame_name, opcode=RoutingOp.FREE_FRAME, loc=call_site.loc, ) wired_nodes[free_frame_name] = free_frame_node free_frame_nodes.append(free_frame_name) # Wire trampoline dest_r to free_frame tramp_to_free_edge = IREdge( source=trampoline_name, dest=free_frame_name, port=Port.L, source_port=Port.R, # Output from trampoline's R port loc=call_site.loc, ) wired_edges.append(tramp_to_free_edge) # Create CallSite metadata call_site_metadata = CallSite( func_name=call_site.func_name, call_id=call_id, input_edges=tuple(input_edge_names), trampoline_nodes=tuple(trampoline_nodes), free_frame_nodes=tuple(free_frame_nodes), loc=call_site.loc, ) return call_site_metadata, errors def expand(graph: IRGraph) -> IRGraph: """Expand all macro calls in an IRGraph. The expand pass processes all MacroDef and IRMacroCall entries from lowering, substitutes parameters, qualifies names, and recursively expands nested macros. The output graph contains no macro definitions or invocation artefacts. Steps: 1. Collect all MacroDef entries into a macro_table 2. Recursively expand all IRMacroCall entries (depth limit 32) 3. For each call: validate arity, build substitution map, clone body, substitute params, qualify names, splice into output 4. Strip all MacroDef and IRMacroCall entries from output 5. Return new IRGraph with only concrete nodes/edges Args: graph: The IRGraph from the lower pass Returns: New IRGraph with all macros expanded and no macro artefacts """ # Collect all macro definitions into a table macro_table: dict[str, MacroDef] = {} for macro_def in graph.macro_defs: macro_table[macro_def.name] = macro_def # Initialize expansion counter expansion_counter: list[int] = [0] # Recursively expand the graph starting at top level expanded_graph, expansion_errors = _expand_graph_recursive( graph, macro_table, expansion_counter, builtin_line_offset=graph.builtin_line_offset, ) # Scan function regions to find all @ret destinations, create synthetic nodes, and track them synthetic_ret_nodes = {} # Map of synthetic_node_name -> IRNode function_ret_destinations = {} # Map of func_name -> set of @ret destinations new_regions = [] for region in expanded_graph.regions: if region.kind == RegionKind.FUNCTION: # Find all @ret destinations in function body edges ret_destinations = set() for edge in region.body.edges: if isinstance(edge.dest, str) and edge.dest.startswith("@ret"): ret_destinations.add(edge.dest) # Store the destinations for later use by _wire_call_site function_ret_destinations[region.tag] = ret_destinations # Create synthetic pass nodes for each @ret destination for ret_dest in ret_destinations: synthetic_node_name = f"{region.tag}.{ret_dest}" if synthetic_node_name not in synthetic_ret_nodes: synthetic_node = IRNode( name=synthetic_node_name, opcode=RoutingOp.PASS, ) synthetic_ret_nodes[synthetic_node_name] = synthetic_node # Update edges in function body to point to synthetic @ret nodes new_body_edges = [] for edge in region.body.edges: new_dest = edge.dest # If destination starts with @ret, replace with synthetic node if isinstance(edge.dest, str) and edge.dest.startswith("@ret"): synthetic_node_name = f"{region.tag}.{edge.dest}" new_dest = synthetic_node_name new_body_edges.append(replace(edge, dest=new_dest)) new_body = replace(region.body, edges=new_body_edges) new_region = replace(region, body=new_body) new_regions.append(new_region) else: new_regions.append(region) expanded_graph = replace(expanded_graph, regions=new_regions) # Add synthetic nodes to the top-level graph wired_nodes = dict(expanded_graph.nodes) wired_nodes.update(synthetic_ret_nodes) # Process function call sites wired_call_sites = [] call_site_errors = [] wired_edges = list(expanded_graph.edges) processed_ret_nodes = set(synthetic_ret_nodes.keys()) # Track which synthetic nodes were created call_id_counter = 0 for call_site_result in expanded_graph.raw_call_sites: call_site_metadata, errors = _wire_call_site( call_site_result, expanded_graph, call_id_counter, wired_nodes, wired_edges, processed_ret_nodes, function_ret_destinations, ) wired_call_sites.append(call_site_metadata) call_site_errors.extend(errors) call_id_counter += 1 # Create final graph with wired call sites final_graph = replace( expanded_graph, nodes=wired_nodes, edges=wired_edges, call_sites=wired_call_sites, raw_call_sites=(), # Clear raw call sites after processing ) # Accumulate all errors all_errors = list(graph.errors) + expansion_errors + call_site_errors # Return with error list updated return replace(final_graph, errors=all_errors) __all__ = ["expand"]