OR-1 dataflow CPU sketch
at main 1313 lines 51 kB view raw
1"""Macro expansion pass for the OR1 assembler. 2 3This module implements macro invocation expansion (Phase 2). It processes 4IRMacroCall entries from the lowering pass, expands them by cloning and 5substituting macro bodies, and qualifies expanded names with scope prefixes. 6 7The expand() function receives an IRGraph from lower, processes all macro 8definitions and invocations, and returns a clean IRGraph with all macro 9artefacts removed. 10""" 11 12from __future__ import annotations 13 14import ast 15from dataclasses import replace 16from typing import Optional 17from collections.abc import Iterable 18 19from asm.errors import AssemblyError, ErrorCategory 20from asm.ir import ( 21 IRGraph, IRNode, IREdge, IRRegion, RegionKind, ParamRef, ConstExpr, 22 MacroDef, IRMacroCall, CallSiteResult, CallSite, IRRepetitionBlock, SourceLoc, 23 PlacementRef, PortRef, ActSlotRef, ActSlotRange, 24) 25from asm.opcodes import MNEMONIC_TO_OP 26from cm_inst import Port, RoutingOp 27 28MAX_EXPANSION_DEPTH = 32 29 30 31def _levenshtein(a: str, b: str) -> int: 32 """Compute Levenshtein (edit) distance between two strings. 33 34 Note: This is duplicated from asm/resolve.py. If a third copy appears, 35 extract to a shared utility module. 36 37 Args: 38 a: First string 39 b: Second string 40 41 Returns: 42 Minimum edit distance (number of single-character edits) 43 """ 44 if len(a) < len(b): 45 return _levenshtein(b, a) 46 if not b: 47 return len(a) 48 49 prev = list(range(len(b) + 1)) 50 for i, ca in enumerate(a): 51 curr = [i + 1] 52 for j, cb in enumerate(b): 53 curr.append(min( 54 prev[j + 1] + 1, # deletion 55 curr[j] + 1, # insertion 56 prev[j] + (ca != cb), # substitution 57 )) 58 prev = curr 59 return prev[-1] 60 61 62def _suggest_names(unresolved: str, available_names: Iterable[str]) -> list[str]: 63 """Generate "did you mean" suggestions via Levenshtein distance. 64 65 Compares unresolved name against all available names, returning suggestions 66 with distance <= 3, or the closest match if all distances are > 3. 67 68 Args: 69 unresolved: The unresolved name 70 available_names: Iterable of available macro names 71 72 Returns: 73 List of suggestion strings (may be empty) 74 """ 75 if not available_names: 76 return [] 77 78 # Compute distances 79 candidates = [] 80 for name in available_names: 81 dist = _levenshtein(unresolved, name) 82 candidates.append((dist, name)) 83 84 # Sort by distance 85 candidates.sort(key=lambda x: x[0]) 86 87 # Return suggestions with distance <= 3, or best if all > 3 88 suggestions = [] 89 best_distance = candidates[0][0] 90 91 for dist, name in candidates: 92 if dist <= 3 or dist == best_distance: 93 suggestions.append(f"Did you mean '#{name}'?") 94 else: 95 break 96 97 return suggestions 98 99 100def _substitute_param( 101 value: object, 102 subst_map: dict[str, object], 103) -> object: 104 """Resolve a ParamRef or name against the substitution map. 105 106 Supports token pasting: ParamRef with prefix/suffix concatenates the 107 parameter value with the prefix and suffix to form a new name. 108 109 For const fields, returns the actual int value. 110 For names, returns the ref name string (possibly qualified). 111 112 Args: 113 value: The value to substitute (could be ParamRef, int, str, etc.) 114 subst_map: Map of formal param names to actual argument values 115 116 Returns: 117 The substituted value, or unchanged if not a ParamRef/param name. 118 If ParamRef has prefix/suffix, returns concatenated string. 119 """ 120 if isinstance(value, ParamRef): 121 # Look up the parameter in the substitution map 122 actual = subst_map.get(value.param) 123 if actual is not None: 124 # Extract name from dict refs (e.g., {"name": "&x"} -> "&x") 125 if isinstance(actual, dict) and "name" in actual: 126 actual = actual["name"] 127 # Handle token pasting with prefix/suffix 128 if value.prefix or value.suffix: 129 # Convert actual value to string if it's an int 130 actual_str = str(actual) if isinstance(actual, int) else actual 131 # Concatenate: prefix + value + suffix 132 return value.prefix + actual_str + value.suffix 133 else: 134 # No prefix/suffix: return the actual value as-is 135 return actual 136 # Parameter not found - return unchanged (should not happen with proper validation) 137 return value 138 139 # For string names, check if they match a formal parameter 140 if isinstance(value, str): 141 # Don't substitute sigil-prefixed names (they may be qualified later) 142 if value and value[0] in "&@$#": 143 return value 144 # Check if this name is a formal parameter 145 if value in subst_map: 146 return subst_map[value] 147 148 return value 149 150 151def _eval_node(node, bindings: dict[str, int]) -> int: 152 """Evaluate a single AST node in a constant expression. 153 154 Args: 155 node: An ast node (Constant, Name, BinOp, or UnaryOp) 156 bindings: Map of parameter names to integer values 157 158 Returns: 159 The evaluated integer result 160 161 Raises: 162 ValueError: If node type is unsupported or value is non-numeric 163 """ 164 if isinstance(node, ast.Constant) and isinstance(node.value, int): 165 return node.value 166 elif isinstance(node, ast.Name): 167 if node.id not in bindings: 168 raise ValueError(f"Undefined parameter: {node.id}") 169 val = bindings[node.id] 170 if not isinstance(val, int): 171 raise ValueError(f"Non-numeric value in arithmetic context") 172 return val 173 elif isinstance(node, ast.BinOp): 174 left = _eval_node(node.left, bindings) 175 right = _eval_node(node.right, bindings) 176 if isinstance(node.op, ast.Add): 177 return left + right 178 elif isinstance(node.op, ast.Sub): 179 return left - right 180 elif isinstance(node.op, ast.Mult): 181 return left * right 182 elif isinstance(node.op, ast.FloorDiv): 183 if right == 0: 184 raise ValueError("division by zero") 185 return left // right 186 else: 187 raise ValueError(f"Unsupported operator: {type(node.op).__name__}") 188 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub): 189 return -_eval_node(node.operand, bindings) 190 else: 191 raise ValueError(f"Unsupported expression node: {type(node).__name__}") 192 193 194def _eval_const_expr(expr: str, bindings: dict[str, int]) -> int: 195 """Evaluate a simple arithmetic expression with parameter bindings. 196 197 Supports: integer literals, +, -, *, // (integer division), parentheses. 198 No eval() call — safe AST walking only. 199 200 Args: 201 expr: Expression string, e.g. "base + 1" 202 bindings: Map of parameter names to integer values 203 204 Returns: 205 The evaluated integer result 206 207 Raises: 208 ValueError: If expression is invalid or contains non-numeric values 209 """ 210 tree = ast.parse(expr, mode='eval') 211 return _eval_node(tree.body, bindings) 212 213 214def _qualify_expanded_name( 215 name: str, 216 macro_scope: str, 217 parent_scope: str = "", 218 func_scope: Optional[str] = None, 219) -> str: 220 """Apply scope prefix to an expanded name. 221 222 Takes a name and applies macro and optional function scopes. 223 Names starting with & are qualified; other sigils pass through. 224 225 Args: 226 name: The original name from the macro body 227 macro_scope: The macro scope (e.g., "#loop_counted_0") 228 parent_scope: Optional parent macro scope (e.g., "#outer_0") 229 func_scope: Optional function scope (e.g., "$main") 230 231 Returns: 232 The qualified name 233 """ 234 if not name: 235 return name 236 237 # Check if it's a label (starts with &) 238 if name.startswith("&"): 239 # Build full scope: [func_scope.][parent_scope.]macro_scope.name 240 if func_scope and parent_scope: 241 # Triple-scoped: $func.#parent_N.#macro_M.&label 242 return f"{func_scope}.{parent_scope}.{macro_scope}.{name}" 243 elif parent_scope: 244 # Double-scoped: #parent_N.#macro_M.&label 245 return f"{parent_scope}.{macro_scope}.{name}" 246 elif func_scope: 247 # Double-scoped: $func.#macro_N.&label 248 return f"{func_scope}.{macro_scope}.{name}" 249 else: 250 # Single-scoped: #macro_N.&label 251 return f"{macro_scope}.{name}" 252 253 # Other sigils (@, $, #) pass through unqualified 254 return name 255 256 257def _clone_and_substitute_node( 258 node: IRNode, 259 macro_scope: str, 260 subst_map: dict[str, object], 261 func_scope: Optional[str] = None, 262 parent_scope: str = "", 263) -> tuple[IRNode, list[AssemblyError]]: 264 """Deep-clone a node and substitute parameters. 265 266 Args: 267 node: The template node from the macro body 268 macro_scope: The macro scope for qualification 269 subst_map: Map of formal params to actual arguments 270 func_scope: Optional function scope 271 parent_scope: Optional parent macro scope 272 273 Returns: 274 Tuple of (new IRNode with substitutions applied and name qualified, errors list) 275 """ 276 errors = [] 277 278 # Substitute the const field 279 new_const = _substitute_param(node.const, subst_map) 280 281 # If const is a ConstExpr, evaluate it 282 if isinstance(new_const, ConstExpr): 283 try: 284 # Build bindings dict from subst_map, converting all to int 285 bindings = {} 286 for param_name in new_const.params: 287 if param_name in subst_map: 288 val = subst_map[param_name] 289 if not isinstance(val, int): 290 errors.append(AssemblyError( 291 loc=new_const.loc, 292 category=ErrorCategory.VALUE, 293 message=f"Non-numeric value '{val}' in arithmetic context", 294 )) 295 # Return node with ConstExpr unchanged (will be caught later) 296 substituted_name = _substitute_param(node.name, subst_map) 297 if not isinstance(substituted_name, str): 298 substituted_name = str(substituted_name) 299 new_name = _qualify_expanded_name(substituted_name, macro_scope, parent_scope, func_scope) 300 return replace(node, name=new_name, const=new_const), errors 301 bindings[param_name] = val 302 # Evaluate the expression 303 evaluated = _eval_const_expr(new_const.expression, bindings) 304 new_const = evaluated 305 except ValueError as e: 306 errors.append(AssemblyError( 307 loc=new_const.loc, 308 category=ErrorCategory.VALUE, 309 message=str(e), 310 )) 311 312 # Resolve opcode if it's a ParamRef 313 new_opcode = node.opcode 314 if isinstance(new_opcode, ParamRef): 315 resolved = _substitute_param(new_opcode, subst_map) 316 if isinstance(resolved, str): 317 if resolved in MNEMONIC_TO_OP: 318 new_opcode = MNEMONIC_TO_OP[resolved] 319 else: 320 errors.append(AssemblyError( 321 loc=node.loc, 322 category=ErrorCategory.MACRO, 323 message=f"'{resolved}' is not a valid opcode mnemonic", 324 )) 325 new_opcode = node.opcode 326 else: 327 errors.append(AssemblyError( 328 loc=node.loc, 329 category=ErrorCategory.MACRO, 330 message=f"opcode parameter must resolve to an opcode mnemonic, got {type(resolved).__name__}", 331 )) 332 new_opcode = node.opcode 333 334 # Resolve placement if it's a PlacementRef 335 new_pe = node.pe 336 if isinstance(new_pe, PlacementRef): 337 resolved = _substitute_param(new_pe.param, subst_map) 338 if isinstance(resolved, str) and resolved.startswith("pe"): 339 try: 340 new_pe = int(resolved[2:]) 341 except ValueError: 342 errors.append(AssemblyError( 343 loc=node.loc, 344 category=ErrorCategory.MACRO, 345 message=f"placement parameter must resolve to 'peN', got '{resolved}'", 346 )) 347 new_pe = None 348 elif isinstance(resolved, int): 349 new_pe = resolved 350 else: 351 errors.append(AssemblyError( 352 loc=node.loc, 353 category=ErrorCategory.MACRO, 354 message=f"placement parameter must resolve to 'peN', got {type(resolved).__name__}", 355 )) 356 new_pe = None 357 358 # Resolve act_slot if it's a ActSlotRef 359 new_act_slot = node.act_slot 360 if isinstance(new_act_slot, ActSlotRef): 361 resolved = _substitute_param(new_act_slot.param, subst_map) 362 if isinstance(resolved, int): 363 new_act_slot = ActSlotRange(start=resolved, end=resolved) 364 else: 365 errors.append(AssemblyError( 366 loc=node.loc, 367 category=ErrorCategory.MACRO, 368 message=f"act_slot parameter must resolve to an integer, got {type(resolved).__name__}", 369 )) 370 new_act_slot = None 371 372 # Substitute the node name (may be a ParamRef with token pasting) 373 substituted_name = _substitute_param(node.name, subst_map) 374 375 # Ensure name is a string before qualification 376 if not isinstance(substituted_name, str): 377 substituted_name = str(substituted_name) 378 379 # Qualify the node name 380 new_name = _qualify_expanded_name(substituted_name, macro_scope, parent_scope, func_scope) 381 382 return replace(node, name=new_name, const=new_const, opcode=new_opcode, 383 pe=new_pe, act_slot=new_act_slot), errors 384 385 386def _clone_and_substitute_edge( 387 edge: IREdge, 388 macro_scope: str, 389 subst_map: dict[str, object], 390 func_scope: Optional[str] = None, 391 parent_scope: str = "", 392) -> tuple[IREdge, list[AssemblyError]]: 393 """Deep-clone an edge and substitute/qualify names. 394 395 Args: 396 edge: The template edge from the macro body 397 macro_scope: The macro scope for qualification 398 subst_map: Map of formal params to actual arguments 399 func_scope: Optional function scope 400 parent_scope: Optional parent macro scope 401 402 Returns: 403 Tuple of (new IREdge with names qualified, list of errors) 404 """ 405 errors: list[AssemblyError] = [] 406 # Substitute source and dest names. 407 # Track whether each was a ParamRef — substituted refs are external 408 # and must NOT be qualified with the macro scope. 409 source_was_param = isinstance(edge.source, ParamRef) 410 source = _substitute_param(edge.source, subst_map) 411 if not isinstance(source, str): 412 source = str(source) 413 414 dest_was_param = isinstance(edge.dest, ParamRef) 415 dest = _substitute_param(edge.dest, subst_map) 416 if not isinstance(dest, str): 417 dest = str(dest) 418 419 # Only qualify names that came from the macro body template directly. 420 # Substituted parameter refs point to external names and stay unqualified. 421 if not source_was_param: 422 source = _qualify_expanded_name(source, macro_scope, parent_scope, func_scope) 423 if not dest_was_param: 424 dest = _qualify_expanded_name(dest, macro_scope, parent_scope, func_scope) 425 426 # Resolve PortRef on dest port 427 new_port = edge.port 428 if isinstance(new_port, PortRef): 429 resolved = _substitute_param(new_port.param, subst_map) 430 if isinstance(resolved, str): 431 if resolved == "L": 432 new_port = Port.L 433 elif resolved == "R": 434 new_port = Port.R 435 else: 436 errors.append(AssemblyError( 437 loc=edge.loc, 438 category=ErrorCategory.MACRO, 439 message=f"port parameter must resolve to 'L' or 'R', got '{resolved}'", 440 )) 441 new_port = Port.L 442 elif isinstance(resolved, Port): 443 new_port = resolved 444 else: 445 errors.append(AssemblyError( 446 loc=edge.loc, 447 category=ErrorCategory.MACRO, 448 message=f"port parameter must resolve to 'L' or 'R', got '{resolved}'", 449 )) 450 new_port = Port.L 451 452 # Resolve PortRef on source port 453 new_source_port = edge.source_port 454 if isinstance(new_source_port, PortRef): 455 resolved = _substitute_param(new_source_port.param, subst_map) 456 if isinstance(resolved, str): 457 if resolved == "L": 458 new_source_port = Port.L 459 elif resolved == "R": 460 new_source_port = Port.R 461 else: 462 errors.append(AssemblyError( 463 loc=edge.loc, 464 category=ErrorCategory.MACRO, 465 message=f"source port parameter must resolve to 'L' or 'R', got '{resolved}'", 466 )) 467 new_source_port = None 468 elif isinstance(resolved, Port): 469 new_source_port = resolved 470 else: 471 errors.append(AssemblyError( 472 loc=edge.loc, 473 category=ErrorCategory.MACRO, 474 message=f"source port parameter must resolve to 'L' or 'R', got '{resolved}'", 475 )) 476 new_source_port = None 477 478 return replace(edge, source=source, dest=dest, port=new_port, source_port=new_source_port), errors 479 480 481def _add_expansion_context( 482 error: AssemblyError, 483 call: IRMacroCall, 484 builtin_line_offset: int = 0, 485) -> AssemblyError: 486 """Add expansion context to an error. 487 488 Appends "expanded from #macro_name at line N, column C" to the 489 context_lines to trace the error back to the macro invocation site. 490 491 Args: 492 error: The error to enhance 493 call: The IRMacroCall being expanded 494 builtin_line_offset: Lines to subtract for display (built-in macro prefix) 495 496 Returns: 497 New AssemblyError with expansion context added to context_lines 498 """ 499 display_line = call.loc.line 500 if builtin_line_offset > 0 and display_line > builtin_line_offset: 501 display_line -= builtin_line_offset 502 expansion_context = ( 503 f"expanded from #{call.name} at line {display_line}, " 504 f"column {call.loc.column}" 505 ) 506 return replace( 507 error, 508 context_lines=list(error.context_lines) + [expansion_context], 509 ) 510 511 512def _expand_repetition_block( 513 rep_block: IRRepetitionBlock, 514 variadic_args: list[object], 515 macro_scope: str, 516 subst_map: dict[str, object], 517 func_scope: Optional[str] = None, 518 parent_scope: str = "", 519) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]: 520 """Expand a repetition block once per variadic argument. 521 522 For each iteration, clones the body, substitutes the variadic param 523 and ${_idx} (iteration index), and qualifies names. 524 525 Args: 526 rep_block: The IRRepetitionBlock to expand 527 variadic_args: List of actual arguments for the variadic parameter 528 macro_scope: The macro scope for qualification 529 subst_map: Base substitution map (will be extended with variadic param and _idx) 530 func_scope: Optional function scope 531 parent_scope: Optional parent macro scope 532 533 Returns: 534 Tuple of (expanded_nodes dict, expanded_edges list, errors list) 535 """ 536 errors = [] 537 expanded_nodes: dict[str, IRNode] = {} 538 expanded_edges: list[IREdge] = [] 539 540 # Iterate over variadic arguments 541 for idx, arg_value in enumerate(variadic_args): 542 # Create iteration-specific substitution map 543 iter_subst_map = dict(subst_map) 544 iter_subst_map[rep_block.variadic_param] = arg_value 545 iter_subst_map["_idx"] = idx # Make iteration index available as parameter 546 547 # Clone and substitute nodes from the repetition body 548 for node_name, node in rep_block.body.nodes.items(): 549 # Create a unique name for this iteration 550 # Qualify the node name with macro scope and iteration suffix 551 qualified_node, node_errors = _clone_and_substitute_node( 552 node, 553 f"{macro_scope}_rep{idx}", 554 iter_subst_map, 555 func_scope, 556 parent_scope, 557 ) 558 errors.extend(node_errors) 559 expanded_nodes[qualified_node.name] = qualified_node 560 561 # Clone and substitute edges from the repetition body 562 for edge in rep_block.body.edges: 563 qualified_edge, edge_errors = _clone_and_substitute_edge( 564 edge, 565 f"{macro_scope}_rep{idx}", 566 iter_subst_map, 567 func_scope, 568 parent_scope, 569 ) 570 errors.extend(edge_errors) 571 expanded_edges.append(qualified_edge) 572 573 return expanded_nodes, expanded_edges, errors 574 575 576def _expand_call( 577 call: IRMacroCall, 578 macro_table: dict[str, MacroDef], 579 expansion_counter: list[int], 580 func_scope: Optional[str] = None, 581 parent_scope: str = "", 582 depth: int = 0, 583 builtin_line_offset: int = 0, 584) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]: 585 """Process a single macro call. 586 587 Looks up the macro, validates arity, builds substitution map, clones 588 the body, and performs parameter substitution and name qualification. 589 590 Args: 591 call: The IRMacroCall to expand 592 macro_table: Map of macro names to MacroDef objects 593 expansion_counter: [int] list for mutable counter (incremented per expansion) 594 func_scope: Optional function scope the call is in 595 parent_scope: Optional parent macro scope (for nested macros) 596 depth: Recursion depth (error if exceeds 32) 597 builtin_line_offset: Lines to subtract for display in error context 598 599 Returns: 600 Tuple of (expanded_nodes dict, expanded_edges list, errors list) 601 """ 602 errors = [] 603 604 # Check depth limit 605 if depth > MAX_EXPANSION_DEPTH: 606 error = AssemblyError( 607 loc=call.loc, 608 category=ErrorCategory.MACRO, 609 message=f"macro expansion depth exceeds {MAX_EXPANSION_DEPTH} (likely infinite recursion in macro '{call.name}')", 610 ) 611 return {}, [], [error] 612 613 # Look up macro definition 614 if call.name not in macro_table: 615 suggestions = _suggest_names(call.name, macro_table.keys()) 616 error = AssemblyError( 617 loc=call.loc, 618 category=ErrorCategory.MACRO, 619 message=f"undefined macro '#{call.name}'", 620 suggestions=suggestions, 621 ) 622 return {}, [], [error] 623 624 macro_def = macro_table[call.name] 625 626 # Validate arity and separate variadic arguments 627 total_args = len(call.positional_args) + len(call.named_args) 628 629 # Count required parameters (non-variadic) 630 required_params = [p for p in macro_def.params if not p.variadic] 631 variadic_param = next((p for p in macro_def.params if p.variadic), None) 632 633 if variadic_param: 634 # With variadic: need at least as many args as required params 635 if total_args < len(required_params): 636 error = AssemblyError( 637 loc=call.loc, 638 category=ErrorCategory.MACRO, 639 message=f"macro '#{call.name}' expects at least {len(required_params)} argument(s), got {total_args}", 640 ) 641 return {}, [], [error] 642 else: 643 # Without variadic: exact match required 644 expected_count = len(macro_def.params) 645 if total_args != expected_count: 646 error = AssemblyError( 647 loc=call.loc, 648 category=ErrorCategory.MACRO, 649 message=f"macro '#{call.name}' expects {expected_count} argument(s), got {total_args}", 650 ) 651 return {}, [], [error] 652 653 # Build substitution map 654 subst_map: dict[str, object] = {} 655 variadic_args: list[object] = [] 656 657 # Add positional arguments 658 for i, actual_value in enumerate(call.positional_args): 659 if i < len(required_params): 660 # Regular parameter 661 param_name = required_params[i].name 662 subst_map[param_name] = actual_value 663 elif variadic_param: 664 # Extra arguments go to variadic parameter 665 variadic_args.append(actual_value) 666 667 # Add named arguments (to required params only; named variadic args not supported) 668 for param_name, actual_value in call.named_args: 669 subst_map[param_name] = actual_value 670 671 # Generate unique macro scope 672 expansion_id = expansion_counter[0] 673 expansion_counter[0] += 1 674 macro_scope = f"#{call.name}_{expansion_id}" 675 676 # Recursively expand and qualify the macro body, including nested calls 677 def _expand_body_recursive( 678 body: IRGraph, 679 depth: int, 680 ) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]: 681 """Recursively expand all macro calls in a body graph and its regions.""" 682 body_errors: list[AssemblyError] = [] 683 body_nodes: dict[str, IRNode] = {} 684 body_edges: list[IREdge] = [] 685 686 # Qualify and add the body's own nodes 687 for node_name, node in body.nodes.items(): 688 qualified_node, node_errors = _clone_and_substitute_node(node, macro_scope, subst_map, func_scope, parent_scope) 689 # Add expansion context to node-level errors (const expression evaluation, etc.) 690 for error in node_errors: 691 body_errors.append(_add_expansion_context(error, call, builtin_line_offset)) 692 body_nodes[qualified_node.name] = qualified_node 693 694 # Qualify the body's own edges 695 for edge in body.edges: 696 qualified_edge, edge_errors = _clone_and_substitute_edge(edge, macro_scope, subst_map, func_scope, parent_scope) 697 body_errors.extend(edge_errors) 698 body_edges.append(qualified_edge) 699 700 # Expand macro calls at this body level 701 # Nested calls have current macro_scope as their parent_scope 702 for nested_call in body.macro_calls: 703 nested_expanded_nodes, nested_expanded_edges, nested_errors = _expand_call( 704 nested_call, 705 macro_table, 706 expansion_counter, 707 func_scope, 708 macro_scope, # Current macro scope becomes parent for nested 709 depth + 1, 710 builtin_line_offset, 711 ) 712 # Add expansion context to nested errors (trace them back to the nested call) 713 for error in nested_errors: 714 body_errors.append(_add_expansion_context(error, nested_call, builtin_line_offset)) 715 body_nodes.update(nested_expanded_nodes) 716 # Filter out leaked @ret edges from failed inner expansions to prevent 717 # spurious "defines output(s) @ret" errors at the outer macro level 718 for nested_edge in nested_expanded_edges: 719 if isinstance(nested_edge.dest, str) and nested_edge.dest.startswith("@ret"): 720 continue 721 body_edges.append(nested_edge) 722 723 # Expand repetition blocks (Phase 6 variadic macros) 724 if variadic_param: 725 for rep_block in macro_def.repetition_blocks: 726 # Only expand blocks for the current variadic parameter 727 if rep_block.variadic_param == variadic_param.name: 728 rep_nodes, rep_edges, rep_errors = _expand_repetition_block( 729 rep_block, 730 variadic_args, 731 macro_scope, 732 subst_map, 733 func_scope, 734 parent_scope, 735 ) 736 body_errors.extend(rep_errors) 737 body_nodes.update(rep_nodes) 738 body_edges.extend(rep_edges) 739 740 # Recursively expand regions in the body 741 for region in body.regions: 742 region_func_scope = region.tag if region.kind == RegionKind.FUNCTION else func_scope 743 region_nodes, region_edges, region_errors = _expand_body_recursive( 744 region.body, 745 depth + 1, 746 ) 747 body_errors.extend(region_errors) 748 body_nodes.update(region_nodes) 749 body_edges.extend(region_edges) 750 751 return body_nodes, body_edges, body_errors 752 753 expanded_nodes, expanded_edges, nested_errors = _expand_body_recursive( 754 macro_def.body, 755 depth, 756 ) 757 errors.extend(nested_errors) 758 759 # Rewrite @ret edges: either substitute with output destinations, or 760 # report an error if the macro body uses @ret but the call site doesn't 761 # provide output wiring. 762 has_ret_edges = any( 763 isinstance(e.dest, str) and e.dest.startswith("@ret") 764 for e in expanded_edges 765 ) 766 767 if has_ret_edges and not call.output_dests: 768 # Collect the @ret markers for the error message 769 ret_markers = sorted({ 770 e.dest for e in expanded_edges 771 if isinstance(e.dest, str) and e.dest.startswith("@ret") 772 }) 773 errors.append(AssemblyError( 774 loc=call.loc, 775 category=ErrorCategory.MACRO, 776 message=f"macro '#{call.name}' defines output(s) {', '.join(ret_markers)} but call site has no '|>' output wiring", 777 )) 778 779 if call.output_dests: 780 rewritten_edges = [] 781 all_outputs = list(call.output_dests) 782 783 # Build ordered list of positional outputs for bare @ret resolution. 784 # Each bare @ret consumes the next positional output in order, 785 # enabling variadic macros to wire each iteration to a separate dest: 786 # #macro *vals |> { $( &c <| const, ${vals}; &c |> @ret ),* } 787 # #macro 3, 4 |> &x, &y ← iteration 0 → &x, iteration 1 → &y 788 positional_outputs: list[str] = [] 789 for output in all_outputs: 790 if isinstance(output, dict): 791 if "name" in output and "ref" in output: 792 continue # Named output — not positional 793 name = output.get("name", None) 794 if name is not None: 795 positional_outputs.append(name) 796 else: 797 positional_outputs.append(str(output)) 798 positional_idx = 0 799 800 for edge in expanded_edges: 801 if not (isinstance(edge.dest, str) and edge.dest.startswith("@ret")): 802 rewritten_edges.append(edge) 803 continue 804 805 # This edge targets an @ret marker — resolve it 806 ret_dest = edge.dest # e.g. "@ret" or "@ret_body" 807 dest_name = None 808 809 # Try named match: @ret_body -> output with name="body" 810 if ret_dest.startswith("@ret_"): 811 expected_suffix = ret_dest[5:] # "body" from "@ret_body" 812 for output in all_outputs: 813 if isinstance(output, dict) and "name" in output and "ref" in output: 814 if output["name"] == expected_suffix: 815 ref = output["ref"] 816 dest_name = ref.get("name", ref) if isinstance(ref, dict) else str(ref) 817 break 818 819 # Bare @ret -> next positional output (advances counter) 820 if dest_name is None and ret_dest == "@ret": 821 if positional_idx < len(positional_outputs): 822 dest_name = positional_outputs[positional_idx] 823 positional_idx += 1 824 825 if dest_name is None: 826 errors.append(AssemblyError( 827 loc=call.loc, 828 category=ErrorCategory.MACRO, 829 message=f"macro '#{call.name}' has output marker '{ret_dest}' but no matching output destination in call site", 830 )) 831 rewritten_edges.append(edge) 832 continue 833 834 # Replace the @ret destination with the concrete node reference 835 rewritten_edges.append(replace(edge, dest=dest_name)) 836 837 expanded_edges = rewritten_edges 838 839 # Propagate errors from macro body template 840 for body_error in macro_def.body.errors: 841 # Adjust source location to point to the call site 842 adjusted_error = replace( 843 body_error, 844 loc=call.loc, 845 suggestions=list(body_error.suggestions) + [ 846 f"defined in macro #{macro_def.name} at line {macro_def.loc.line}" 847 ], 848 ) 849 errors.append(adjusted_error) 850 851 return expanded_nodes, expanded_edges, errors 852 853 854def _expand_graph_recursive( 855 graph: IRGraph, 856 macro_table: dict[str, MacroDef], 857 expansion_counter: list[int], 858 func_scope: Optional[str] = None, 859 builtin_line_offset: int = 0, 860) -> tuple[IRGraph, list[AssemblyError]]: 861 """Recursively expand macros in a graph and its regions. 862 863 Args: 864 graph: The IRGraph to expand 865 macro_table: Map of macro names to MacroDef objects 866 expansion_counter: [int] list for mutable counter 867 func_scope: Optional function scope for name qualification 868 builtin_line_offset: Lines to subtract for display in error context 869 870 Returns: 871 Tuple of (new_graph, all_errors) 872 """ 873 new_errors: list[AssemblyError] = [] 874 expanded_nodes: dict[str, IRNode] = dict(graph.nodes) 875 expanded_edges: list[IREdge] = list(graph.edges) 876 877 # Collect all macro calls from this graph level 878 # Note: The lower pass doesn't populate macro_calls in regions, 879 # so we also need to collect from macro_calls in the graph 880 all_calls_at_level = list(graph.macro_calls) 881 882 # Expand all macro calls at this level 883 for call in all_calls_at_level: 884 # Determine the function scope for nested calls 885 call_func_scope = func_scope 886 call_expanded_nodes, call_expanded_edges, call_errors = _expand_call( 887 call, 888 macro_table, 889 expansion_counter, 890 call_func_scope, 891 "", # No parent scope at top level 892 builtin_line_offset=builtin_line_offset, 893 ) 894 for error in call_errors: 895 new_errors.append(_add_expansion_context(error, call, builtin_line_offset)) 896 expanded_nodes.update(call_expanded_nodes) 897 expanded_edges.extend(call_expanded_edges) 898 899 # Recursively expand regions (function bodies, etc.) 900 new_regions: list[IRRegion] = [] 901 for region in graph.regions: 902 # For function regions, pass the region tag as the func_scope for name qualification 903 region_func_scope = region.tag if region.kind == RegionKind.FUNCTION else func_scope 904 new_body, region_errors = _expand_graph_recursive( 905 region.body, 906 macro_table, 907 expansion_counter, 908 region_func_scope, 909 builtin_line_offset, 910 ) 911 new_errors.extend(region_errors) 912 new_region = replace(region, body=new_body) 913 new_regions.append(new_region) 914 915 # Create new graph with expanded content and no macro artefacts 916 new_graph = replace( 917 graph, 918 nodes=expanded_nodes, 919 edges=expanded_edges, 920 regions=new_regions, 921 macro_defs=[], # Remove all macro defs 922 macro_calls=[], # Remove all macro calls 923 ) 924 925 return new_graph, new_errors 926 927 928def _wire_call_site( 929 call_site: CallSiteResult, 930 graph: IRGraph, 931 call_id: int, 932 wired_nodes: dict[str, IRNode], 933 wired_edges: list[IREdge], 934 processed_ret_nodes: set, 935 function_ret_destinations: dict[str, set], 936) -> tuple[CallSite, list[AssemblyError]]: 937 """Process a single function call site and wire it into the graph. 938 939 This function: 940 1. Finds the function definition in the graph's regions 941 2. Matches input arguments to function labels 942 3. Synthesises @ret rendezvous nodes (shared across call sites) 943 4. Creates per-call-site trampolines and free_frame nodes 944 5. Wires everything together with ctx_override edges 945 946 Args: 947 call_site: The CallSiteResult from the lower pass 948 graph: The IRGraph containing regions (functions) 949 call_id: Unique ID for this call site 950 wired_nodes: Dictionary to accumulate generated nodes 951 wired_edges: List to accumulate generated edges 952 processed_ret_nodes: Cache of already-synthesised @ret nodes (func_name.@ret -> node_name) 953 954 Returns: 955 Tuple of (CallSite metadata, errors list) 956 """ 957 errors = [] 958 959 # Find the function definition in the graph's regions 960 func_region = None 961 for region in graph.regions: 962 if region.kind == RegionKind.FUNCTION and region.tag == call_site.func_name: 963 func_region = region 964 break 965 966 if func_region is None: 967 error = AssemblyError( 968 loc=call_site.loc, 969 category=ErrorCategory.CALL, 970 message=f"undefined function '{call_site.func_name}'", 971 ) 972 return CallSite( 973 func_name=call_site.func_name, 974 call_id=call_id, 975 ), [error] 976 977 # Collect all nodes in the function body (including nested regions) 978 func_all_nodes = {} 979 func_all_edges = [] 980 981 def _collect_from_region(r: IRGraph): 982 func_all_nodes.update(r.nodes) 983 func_all_edges.extend(r.edges) 984 for sub_region in r.regions: 985 _collect_from_region(sub_region.body) 986 987 _collect_from_region(func_region.body) 988 989 input_edge_names = [] 990 trampoline_nodes = [] 991 free_frame_nodes = [] 992 993 # Process input arguments: match each to a label in the function 994 for param_name, source_ref in call_site.input_args: 995 # source_ref may be a dict with {"name": "..."} or a simple string 996 if isinstance(source_ref, dict): 997 source_name = source_ref.get("name", str(source_ref)) 998 else: 999 source_name = str(source_ref) 1000 1001 # Look for a label &param_name in the function 1002 target_label = f"{call_site.func_name}.&{param_name}" 1003 1004 if target_label not in func_all_nodes: 1005 error = AssemblyError( 1006 loc=call_site.loc, 1007 category=ErrorCategory.CALL, 1008 message=f"argument '{param_name}' does not match any label in '{call_site.func_name}'", 1009 ) 1010 errors.append(error) 1011 continue 1012 1013 # Check if source node has a const (AC5.3: const+CTX_OVRD conflict) 1014 # If so, insert a pass trampoline between source and target 1015 source_node = graph.nodes.get(source_name) 1016 if source_node is not None and source_node.const is not None: 1017 # Insert a pass trampoline to separate const from ctx_override 1018 tramp_name = f"{call_site.func_name}.__input_tramp_{call_id}_{param_name}" 1019 tramp_node = IRNode( 1020 name=tramp_name, 1021 opcode=RoutingOp.PASS, 1022 loc=call_site.loc, 1023 ) 1024 wired_nodes[tramp_name] = tramp_node 1025 1026 # Wire: source -> trampoline (no ctx_override, inherits ctx) 1027 source_to_tramp = IREdge( 1028 source=source_name, 1029 dest=tramp_name, 1030 port=Port.L, 1031 loc=call_site.loc, 1032 ) 1033 wired_edges.append(source_to_tramp) 1034 1035 # Wire: trampoline -> target (no ctx_override — INHERIT mode reads 1036 # the destination FrameDest from the frame, which already encodes 1037 # the function's act_id. CHANGE_TAG is wrong here because the left 1038 # operand is raw data, not a packed FrameDest.) 1039 tramp_to_target = IREdge( 1040 source=tramp_name, 1041 dest=target_label, 1042 port=Port.L, 1043 loc=call_site.loc, 1044 ) 1045 wired_edges.append(tramp_to_target) 1046 else: 1047 # No conflict — direct edge (no ctx_override — the destination 1048 # node's act_id in the FrameDest handles cross-context routing) 1049 input_edge = IREdge( 1050 source=source_name, 1051 dest=target_label, 1052 port=Port.L, 1053 loc=call_site.loc, 1054 ) 1055 wired_edges.append(input_edge) 1056 1057 edge_name = f"{call_site.func_name}.__input_{call_id}_{param_name}" 1058 input_edge_names.append(edge_name) 1059 1060 # Get @ret destinations for this function (pre-computed during expand setup) 1061 ret_destinations = set() 1062 if function_ret_destinations and call_site.func_name in function_ret_destinations: 1063 ret_destinations = function_ret_destinations[call_site.func_name] 1064 1065 # For each @ret variant, create a per-call-site trampoline 1066 # (synthetic nodes are already created during expand pass setup) 1067 for ret_dest in ret_destinations: 1068 # Determine the synthetic node name: $func.@ret or $func.@ret_name 1069 synthetic_node_name = f"{call_site.func_name}.{ret_dest}" 1070 1071 # Synthetic node should already exist from expand setup 1072 if synthetic_node_name not in processed_ret_nodes: 1073 # This shouldn't happen, but create it just in case 1074 synthetic_pass_node = IRNode( 1075 name=synthetic_node_name, 1076 opcode=RoutingOp.PASS, 1077 loc=call_site.loc, 1078 ) 1079 wired_nodes[synthetic_node_name] = synthetic_pass_node 1080 processed_ret_nodes.add(synthetic_node_name) 1081 1082 # Create a per-call-site trampoline pass node 1083 trampoline_name = f"{call_site.func_name}.__ret_trampoline_{call_id}_{ret_dest[1:]}" 1084 trampoline_node = IRNode( 1085 name=trampoline_name, 1086 opcode=RoutingOp.PASS, 1087 dest_l=None, # Will be wired below 1088 dest_r=None, # Will be wired below 1089 loc=call_site.loc, 1090 ) 1091 wired_nodes[trampoline_name] = trampoline_node 1092 trampoline_nodes.append(trampoline_name) 1093 1094 # Create edge from synthetic @ret node to trampoline 1095 ret_to_tramp_edge = IREdge( 1096 source=synthetic_node_name, 1097 dest=trampoline_name, 1098 port=Port.L, 1099 loc=call_site.loc, 1100 ) 1101 wired_edges.append(ret_to_tramp_edge) 1102 1103 # Find the corresponding output destination from call_site.output_dests 1104 # output_dests is a flat tuple of dicts: each dict is either a named_output 1105 # {"name": "...", "ref": {...}} or positional_output {...} 1106 output_dest = None 1107 dest_name = f"@__unmatched_{ret_dest}" 1108 1109 # Iterate directly over flattened output_dests 1110 all_outputs = list(call_site.output_dests) if call_site.output_dests else [] 1111 1112 # Try to find named output matching ret_dest 1113 for output in all_outputs: 1114 if isinstance(output, dict): 1115 output_name = output.get("name") 1116 # ret_dest is "@ret_name", so we need to match "name" part (without @ prefix and "ret_" prefix) 1117 # Possible forms: @ret (bare), @ret_sum, @ret_carry, etc. 1118 expected_suffix = ret_dest[5:] if ret_dest.startswith("@ret_") else "" # "sum" from "@ret_sum" 1119 if output_name and output_name == expected_suffix: 1120 # Found named output 1121 output_ref = output.get("ref") 1122 if isinstance(output_ref, dict): 1123 dest_name = output_ref.get("name", "@__unmatched") 1124 else: 1125 dest_name = str(output_ref) 1126 break 1127 1128 # If not found by name and this is @ret (bare), try positional mapping 1129 if dest_name.startswith("@__unmatched") and ret_dest == "@ret": 1130 for output in all_outputs: 1131 # Named outputs have "name" key that matches a @ret_name label 1132 # Positional outputs don't have this structure 1133 has_label_name = isinstance(output, dict) and "name" in output and "ref" in output 1134 if not has_label_name: 1135 # This is a positional output 1136 if isinstance(output, dict): 1137 # Positional output is stored as a ref dict with just "name" key 1138 dest_name = output.get("name", "@__unmatched") 1139 else: 1140 # Non-dict positional output 1141 dest_name = str(output) 1142 break 1143 1144 # Wire trampoline dest_l to the caller's output destination with ctx_override=True 1145 tramp_to_output_edge = IREdge( 1146 source=trampoline_name, 1147 dest=dest_name, 1148 port=Port.L, 1149 source_port=Port.L, # Output from trampoline's L port 1150 ctx_override=True, 1151 loc=call_site.loc, 1152 ) 1153 wired_edges.append(tramp_to_output_edge) 1154 1155 # Create a free_frame node (one per call site, not per @ret variant) 1156 # Wire it to trampoline's dest_r 1157 free_frame_name = f"{call_site.func_name}.__free_frame_{call_id}" 1158 if free_frame_name not in wired_nodes: 1159 # Only create once per call site 1160 free_frame_node = IRNode( 1161 name=free_frame_name, 1162 opcode=RoutingOp.FREE_FRAME, 1163 loc=call_site.loc, 1164 ) 1165 wired_nodes[free_frame_name] = free_frame_node 1166 free_frame_nodes.append(free_frame_name) 1167 1168 # Wire trampoline dest_r to free_frame 1169 tramp_to_free_edge = IREdge( 1170 source=trampoline_name, 1171 dest=free_frame_name, 1172 port=Port.L, 1173 source_port=Port.R, # Output from trampoline's R port 1174 loc=call_site.loc, 1175 ) 1176 wired_edges.append(tramp_to_free_edge) 1177 1178 # Create CallSite metadata 1179 call_site_metadata = CallSite( 1180 func_name=call_site.func_name, 1181 call_id=call_id, 1182 input_edges=tuple(input_edge_names), 1183 trampoline_nodes=tuple(trampoline_nodes), 1184 free_frame_nodes=tuple(free_frame_nodes), 1185 loc=call_site.loc, 1186 ) 1187 1188 return call_site_metadata, errors 1189 1190 1191def expand(graph: IRGraph) -> IRGraph: 1192 """Expand all macro calls in an IRGraph. 1193 1194 The expand pass processes all MacroDef and IRMacroCall entries from 1195 lowering, substitutes parameters, qualifies names, and recursively 1196 expands nested macros. The output graph contains no macro definitions 1197 or invocation artefacts. 1198 1199 Steps: 1200 1. Collect all MacroDef entries into a macro_table 1201 2. Recursively expand all IRMacroCall entries (depth limit 32) 1202 3. For each call: validate arity, build substitution map, clone body, 1203 substitute params, qualify names, splice into output 1204 4. Strip all MacroDef and IRMacroCall entries from output 1205 5. Return new IRGraph with only concrete nodes/edges 1206 1207 Args: 1208 graph: The IRGraph from the lower pass 1209 1210 Returns: 1211 New IRGraph with all macros expanded and no macro artefacts 1212 """ 1213 # Collect all macro definitions into a table 1214 macro_table: dict[str, MacroDef] = {} 1215 for macro_def in graph.macro_defs: 1216 macro_table[macro_def.name] = macro_def 1217 1218 # Initialize expansion counter 1219 expansion_counter: list[int] = [0] 1220 1221 # Recursively expand the graph starting at top level 1222 expanded_graph, expansion_errors = _expand_graph_recursive( 1223 graph, 1224 macro_table, 1225 expansion_counter, 1226 builtin_line_offset=graph.builtin_line_offset, 1227 ) 1228 1229 # Scan function regions to find all @ret destinations, create synthetic nodes, and track them 1230 synthetic_ret_nodes = {} # Map of synthetic_node_name -> IRNode 1231 function_ret_destinations = {} # Map of func_name -> set of @ret destinations 1232 new_regions = [] 1233 for region in expanded_graph.regions: 1234 if region.kind == RegionKind.FUNCTION: 1235 # Find all @ret destinations in function body edges 1236 ret_destinations = set() 1237 for edge in region.body.edges: 1238 if isinstance(edge.dest, str) and edge.dest.startswith("@ret"): 1239 ret_destinations.add(edge.dest) 1240 1241 # Store the destinations for later use by _wire_call_site 1242 function_ret_destinations[region.tag] = ret_destinations 1243 1244 # Create synthetic pass nodes for each @ret destination 1245 for ret_dest in ret_destinations: 1246 synthetic_node_name = f"{region.tag}.{ret_dest}" 1247 if synthetic_node_name not in synthetic_ret_nodes: 1248 synthetic_node = IRNode( 1249 name=synthetic_node_name, 1250 opcode=RoutingOp.PASS, 1251 ) 1252 synthetic_ret_nodes[synthetic_node_name] = synthetic_node 1253 1254 # Update edges in function body to point to synthetic @ret nodes 1255 new_body_edges = [] 1256 for edge in region.body.edges: 1257 new_dest = edge.dest 1258 # If destination starts with @ret, replace with synthetic node 1259 if isinstance(edge.dest, str) and edge.dest.startswith("@ret"): 1260 synthetic_node_name = f"{region.tag}.{edge.dest}" 1261 new_dest = synthetic_node_name 1262 new_body_edges.append(replace(edge, dest=new_dest)) 1263 1264 new_body = replace(region.body, edges=new_body_edges) 1265 new_region = replace(region, body=new_body) 1266 new_regions.append(new_region) 1267 else: 1268 new_regions.append(region) 1269 1270 expanded_graph = replace(expanded_graph, regions=new_regions) 1271 1272 # Add synthetic nodes to the top-level graph 1273 wired_nodes = dict(expanded_graph.nodes) 1274 wired_nodes.update(synthetic_ret_nodes) 1275 1276 # Process function call sites 1277 wired_call_sites = [] 1278 call_site_errors = [] 1279 wired_edges = list(expanded_graph.edges) 1280 processed_ret_nodes = set(synthetic_ret_nodes.keys()) # Track which synthetic nodes were created 1281 1282 call_id_counter = 0 1283 for call_site_result in expanded_graph.raw_call_sites: 1284 call_site_metadata, errors = _wire_call_site( 1285 call_site_result, 1286 expanded_graph, 1287 call_id_counter, 1288 wired_nodes, 1289 wired_edges, 1290 processed_ret_nodes, 1291 function_ret_destinations, 1292 ) 1293 wired_call_sites.append(call_site_metadata) 1294 call_site_errors.extend(errors) 1295 call_id_counter += 1 1296 1297 # Create final graph with wired call sites 1298 final_graph = replace( 1299 expanded_graph, 1300 nodes=wired_nodes, 1301 edges=wired_edges, 1302 call_sites=wired_call_sites, 1303 raw_call_sites=(), # Clear raw call sites after processing 1304 ) 1305 1306 # Accumulate all errors 1307 all_errors = list(graph.errors) + expansion_errors + call_site_errors 1308 1309 # Return with error list updated 1310 return replace(final_graph, errors=all_errors) 1311 1312 1313__all__ = ["expand"]