OR-1 dataflow CPU sketch
1"""Macro expansion pass for the OR1 assembler.
2
3This module implements macro invocation expansion (Phase 2). It processes
4IRMacroCall entries from the lowering pass, expands them by cloning and
5substituting macro bodies, and qualifies expanded names with scope prefixes.
6
7The expand() function receives an IRGraph from lower, processes all macro
8definitions and invocations, and returns a clean IRGraph with all macro
9artefacts removed.
10"""
11
12from __future__ import annotations
13
14import ast
15from dataclasses import replace
16from typing import Optional
17from collections.abc import Iterable
18
19from asm.errors import AssemblyError, ErrorCategory
20from asm.ir import (
21 IRGraph, IRNode, IREdge, IRRegion, RegionKind, ParamRef, ConstExpr,
22 MacroDef, IRMacroCall, CallSiteResult, CallSite, IRRepetitionBlock, SourceLoc,
23 PlacementRef, PortRef, ActSlotRef, ActSlotRange,
24)
25from asm.opcodes import MNEMONIC_TO_OP
26from cm_inst import Port, RoutingOp
27
28MAX_EXPANSION_DEPTH = 32
29
30
31def _levenshtein(a: str, b: str) -> int:
32 """Compute Levenshtein (edit) distance between two strings.
33
34 Note: This is duplicated from asm/resolve.py. If a third copy appears,
35 extract to a shared utility module.
36
37 Args:
38 a: First string
39 b: Second string
40
41 Returns:
42 Minimum edit distance (number of single-character edits)
43 """
44 if len(a) < len(b):
45 return _levenshtein(b, a)
46 if not b:
47 return len(a)
48
49 prev = list(range(len(b) + 1))
50 for i, ca in enumerate(a):
51 curr = [i + 1]
52 for j, cb in enumerate(b):
53 curr.append(min(
54 prev[j + 1] + 1, # deletion
55 curr[j] + 1, # insertion
56 prev[j] + (ca != cb), # substitution
57 ))
58 prev = curr
59 return prev[-1]
60
61
62def _suggest_names(unresolved: str, available_names: Iterable[str]) -> list[str]:
63 """Generate "did you mean" suggestions via Levenshtein distance.
64
65 Compares unresolved name against all available names, returning suggestions
66 with distance <= 3, or the closest match if all distances are > 3.
67
68 Args:
69 unresolved: The unresolved name
70 available_names: Iterable of available macro names
71
72 Returns:
73 List of suggestion strings (may be empty)
74 """
75 if not available_names:
76 return []
77
78 # Compute distances
79 candidates = []
80 for name in available_names:
81 dist = _levenshtein(unresolved, name)
82 candidates.append((dist, name))
83
84 # Sort by distance
85 candidates.sort(key=lambda x: x[0])
86
87 # Return suggestions with distance <= 3, or best if all > 3
88 suggestions = []
89 best_distance = candidates[0][0]
90
91 for dist, name in candidates:
92 if dist <= 3 or dist == best_distance:
93 suggestions.append(f"Did you mean '#{name}'?")
94 else:
95 break
96
97 return suggestions
98
99
100def _substitute_param(
101 value: object,
102 subst_map: dict[str, object],
103) -> object:
104 """Resolve a ParamRef or name against the substitution map.
105
106 Supports token pasting: ParamRef with prefix/suffix concatenates the
107 parameter value with the prefix and suffix to form a new name.
108
109 For const fields, returns the actual int value.
110 For names, returns the ref name string (possibly qualified).
111
112 Args:
113 value: The value to substitute (could be ParamRef, int, str, etc.)
114 subst_map: Map of formal param names to actual argument values
115
116 Returns:
117 The substituted value, or unchanged if not a ParamRef/param name.
118 If ParamRef has prefix/suffix, returns concatenated string.
119 """
120 if isinstance(value, ParamRef):
121 # Look up the parameter in the substitution map
122 actual = subst_map.get(value.param)
123 if actual is not None:
124 # Extract name from dict refs (e.g., {"name": "&x"} -> "&x")
125 if isinstance(actual, dict) and "name" in actual:
126 actual = actual["name"]
127 # Handle token pasting with prefix/suffix
128 if value.prefix or value.suffix:
129 # Convert actual value to string if it's an int
130 actual_str = str(actual) if isinstance(actual, int) else actual
131 # Concatenate: prefix + value + suffix
132 return value.prefix + actual_str + value.suffix
133 else:
134 # No prefix/suffix: return the actual value as-is
135 return actual
136 # Parameter not found - return unchanged (should not happen with proper validation)
137 return value
138
139 # For string names, check if they match a formal parameter
140 if isinstance(value, str):
141 # Don't substitute sigil-prefixed names (they may be qualified later)
142 if value and value[0] in "&@$#":
143 return value
144 # Check if this name is a formal parameter
145 if value in subst_map:
146 return subst_map[value]
147
148 return value
149
150
151def _eval_node(node, bindings: dict[str, int]) -> int:
152 """Evaluate a single AST node in a constant expression.
153
154 Args:
155 node: An ast node (Constant, Name, BinOp, or UnaryOp)
156 bindings: Map of parameter names to integer values
157
158 Returns:
159 The evaluated integer result
160
161 Raises:
162 ValueError: If node type is unsupported or value is non-numeric
163 """
164 if isinstance(node, ast.Constant) and isinstance(node.value, int):
165 return node.value
166 elif isinstance(node, ast.Name):
167 if node.id not in bindings:
168 raise ValueError(f"Undefined parameter: {node.id}")
169 val = bindings[node.id]
170 if not isinstance(val, int):
171 raise ValueError(f"Non-numeric value in arithmetic context")
172 return val
173 elif isinstance(node, ast.BinOp):
174 left = _eval_node(node.left, bindings)
175 right = _eval_node(node.right, bindings)
176 if isinstance(node.op, ast.Add):
177 return left + right
178 elif isinstance(node.op, ast.Sub):
179 return left - right
180 elif isinstance(node.op, ast.Mult):
181 return left * right
182 elif isinstance(node.op, ast.FloorDiv):
183 if right == 0:
184 raise ValueError("division by zero")
185 return left // right
186 else:
187 raise ValueError(f"Unsupported operator: {type(node.op).__name__}")
188 elif isinstance(node, ast.UnaryOp) and isinstance(node.op, ast.USub):
189 return -_eval_node(node.operand, bindings)
190 else:
191 raise ValueError(f"Unsupported expression node: {type(node).__name__}")
192
193
194def _eval_const_expr(expr: str, bindings: dict[str, int]) -> int:
195 """Evaluate a simple arithmetic expression with parameter bindings.
196
197 Supports: integer literals, +, -, *, // (integer division), parentheses.
198 No eval() call — safe AST walking only.
199
200 Args:
201 expr: Expression string, e.g. "base + 1"
202 bindings: Map of parameter names to integer values
203
204 Returns:
205 The evaluated integer result
206
207 Raises:
208 ValueError: If expression is invalid or contains non-numeric values
209 """
210 tree = ast.parse(expr, mode='eval')
211 return _eval_node(tree.body, bindings)
212
213
214def _qualify_expanded_name(
215 name: str,
216 macro_scope: str,
217 parent_scope: str = "",
218 func_scope: Optional[str] = None,
219) -> str:
220 """Apply scope prefix to an expanded name.
221
222 Takes a name and applies macro and optional function scopes.
223 Names starting with & are qualified; other sigils pass through.
224
225 Args:
226 name: The original name from the macro body
227 macro_scope: The macro scope (e.g., "#loop_counted_0")
228 parent_scope: Optional parent macro scope (e.g., "#outer_0")
229 func_scope: Optional function scope (e.g., "$main")
230
231 Returns:
232 The qualified name
233 """
234 if not name:
235 return name
236
237 # Check if it's a label (starts with &)
238 if name.startswith("&"):
239 # Build full scope: [func_scope.][parent_scope.]macro_scope.name
240 if func_scope and parent_scope:
241 # Triple-scoped: $func.#parent_N.#macro_M.&label
242 return f"{func_scope}.{parent_scope}.{macro_scope}.{name}"
243 elif parent_scope:
244 # Double-scoped: #parent_N.#macro_M.&label
245 return f"{parent_scope}.{macro_scope}.{name}"
246 elif func_scope:
247 # Double-scoped: $func.#macro_N.&label
248 return f"{func_scope}.{macro_scope}.{name}"
249 else:
250 # Single-scoped: #macro_N.&label
251 return f"{macro_scope}.{name}"
252
253 # Other sigils (@, $, #) pass through unqualified
254 return name
255
256
257def _clone_and_substitute_node(
258 node: IRNode,
259 macro_scope: str,
260 subst_map: dict[str, object],
261 func_scope: Optional[str] = None,
262 parent_scope: str = "",
263) -> tuple[IRNode, list[AssemblyError]]:
264 """Deep-clone a node and substitute parameters.
265
266 Args:
267 node: The template node from the macro body
268 macro_scope: The macro scope for qualification
269 subst_map: Map of formal params to actual arguments
270 func_scope: Optional function scope
271 parent_scope: Optional parent macro scope
272
273 Returns:
274 Tuple of (new IRNode with substitutions applied and name qualified, errors list)
275 """
276 errors = []
277
278 # Substitute the const field
279 new_const = _substitute_param(node.const, subst_map)
280
281 # If const is a ConstExpr, evaluate it
282 if isinstance(new_const, ConstExpr):
283 try:
284 # Build bindings dict from subst_map, converting all to int
285 bindings = {}
286 for param_name in new_const.params:
287 if param_name in subst_map:
288 val = subst_map[param_name]
289 if not isinstance(val, int):
290 errors.append(AssemblyError(
291 loc=new_const.loc,
292 category=ErrorCategory.VALUE,
293 message=f"Non-numeric value '{val}' in arithmetic context",
294 ))
295 # Return node with ConstExpr unchanged (will be caught later)
296 substituted_name = _substitute_param(node.name, subst_map)
297 if not isinstance(substituted_name, str):
298 substituted_name = str(substituted_name)
299 new_name = _qualify_expanded_name(substituted_name, macro_scope, parent_scope, func_scope)
300 return replace(node, name=new_name, const=new_const), errors
301 bindings[param_name] = val
302 # Evaluate the expression
303 evaluated = _eval_const_expr(new_const.expression, bindings)
304 new_const = evaluated
305 except ValueError as e:
306 errors.append(AssemblyError(
307 loc=new_const.loc,
308 category=ErrorCategory.VALUE,
309 message=str(e),
310 ))
311
312 # Resolve opcode if it's a ParamRef
313 new_opcode = node.opcode
314 if isinstance(new_opcode, ParamRef):
315 resolved = _substitute_param(new_opcode, subst_map)
316 if isinstance(resolved, str):
317 if resolved in MNEMONIC_TO_OP:
318 new_opcode = MNEMONIC_TO_OP[resolved]
319 else:
320 errors.append(AssemblyError(
321 loc=node.loc,
322 category=ErrorCategory.MACRO,
323 message=f"'{resolved}' is not a valid opcode mnemonic",
324 ))
325 new_opcode = node.opcode
326 else:
327 errors.append(AssemblyError(
328 loc=node.loc,
329 category=ErrorCategory.MACRO,
330 message=f"opcode parameter must resolve to an opcode mnemonic, got {type(resolved).__name__}",
331 ))
332 new_opcode = node.opcode
333
334 # Resolve placement if it's a PlacementRef
335 new_pe = node.pe
336 if isinstance(new_pe, PlacementRef):
337 resolved = _substitute_param(new_pe.param, subst_map)
338 if isinstance(resolved, str) and resolved.startswith("pe"):
339 try:
340 new_pe = int(resolved[2:])
341 except ValueError:
342 errors.append(AssemblyError(
343 loc=node.loc,
344 category=ErrorCategory.MACRO,
345 message=f"placement parameter must resolve to 'peN', got '{resolved}'",
346 ))
347 new_pe = None
348 elif isinstance(resolved, int):
349 new_pe = resolved
350 else:
351 errors.append(AssemblyError(
352 loc=node.loc,
353 category=ErrorCategory.MACRO,
354 message=f"placement parameter must resolve to 'peN', got {type(resolved).__name__}",
355 ))
356 new_pe = None
357
358 # Resolve act_slot if it's a ActSlotRef
359 new_act_slot = node.act_slot
360 if isinstance(new_act_slot, ActSlotRef):
361 resolved = _substitute_param(new_act_slot.param, subst_map)
362 if isinstance(resolved, int):
363 new_act_slot = ActSlotRange(start=resolved, end=resolved)
364 else:
365 errors.append(AssemblyError(
366 loc=node.loc,
367 category=ErrorCategory.MACRO,
368 message=f"act_slot parameter must resolve to an integer, got {type(resolved).__name__}",
369 ))
370 new_act_slot = None
371
372 # Substitute the node name (may be a ParamRef with token pasting)
373 substituted_name = _substitute_param(node.name, subst_map)
374
375 # Ensure name is a string before qualification
376 if not isinstance(substituted_name, str):
377 substituted_name = str(substituted_name)
378
379 # Qualify the node name
380 new_name = _qualify_expanded_name(substituted_name, macro_scope, parent_scope, func_scope)
381
382 return replace(node, name=new_name, const=new_const, opcode=new_opcode,
383 pe=new_pe, act_slot=new_act_slot), errors
384
385
386def _clone_and_substitute_edge(
387 edge: IREdge,
388 macro_scope: str,
389 subst_map: dict[str, object],
390 func_scope: Optional[str] = None,
391 parent_scope: str = "",
392) -> tuple[IREdge, list[AssemblyError]]:
393 """Deep-clone an edge and substitute/qualify names.
394
395 Args:
396 edge: The template edge from the macro body
397 macro_scope: The macro scope for qualification
398 subst_map: Map of formal params to actual arguments
399 func_scope: Optional function scope
400 parent_scope: Optional parent macro scope
401
402 Returns:
403 Tuple of (new IREdge with names qualified, list of errors)
404 """
405 errors: list[AssemblyError] = []
406 # Substitute source and dest names.
407 # Track whether each was a ParamRef — substituted refs are external
408 # and must NOT be qualified with the macro scope.
409 source_was_param = isinstance(edge.source, ParamRef)
410 source = _substitute_param(edge.source, subst_map)
411 if not isinstance(source, str):
412 source = str(source)
413
414 dest_was_param = isinstance(edge.dest, ParamRef)
415 dest = _substitute_param(edge.dest, subst_map)
416 if not isinstance(dest, str):
417 dest = str(dest)
418
419 # Only qualify names that came from the macro body template directly.
420 # Substituted parameter refs point to external names and stay unqualified.
421 if not source_was_param:
422 source = _qualify_expanded_name(source, macro_scope, parent_scope, func_scope)
423 if not dest_was_param:
424 dest = _qualify_expanded_name(dest, macro_scope, parent_scope, func_scope)
425
426 # Resolve PortRef on dest port
427 new_port = edge.port
428 if isinstance(new_port, PortRef):
429 resolved = _substitute_param(new_port.param, subst_map)
430 if isinstance(resolved, str):
431 if resolved == "L":
432 new_port = Port.L
433 elif resolved == "R":
434 new_port = Port.R
435 else:
436 errors.append(AssemblyError(
437 loc=edge.loc,
438 category=ErrorCategory.MACRO,
439 message=f"port parameter must resolve to 'L' or 'R', got '{resolved}'",
440 ))
441 new_port = Port.L
442 elif isinstance(resolved, Port):
443 new_port = resolved
444 else:
445 errors.append(AssemblyError(
446 loc=edge.loc,
447 category=ErrorCategory.MACRO,
448 message=f"port parameter must resolve to 'L' or 'R', got '{resolved}'",
449 ))
450 new_port = Port.L
451
452 # Resolve PortRef on source port
453 new_source_port = edge.source_port
454 if isinstance(new_source_port, PortRef):
455 resolved = _substitute_param(new_source_port.param, subst_map)
456 if isinstance(resolved, str):
457 if resolved == "L":
458 new_source_port = Port.L
459 elif resolved == "R":
460 new_source_port = Port.R
461 else:
462 errors.append(AssemblyError(
463 loc=edge.loc,
464 category=ErrorCategory.MACRO,
465 message=f"source port parameter must resolve to 'L' or 'R', got '{resolved}'",
466 ))
467 new_source_port = None
468 elif isinstance(resolved, Port):
469 new_source_port = resolved
470 else:
471 errors.append(AssemblyError(
472 loc=edge.loc,
473 category=ErrorCategory.MACRO,
474 message=f"source port parameter must resolve to 'L' or 'R', got '{resolved}'",
475 ))
476 new_source_port = None
477
478 return replace(edge, source=source, dest=dest, port=new_port, source_port=new_source_port), errors
479
480
481def _add_expansion_context(
482 error: AssemblyError,
483 call: IRMacroCall,
484 builtin_line_offset: int = 0,
485) -> AssemblyError:
486 """Add expansion context to an error.
487
488 Appends "expanded from #macro_name at line N, column C" to the
489 context_lines to trace the error back to the macro invocation site.
490
491 Args:
492 error: The error to enhance
493 call: The IRMacroCall being expanded
494 builtin_line_offset: Lines to subtract for display (built-in macro prefix)
495
496 Returns:
497 New AssemblyError with expansion context added to context_lines
498 """
499 display_line = call.loc.line
500 if builtin_line_offset > 0 and display_line > builtin_line_offset:
501 display_line -= builtin_line_offset
502 expansion_context = (
503 f"expanded from #{call.name} at line {display_line}, "
504 f"column {call.loc.column}"
505 )
506 return replace(
507 error,
508 context_lines=list(error.context_lines) + [expansion_context],
509 )
510
511
512def _expand_repetition_block(
513 rep_block: IRRepetitionBlock,
514 variadic_args: list[object],
515 macro_scope: str,
516 subst_map: dict[str, object],
517 func_scope: Optional[str] = None,
518 parent_scope: str = "",
519) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]:
520 """Expand a repetition block once per variadic argument.
521
522 For each iteration, clones the body, substitutes the variadic param
523 and ${_idx} (iteration index), and qualifies names.
524
525 Args:
526 rep_block: The IRRepetitionBlock to expand
527 variadic_args: List of actual arguments for the variadic parameter
528 macro_scope: The macro scope for qualification
529 subst_map: Base substitution map (will be extended with variadic param and _idx)
530 func_scope: Optional function scope
531 parent_scope: Optional parent macro scope
532
533 Returns:
534 Tuple of (expanded_nodes dict, expanded_edges list, errors list)
535 """
536 errors = []
537 expanded_nodes: dict[str, IRNode] = {}
538 expanded_edges: list[IREdge] = []
539
540 # Iterate over variadic arguments
541 for idx, arg_value in enumerate(variadic_args):
542 # Create iteration-specific substitution map
543 iter_subst_map = dict(subst_map)
544 iter_subst_map[rep_block.variadic_param] = arg_value
545 iter_subst_map["_idx"] = idx # Make iteration index available as parameter
546
547 # Clone and substitute nodes from the repetition body
548 for node_name, node in rep_block.body.nodes.items():
549 # Create a unique name for this iteration
550 # Qualify the node name with macro scope and iteration suffix
551 qualified_node, node_errors = _clone_and_substitute_node(
552 node,
553 f"{macro_scope}_rep{idx}",
554 iter_subst_map,
555 func_scope,
556 parent_scope,
557 )
558 errors.extend(node_errors)
559 expanded_nodes[qualified_node.name] = qualified_node
560
561 # Clone and substitute edges from the repetition body
562 for edge in rep_block.body.edges:
563 qualified_edge, edge_errors = _clone_and_substitute_edge(
564 edge,
565 f"{macro_scope}_rep{idx}",
566 iter_subst_map,
567 func_scope,
568 parent_scope,
569 )
570 errors.extend(edge_errors)
571 expanded_edges.append(qualified_edge)
572
573 return expanded_nodes, expanded_edges, errors
574
575
576def _expand_call(
577 call: IRMacroCall,
578 macro_table: dict[str, MacroDef],
579 expansion_counter: list[int],
580 func_scope: Optional[str] = None,
581 parent_scope: str = "",
582 depth: int = 0,
583 builtin_line_offset: int = 0,
584) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]:
585 """Process a single macro call.
586
587 Looks up the macro, validates arity, builds substitution map, clones
588 the body, and performs parameter substitution and name qualification.
589
590 Args:
591 call: The IRMacroCall to expand
592 macro_table: Map of macro names to MacroDef objects
593 expansion_counter: [int] list for mutable counter (incremented per expansion)
594 func_scope: Optional function scope the call is in
595 parent_scope: Optional parent macro scope (for nested macros)
596 depth: Recursion depth (error if exceeds 32)
597 builtin_line_offset: Lines to subtract for display in error context
598
599 Returns:
600 Tuple of (expanded_nodes dict, expanded_edges list, errors list)
601 """
602 errors = []
603
604 # Check depth limit
605 if depth > MAX_EXPANSION_DEPTH:
606 error = AssemblyError(
607 loc=call.loc,
608 category=ErrorCategory.MACRO,
609 message=f"macro expansion depth exceeds {MAX_EXPANSION_DEPTH} (likely infinite recursion in macro '{call.name}')",
610 )
611 return {}, [], [error]
612
613 # Look up macro definition
614 if call.name not in macro_table:
615 suggestions = _suggest_names(call.name, macro_table.keys())
616 error = AssemblyError(
617 loc=call.loc,
618 category=ErrorCategory.MACRO,
619 message=f"undefined macro '#{call.name}'",
620 suggestions=suggestions,
621 )
622 return {}, [], [error]
623
624 macro_def = macro_table[call.name]
625
626 # Validate arity and separate variadic arguments
627 total_args = len(call.positional_args) + len(call.named_args)
628
629 # Count required parameters (non-variadic)
630 required_params = [p for p in macro_def.params if not p.variadic]
631 variadic_param = next((p for p in macro_def.params if p.variadic), None)
632
633 if variadic_param:
634 # With variadic: need at least as many args as required params
635 if total_args < len(required_params):
636 error = AssemblyError(
637 loc=call.loc,
638 category=ErrorCategory.MACRO,
639 message=f"macro '#{call.name}' expects at least {len(required_params)} argument(s), got {total_args}",
640 )
641 return {}, [], [error]
642 else:
643 # Without variadic: exact match required
644 expected_count = len(macro_def.params)
645 if total_args != expected_count:
646 error = AssemblyError(
647 loc=call.loc,
648 category=ErrorCategory.MACRO,
649 message=f"macro '#{call.name}' expects {expected_count} argument(s), got {total_args}",
650 )
651 return {}, [], [error]
652
653 # Build substitution map
654 subst_map: dict[str, object] = {}
655 variadic_args: list[object] = []
656
657 # Add positional arguments
658 for i, actual_value in enumerate(call.positional_args):
659 if i < len(required_params):
660 # Regular parameter
661 param_name = required_params[i].name
662 subst_map[param_name] = actual_value
663 elif variadic_param:
664 # Extra arguments go to variadic parameter
665 variadic_args.append(actual_value)
666
667 # Add named arguments (to required params only; named variadic args not supported)
668 for param_name, actual_value in call.named_args:
669 subst_map[param_name] = actual_value
670
671 # Generate unique macro scope
672 expansion_id = expansion_counter[0]
673 expansion_counter[0] += 1
674 macro_scope = f"#{call.name}_{expansion_id}"
675
676 # Recursively expand and qualify the macro body, including nested calls
677 def _expand_body_recursive(
678 body: IRGraph,
679 depth: int,
680 ) -> tuple[dict[str, IRNode], list[IREdge], list[AssemblyError]]:
681 """Recursively expand all macro calls in a body graph and its regions."""
682 body_errors: list[AssemblyError] = []
683 body_nodes: dict[str, IRNode] = {}
684 body_edges: list[IREdge] = []
685
686 # Qualify and add the body's own nodes
687 for node_name, node in body.nodes.items():
688 qualified_node, node_errors = _clone_and_substitute_node(node, macro_scope, subst_map, func_scope, parent_scope)
689 # Add expansion context to node-level errors (const expression evaluation, etc.)
690 for error in node_errors:
691 body_errors.append(_add_expansion_context(error, call, builtin_line_offset))
692 body_nodes[qualified_node.name] = qualified_node
693
694 # Qualify the body's own edges
695 for edge in body.edges:
696 qualified_edge, edge_errors = _clone_and_substitute_edge(edge, macro_scope, subst_map, func_scope, parent_scope)
697 body_errors.extend(edge_errors)
698 body_edges.append(qualified_edge)
699
700 # Expand macro calls at this body level
701 # Nested calls have current macro_scope as their parent_scope
702 for nested_call in body.macro_calls:
703 nested_expanded_nodes, nested_expanded_edges, nested_errors = _expand_call(
704 nested_call,
705 macro_table,
706 expansion_counter,
707 func_scope,
708 macro_scope, # Current macro scope becomes parent for nested
709 depth + 1,
710 builtin_line_offset,
711 )
712 # Add expansion context to nested errors (trace them back to the nested call)
713 for error in nested_errors:
714 body_errors.append(_add_expansion_context(error, nested_call, builtin_line_offset))
715 body_nodes.update(nested_expanded_nodes)
716 # Filter out leaked @ret edges from failed inner expansions to prevent
717 # spurious "defines output(s) @ret" errors at the outer macro level
718 for nested_edge in nested_expanded_edges:
719 if isinstance(nested_edge.dest, str) and nested_edge.dest.startswith("@ret"):
720 continue
721 body_edges.append(nested_edge)
722
723 # Expand repetition blocks (Phase 6 variadic macros)
724 if variadic_param:
725 for rep_block in macro_def.repetition_blocks:
726 # Only expand blocks for the current variadic parameter
727 if rep_block.variadic_param == variadic_param.name:
728 rep_nodes, rep_edges, rep_errors = _expand_repetition_block(
729 rep_block,
730 variadic_args,
731 macro_scope,
732 subst_map,
733 func_scope,
734 parent_scope,
735 )
736 body_errors.extend(rep_errors)
737 body_nodes.update(rep_nodes)
738 body_edges.extend(rep_edges)
739
740 # Recursively expand regions in the body
741 for region in body.regions:
742 region_func_scope = region.tag if region.kind == RegionKind.FUNCTION else func_scope
743 region_nodes, region_edges, region_errors = _expand_body_recursive(
744 region.body,
745 depth + 1,
746 )
747 body_errors.extend(region_errors)
748 body_nodes.update(region_nodes)
749 body_edges.extend(region_edges)
750
751 return body_nodes, body_edges, body_errors
752
753 expanded_nodes, expanded_edges, nested_errors = _expand_body_recursive(
754 macro_def.body,
755 depth,
756 )
757 errors.extend(nested_errors)
758
759 # Rewrite @ret edges: either substitute with output destinations, or
760 # report an error if the macro body uses @ret but the call site doesn't
761 # provide output wiring.
762 has_ret_edges = any(
763 isinstance(e.dest, str) and e.dest.startswith("@ret")
764 for e in expanded_edges
765 )
766
767 if has_ret_edges and not call.output_dests:
768 # Collect the @ret markers for the error message
769 ret_markers = sorted({
770 e.dest for e in expanded_edges
771 if isinstance(e.dest, str) and e.dest.startswith("@ret")
772 })
773 errors.append(AssemblyError(
774 loc=call.loc,
775 category=ErrorCategory.MACRO,
776 message=f"macro '#{call.name}' defines output(s) {', '.join(ret_markers)} but call site has no '|>' output wiring",
777 ))
778
779 if call.output_dests:
780 rewritten_edges = []
781 all_outputs = list(call.output_dests)
782
783 # Build ordered list of positional outputs for bare @ret resolution.
784 # Each bare @ret consumes the next positional output in order,
785 # enabling variadic macros to wire each iteration to a separate dest:
786 # #macro *vals |> { $( &c <| const, ${vals}; &c |> @ret ),* }
787 # #macro 3, 4 |> &x, &y ← iteration 0 → &x, iteration 1 → &y
788 positional_outputs: list[str] = []
789 for output in all_outputs:
790 if isinstance(output, dict):
791 if "name" in output and "ref" in output:
792 continue # Named output — not positional
793 name = output.get("name", None)
794 if name is not None:
795 positional_outputs.append(name)
796 else:
797 positional_outputs.append(str(output))
798 positional_idx = 0
799
800 for edge in expanded_edges:
801 if not (isinstance(edge.dest, str) and edge.dest.startswith("@ret")):
802 rewritten_edges.append(edge)
803 continue
804
805 # This edge targets an @ret marker — resolve it
806 ret_dest = edge.dest # e.g. "@ret" or "@ret_body"
807 dest_name = None
808
809 # Try named match: @ret_body -> output with name="body"
810 if ret_dest.startswith("@ret_"):
811 expected_suffix = ret_dest[5:] # "body" from "@ret_body"
812 for output in all_outputs:
813 if isinstance(output, dict) and "name" in output and "ref" in output:
814 if output["name"] == expected_suffix:
815 ref = output["ref"]
816 dest_name = ref.get("name", ref) if isinstance(ref, dict) else str(ref)
817 break
818
819 # Bare @ret -> next positional output (advances counter)
820 if dest_name is None and ret_dest == "@ret":
821 if positional_idx < len(positional_outputs):
822 dest_name = positional_outputs[positional_idx]
823 positional_idx += 1
824
825 if dest_name is None:
826 errors.append(AssemblyError(
827 loc=call.loc,
828 category=ErrorCategory.MACRO,
829 message=f"macro '#{call.name}' has output marker '{ret_dest}' but no matching output destination in call site",
830 ))
831 rewritten_edges.append(edge)
832 continue
833
834 # Replace the @ret destination with the concrete node reference
835 rewritten_edges.append(replace(edge, dest=dest_name))
836
837 expanded_edges = rewritten_edges
838
839 # Propagate errors from macro body template
840 for body_error in macro_def.body.errors:
841 # Adjust source location to point to the call site
842 adjusted_error = replace(
843 body_error,
844 loc=call.loc,
845 suggestions=list(body_error.suggestions) + [
846 f"defined in macro #{macro_def.name} at line {macro_def.loc.line}"
847 ],
848 )
849 errors.append(adjusted_error)
850
851 return expanded_nodes, expanded_edges, errors
852
853
854def _expand_graph_recursive(
855 graph: IRGraph,
856 macro_table: dict[str, MacroDef],
857 expansion_counter: list[int],
858 func_scope: Optional[str] = None,
859 builtin_line_offset: int = 0,
860) -> tuple[IRGraph, list[AssemblyError]]:
861 """Recursively expand macros in a graph and its regions.
862
863 Args:
864 graph: The IRGraph to expand
865 macro_table: Map of macro names to MacroDef objects
866 expansion_counter: [int] list for mutable counter
867 func_scope: Optional function scope for name qualification
868 builtin_line_offset: Lines to subtract for display in error context
869
870 Returns:
871 Tuple of (new_graph, all_errors)
872 """
873 new_errors: list[AssemblyError] = []
874 expanded_nodes: dict[str, IRNode] = dict(graph.nodes)
875 expanded_edges: list[IREdge] = list(graph.edges)
876
877 # Collect all macro calls from this graph level
878 # Note: The lower pass doesn't populate macro_calls in regions,
879 # so we also need to collect from macro_calls in the graph
880 all_calls_at_level = list(graph.macro_calls)
881
882 # Expand all macro calls at this level
883 for call in all_calls_at_level:
884 # Determine the function scope for nested calls
885 call_func_scope = func_scope
886 call_expanded_nodes, call_expanded_edges, call_errors = _expand_call(
887 call,
888 macro_table,
889 expansion_counter,
890 call_func_scope,
891 "", # No parent scope at top level
892 builtin_line_offset=builtin_line_offset,
893 )
894 for error in call_errors:
895 new_errors.append(_add_expansion_context(error, call, builtin_line_offset))
896 expanded_nodes.update(call_expanded_nodes)
897 expanded_edges.extend(call_expanded_edges)
898
899 # Recursively expand regions (function bodies, etc.)
900 new_regions: list[IRRegion] = []
901 for region in graph.regions:
902 # For function regions, pass the region tag as the func_scope for name qualification
903 region_func_scope = region.tag if region.kind == RegionKind.FUNCTION else func_scope
904 new_body, region_errors = _expand_graph_recursive(
905 region.body,
906 macro_table,
907 expansion_counter,
908 region_func_scope,
909 builtin_line_offset,
910 )
911 new_errors.extend(region_errors)
912 new_region = replace(region, body=new_body)
913 new_regions.append(new_region)
914
915 # Create new graph with expanded content and no macro artefacts
916 new_graph = replace(
917 graph,
918 nodes=expanded_nodes,
919 edges=expanded_edges,
920 regions=new_regions,
921 macro_defs=[], # Remove all macro defs
922 macro_calls=[], # Remove all macro calls
923 )
924
925 return new_graph, new_errors
926
927
928def _wire_call_site(
929 call_site: CallSiteResult,
930 graph: IRGraph,
931 call_id: int,
932 wired_nodes: dict[str, IRNode],
933 wired_edges: list[IREdge],
934 processed_ret_nodes: set,
935 function_ret_destinations: dict[str, set],
936) -> tuple[CallSite, list[AssemblyError]]:
937 """Process a single function call site and wire it into the graph.
938
939 This function:
940 1. Finds the function definition in the graph's regions
941 2. Matches input arguments to function labels
942 3. Synthesises @ret rendezvous nodes (shared across call sites)
943 4. Creates per-call-site trampolines and free_frame nodes
944 5. Wires everything together with ctx_override edges
945
946 Args:
947 call_site: The CallSiteResult from the lower pass
948 graph: The IRGraph containing regions (functions)
949 call_id: Unique ID for this call site
950 wired_nodes: Dictionary to accumulate generated nodes
951 wired_edges: List to accumulate generated edges
952 processed_ret_nodes: Cache of already-synthesised @ret nodes (func_name.@ret -> node_name)
953
954 Returns:
955 Tuple of (CallSite metadata, errors list)
956 """
957 errors = []
958
959 # Find the function definition in the graph's regions
960 func_region = None
961 for region in graph.regions:
962 if region.kind == RegionKind.FUNCTION and region.tag == call_site.func_name:
963 func_region = region
964 break
965
966 if func_region is None:
967 error = AssemblyError(
968 loc=call_site.loc,
969 category=ErrorCategory.CALL,
970 message=f"undefined function '{call_site.func_name}'",
971 )
972 return CallSite(
973 func_name=call_site.func_name,
974 call_id=call_id,
975 ), [error]
976
977 # Collect all nodes in the function body (including nested regions)
978 func_all_nodes = {}
979 func_all_edges = []
980
981 def _collect_from_region(r: IRGraph):
982 func_all_nodes.update(r.nodes)
983 func_all_edges.extend(r.edges)
984 for sub_region in r.regions:
985 _collect_from_region(sub_region.body)
986
987 _collect_from_region(func_region.body)
988
989 input_edge_names = []
990 trampoline_nodes = []
991 free_frame_nodes = []
992
993 # Process input arguments: match each to a label in the function
994 for param_name, source_ref in call_site.input_args:
995 # source_ref may be a dict with {"name": "..."} or a simple string
996 if isinstance(source_ref, dict):
997 source_name = source_ref.get("name", str(source_ref))
998 else:
999 source_name = str(source_ref)
1000
1001 # Look for a label ¶m_name in the function
1002 target_label = f"{call_site.func_name}.&{param_name}"
1003
1004 if target_label not in func_all_nodes:
1005 error = AssemblyError(
1006 loc=call_site.loc,
1007 category=ErrorCategory.CALL,
1008 message=f"argument '{param_name}' does not match any label in '{call_site.func_name}'",
1009 )
1010 errors.append(error)
1011 continue
1012
1013 # Check if source node has a const (AC5.3: const+CTX_OVRD conflict)
1014 # If so, insert a pass trampoline between source and target
1015 source_node = graph.nodes.get(source_name)
1016 if source_node is not None and source_node.const is not None:
1017 # Insert a pass trampoline to separate const from ctx_override
1018 tramp_name = f"{call_site.func_name}.__input_tramp_{call_id}_{param_name}"
1019 tramp_node = IRNode(
1020 name=tramp_name,
1021 opcode=RoutingOp.PASS,
1022 loc=call_site.loc,
1023 )
1024 wired_nodes[tramp_name] = tramp_node
1025
1026 # Wire: source -> trampoline (no ctx_override, inherits ctx)
1027 source_to_tramp = IREdge(
1028 source=source_name,
1029 dest=tramp_name,
1030 port=Port.L,
1031 loc=call_site.loc,
1032 )
1033 wired_edges.append(source_to_tramp)
1034
1035 # Wire: trampoline -> target (no ctx_override — INHERIT mode reads
1036 # the destination FrameDest from the frame, which already encodes
1037 # the function's act_id. CHANGE_TAG is wrong here because the left
1038 # operand is raw data, not a packed FrameDest.)
1039 tramp_to_target = IREdge(
1040 source=tramp_name,
1041 dest=target_label,
1042 port=Port.L,
1043 loc=call_site.loc,
1044 )
1045 wired_edges.append(tramp_to_target)
1046 else:
1047 # No conflict — direct edge (no ctx_override — the destination
1048 # node's act_id in the FrameDest handles cross-context routing)
1049 input_edge = IREdge(
1050 source=source_name,
1051 dest=target_label,
1052 port=Port.L,
1053 loc=call_site.loc,
1054 )
1055 wired_edges.append(input_edge)
1056
1057 edge_name = f"{call_site.func_name}.__input_{call_id}_{param_name}"
1058 input_edge_names.append(edge_name)
1059
1060 # Get @ret destinations for this function (pre-computed during expand setup)
1061 ret_destinations = set()
1062 if function_ret_destinations and call_site.func_name in function_ret_destinations:
1063 ret_destinations = function_ret_destinations[call_site.func_name]
1064
1065 # For each @ret variant, create a per-call-site trampoline
1066 # (synthetic nodes are already created during expand pass setup)
1067 for ret_dest in ret_destinations:
1068 # Determine the synthetic node name: $func.@ret or $func.@ret_name
1069 synthetic_node_name = f"{call_site.func_name}.{ret_dest}"
1070
1071 # Synthetic node should already exist from expand setup
1072 if synthetic_node_name not in processed_ret_nodes:
1073 # This shouldn't happen, but create it just in case
1074 synthetic_pass_node = IRNode(
1075 name=synthetic_node_name,
1076 opcode=RoutingOp.PASS,
1077 loc=call_site.loc,
1078 )
1079 wired_nodes[synthetic_node_name] = synthetic_pass_node
1080 processed_ret_nodes.add(synthetic_node_name)
1081
1082 # Create a per-call-site trampoline pass node
1083 trampoline_name = f"{call_site.func_name}.__ret_trampoline_{call_id}_{ret_dest[1:]}"
1084 trampoline_node = IRNode(
1085 name=trampoline_name,
1086 opcode=RoutingOp.PASS,
1087 dest_l=None, # Will be wired below
1088 dest_r=None, # Will be wired below
1089 loc=call_site.loc,
1090 )
1091 wired_nodes[trampoline_name] = trampoline_node
1092 trampoline_nodes.append(trampoline_name)
1093
1094 # Create edge from synthetic @ret node to trampoline
1095 ret_to_tramp_edge = IREdge(
1096 source=synthetic_node_name,
1097 dest=trampoline_name,
1098 port=Port.L,
1099 loc=call_site.loc,
1100 )
1101 wired_edges.append(ret_to_tramp_edge)
1102
1103 # Find the corresponding output destination from call_site.output_dests
1104 # output_dests is a flat tuple of dicts: each dict is either a named_output
1105 # {"name": "...", "ref": {...}} or positional_output {...}
1106 output_dest = None
1107 dest_name = f"@__unmatched_{ret_dest}"
1108
1109 # Iterate directly over flattened output_dests
1110 all_outputs = list(call_site.output_dests) if call_site.output_dests else []
1111
1112 # Try to find named output matching ret_dest
1113 for output in all_outputs:
1114 if isinstance(output, dict):
1115 output_name = output.get("name")
1116 # ret_dest is "@ret_name", so we need to match "name" part (without @ prefix and "ret_" prefix)
1117 # Possible forms: @ret (bare), @ret_sum, @ret_carry, etc.
1118 expected_suffix = ret_dest[5:] if ret_dest.startswith("@ret_") else "" # "sum" from "@ret_sum"
1119 if output_name and output_name == expected_suffix:
1120 # Found named output
1121 output_ref = output.get("ref")
1122 if isinstance(output_ref, dict):
1123 dest_name = output_ref.get("name", "@__unmatched")
1124 else:
1125 dest_name = str(output_ref)
1126 break
1127
1128 # If not found by name and this is @ret (bare), try positional mapping
1129 if dest_name.startswith("@__unmatched") and ret_dest == "@ret":
1130 for output in all_outputs:
1131 # Named outputs have "name" key that matches a @ret_name label
1132 # Positional outputs don't have this structure
1133 has_label_name = isinstance(output, dict) and "name" in output and "ref" in output
1134 if not has_label_name:
1135 # This is a positional output
1136 if isinstance(output, dict):
1137 # Positional output is stored as a ref dict with just "name" key
1138 dest_name = output.get("name", "@__unmatched")
1139 else:
1140 # Non-dict positional output
1141 dest_name = str(output)
1142 break
1143
1144 # Wire trampoline dest_l to the caller's output destination with ctx_override=True
1145 tramp_to_output_edge = IREdge(
1146 source=trampoline_name,
1147 dest=dest_name,
1148 port=Port.L,
1149 source_port=Port.L, # Output from trampoline's L port
1150 ctx_override=True,
1151 loc=call_site.loc,
1152 )
1153 wired_edges.append(tramp_to_output_edge)
1154
1155 # Create a free_frame node (one per call site, not per @ret variant)
1156 # Wire it to trampoline's dest_r
1157 free_frame_name = f"{call_site.func_name}.__free_frame_{call_id}"
1158 if free_frame_name not in wired_nodes:
1159 # Only create once per call site
1160 free_frame_node = IRNode(
1161 name=free_frame_name,
1162 opcode=RoutingOp.FREE_FRAME,
1163 loc=call_site.loc,
1164 )
1165 wired_nodes[free_frame_name] = free_frame_node
1166 free_frame_nodes.append(free_frame_name)
1167
1168 # Wire trampoline dest_r to free_frame
1169 tramp_to_free_edge = IREdge(
1170 source=trampoline_name,
1171 dest=free_frame_name,
1172 port=Port.L,
1173 source_port=Port.R, # Output from trampoline's R port
1174 loc=call_site.loc,
1175 )
1176 wired_edges.append(tramp_to_free_edge)
1177
1178 # Create CallSite metadata
1179 call_site_metadata = CallSite(
1180 func_name=call_site.func_name,
1181 call_id=call_id,
1182 input_edges=tuple(input_edge_names),
1183 trampoline_nodes=tuple(trampoline_nodes),
1184 free_frame_nodes=tuple(free_frame_nodes),
1185 loc=call_site.loc,
1186 )
1187
1188 return call_site_metadata, errors
1189
1190
1191def expand(graph: IRGraph) -> IRGraph:
1192 """Expand all macro calls in an IRGraph.
1193
1194 The expand pass processes all MacroDef and IRMacroCall entries from
1195 lowering, substitutes parameters, qualifies names, and recursively
1196 expands nested macros. The output graph contains no macro definitions
1197 or invocation artefacts.
1198
1199 Steps:
1200 1. Collect all MacroDef entries into a macro_table
1201 2. Recursively expand all IRMacroCall entries (depth limit 32)
1202 3. For each call: validate arity, build substitution map, clone body,
1203 substitute params, qualify names, splice into output
1204 4. Strip all MacroDef and IRMacroCall entries from output
1205 5. Return new IRGraph with only concrete nodes/edges
1206
1207 Args:
1208 graph: The IRGraph from the lower pass
1209
1210 Returns:
1211 New IRGraph with all macros expanded and no macro artefacts
1212 """
1213 # Collect all macro definitions into a table
1214 macro_table: dict[str, MacroDef] = {}
1215 for macro_def in graph.macro_defs:
1216 macro_table[macro_def.name] = macro_def
1217
1218 # Initialize expansion counter
1219 expansion_counter: list[int] = [0]
1220
1221 # Recursively expand the graph starting at top level
1222 expanded_graph, expansion_errors = _expand_graph_recursive(
1223 graph,
1224 macro_table,
1225 expansion_counter,
1226 builtin_line_offset=graph.builtin_line_offset,
1227 )
1228
1229 # Scan function regions to find all @ret destinations, create synthetic nodes, and track them
1230 synthetic_ret_nodes = {} # Map of synthetic_node_name -> IRNode
1231 function_ret_destinations = {} # Map of func_name -> set of @ret destinations
1232 new_regions = []
1233 for region in expanded_graph.regions:
1234 if region.kind == RegionKind.FUNCTION:
1235 # Find all @ret destinations in function body edges
1236 ret_destinations = set()
1237 for edge in region.body.edges:
1238 if isinstance(edge.dest, str) and edge.dest.startswith("@ret"):
1239 ret_destinations.add(edge.dest)
1240
1241 # Store the destinations for later use by _wire_call_site
1242 function_ret_destinations[region.tag] = ret_destinations
1243
1244 # Create synthetic pass nodes for each @ret destination
1245 for ret_dest in ret_destinations:
1246 synthetic_node_name = f"{region.tag}.{ret_dest}"
1247 if synthetic_node_name not in synthetic_ret_nodes:
1248 synthetic_node = IRNode(
1249 name=synthetic_node_name,
1250 opcode=RoutingOp.PASS,
1251 )
1252 synthetic_ret_nodes[synthetic_node_name] = synthetic_node
1253
1254 # Update edges in function body to point to synthetic @ret nodes
1255 new_body_edges = []
1256 for edge in region.body.edges:
1257 new_dest = edge.dest
1258 # If destination starts with @ret, replace with synthetic node
1259 if isinstance(edge.dest, str) and edge.dest.startswith("@ret"):
1260 synthetic_node_name = f"{region.tag}.{edge.dest}"
1261 new_dest = synthetic_node_name
1262 new_body_edges.append(replace(edge, dest=new_dest))
1263
1264 new_body = replace(region.body, edges=new_body_edges)
1265 new_region = replace(region, body=new_body)
1266 new_regions.append(new_region)
1267 else:
1268 new_regions.append(region)
1269
1270 expanded_graph = replace(expanded_graph, regions=new_regions)
1271
1272 # Add synthetic nodes to the top-level graph
1273 wired_nodes = dict(expanded_graph.nodes)
1274 wired_nodes.update(synthetic_ret_nodes)
1275
1276 # Process function call sites
1277 wired_call_sites = []
1278 call_site_errors = []
1279 wired_edges = list(expanded_graph.edges)
1280 processed_ret_nodes = set(synthetic_ret_nodes.keys()) # Track which synthetic nodes were created
1281
1282 call_id_counter = 0
1283 for call_site_result in expanded_graph.raw_call_sites:
1284 call_site_metadata, errors = _wire_call_site(
1285 call_site_result,
1286 expanded_graph,
1287 call_id_counter,
1288 wired_nodes,
1289 wired_edges,
1290 processed_ret_nodes,
1291 function_ret_destinations,
1292 )
1293 wired_call_sites.append(call_site_metadata)
1294 call_site_errors.extend(errors)
1295 call_id_counter += 1
1296
1297 # Create final graph with wired call sites
1298 final_graph = replace(
1299 expanded_graph,
1300 nodes=wired_nodes,
1301 edges=wired_edges,
1302 call_sites=wired_call_sites,
1303 raw_call_sites=(), # Clear raw call sites after processing
1304 )
1305
1306 # Accumulate all errors
1307 all_errors = list(graph.errors) + expansion_errors + call_site_errors
1308
1309 # Return with error list updated
1310 return replace(final_graph, errors=all_errors)
1311
1312
1313__all__ = ["expand"]