"""Lower pass: Convert Lark CST to IR graph. This module implements a Lark Transformer that converts a parse tree from the dfasm grammar into an IRGraph. The transformer handles: - Instruction definitions and node creation - Plain, strong, and weak edge routing - Function and location regions - Data definitions - System configuration pragmas - Name qualification (scoping) - Error collection for reserved names and duplicates """ from typing import Any, Optional, Union, Tuple, List, Dict from dataclasses import replace import re from lark import Transformer, v_args, Tree from lark.lexer import Token as LarkToken from asm.ir import ( IRGraph, IRNode, IREdge, IRRegion, RegionKind, IRDataDef, SystemConfig, SourceLoc, NameRef, ResolvedDest, MacroParam, ParamRef, MacroDef, IRMacroCall, CallSiteResult, IRRepetitionBlock, PlacementRef, PortRef, ActSlotRef, ActSlotRange, ) from asm.errors import AssemblyError, ErrorCategory from asm.opcodes import MNEMONIC_TO_OP from cm_inst import ALUOp, MemOp, Port, RoutingOp # Reserved names that cannot be used as node definitions _RESERVED_NAMES = frozenset({"@system", "@io", "@debug"}) # Pattern for detecting ${param} token pasting in identifiers _PASTE_PATTERN = re.compile(r'^(.*?)\$\{([a-zA-Z_][a-zA-Z0-9_]*)\}(.*)$') def _filter_args(args: tuple) -> list: """Filter out LarkTokens from argument list.""" return [arg for arg in args if not isinstance(arg, LarkToken)] def _normalize_port(value: Union[int, Port, PortRef]) -> Union[Port, PortRef]: """Normalize a port value to Port enum, preserving PortRef for macro templates. Args: value: An int (0/1), Port enum, or PortRef (macro parameter) Returns: Port enum value, or PortRef passed through for later expansion """ if isinstance(value, PortRef): return value if isinstance(value, Port): return value if isinstance(value, int): return Port.L if value == 0 else (Port.R if value == 1 else Port.L) return Port.L # Structured statement result types class StatementResult: """Base class for statement processing results.""" pass class NodeResult(StatementResult): """Result from inst_def: one or more IRNodes.""" def __init__(self, nodes: Dict[str, IRNode]): self.nodes = nodes class EdgeResult(StatementResult): """Result from plain_edge or anonymous edges: IREdges.""" def __init__(self, edges: List[IREdge]): self.edges = edges class FunctionResult(StatementResult): """Result from func_def: an IRRegion.""" def __init__(self, region: IRRegion): self.region = region class LocationResult(StatementResult): """Result from location_dir: an IRRegion.""" def __init__(self, region: IRRegion): self.region = region class DataDefResult(StatementResult): """Result from data_def: IRDataDefs.""" def __init__(self, data_defs: List[IRDataDef]): self.data_defs = data_defs class MacroDefResult(StatementResult): """Result from macro_def: a MacroDef.""" def __init__(self, macro_def: MacroDef): self.macro_def = macro_def class MacroCallResult(StatementResult): """Result from macro_call_stmt: an IRMacroCall.""" def __init__(self, macro_call: IRMacroCall): self.macro_call = macro_call class CallSiteResultStatement(StatementResult): """Result from call_stmt: a CallSiteResult.""" def __init__(self, call_site_result: CallSiteResult): self.call_site_result = call_site_result class RepetitionBlockResult(StatementResult): """Result from repetition_block: an IRRepetitionBlock.""" def __init__(self, repetition_block: IRRepetitionBlock): self.repetition_block = repetition_block class CompositeResult(StatementResult): """Result combining nodes and edges (for strong/weak edges).""" def __init__(self, nodes: Dict[str, IRNode], edges: List[IREdge]): self.nodes = nodes self.edges = edges class LowerTransformer(Transformer): """Transformer that converts a CST into an IRGraph. The transformer collects statement results and then in the `start` rule organizes them into the final IRGraph structure. """ def __init__(self): super().__init__() self._anon_counter: int = 0 self._errors: list[AssemblyError] = [] self._defined_names: dict[str, SourceLoc] = {} self._system: Optional[SystemConfig] = None def _qualify_name(self, name, func_scope: Optional[str]): """Apply function scope qualification to a name. ParamRef values pass through unchanged — they are resolved during macro expansion, not during lowering. """ if isinstance(name, ParamRef): return name if isinstance(name, str) and name.startswith("&") and func_scope: return f"{func_scope}.{name}" return name def _extract_loc(self, meta: Any) -> SourceLoc: """Extract SourceLoc from Lark's meta object.""" return SourceLoc( line=meta.line, column=meta.column, end_line=meta.end_line if hasattr(meta, "end_line") else None, end_column=meta.end_column if hasattr(meta, "end_column") else None, ) def _gen_anon_name(self, func_scope: Optional[str]) -> str: """Generate an anonymous node name, qualified by current scope.""" name = f"&__anon_{self._anon_counter}" self._anon_counter += 1 return self._qualify_name(name, func_scope) def _check_reserved_name(self, name: str, loc: SourceLoc) -> bool: """Check if name is reserved. Return True if reserved (and add error).""" if name in _RESERVED_NAMES: self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.NAME, message=f"Reserved name '{name}' cannot be used as a node definition" )) return True return False def _check_duplicate_name(self, name: str, loc: SourceLoc) -> bool: """Check for duplicate definition. Return True if duplicate (and add error).""" if name in self._defined_names: prev_loc = self._defined_names[name] self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.SCOPE, message=f"Duplicate label '{name}'", suggestions=[f"First defined at line {prev_loc.line}"] )) return True self._defined_names[name] = loc return False def _process_statements( self, statements: list, func_scope: Optional[str] = None ) -> Tuple[Dict[str, IRNode], List[IREdge], List[IRRegion], List[IRDataDef], List]: """Process a list of statement results and collect them into containers.""" nodes = {} edges = [] regions = [] data_defs = [] call_sites = [] # Reset defined names for this scope prev_defined_names = self._defined_names self._defined_names = {} for stmt in statements: if isinstance(stmt, NodeResult): # Qualify and add nodes for node_name, node in stmt.nodes.items(): qualified_name = self._qualify_name(node_name, func_scope) if not self._check_duplicate_name(qualified_name, node.loc): # Update node with qualified name qualified_node = replace(node, name=qualified_name) nodes[qualified_name] = qualified_node elif isinstance(stmt, EdgeResult): # Qualify and add edges for edge in stmt.edges: qualified_edge = replace( edge, source=self._qualify_name(edge.source, func_scope), dest=self._qualify_name(edge.dest, func_scope), ) edges.append(qualified_edge) elif isinstance(stmt, CompositeResult): # Composite: both nodes and edges (strong/weak edges) for node_name, node in stmt.nodes.items(): qualified_name = self._qualify_name(node_name, func_scope) if not self._check_duplicate_name(qualified_name, node.loc): qualified_node = replace(node, name=qualified_name) nodes[qualified_name] = qualified_node for edge in stmt.edges: qualified_edge = replace( edge, source=self._qualify_name(edge.source, func_scope), dest=self._qualify_name(edge.dest, func_scope), ) edges.append(qualified_edge) elif isinstance(stmt, FunctionResult): regions.append(stmt.region) elif isinstance(stmt, LocationResult): regions.append(stmt.region) elif isinstance(stmt, DataDefResult): data_defs.extend(stmt.data_defs) elif isinstance(stmt, MacroDefResult): # Macro definitions are stored separately, not as regions pass # Collected at the start() level elif isinstance(stmt, MacroCallResult): # Macro calls are stored separately pass # Collected at the start() level elif isinstance(stmt, CallSiteResultStatement): # Call sites are stored separately call_sites.append(stmt.call_site_result) # Restore defined names self._defined_names = prev_defined_names return nodes, edges, regions, data_defs, call_sites def start(self, items: list) -> IRGraph: """Process the entire program and return an IRGraph. Post-processing: Groups statements following location_dir into that region's body. """ # First pass: collect all items nodes, edges, regions, data_defs, call_sites = self._process_statements(items, None) # Second pass: post-process location regions to collect subsequent statements # Find LocationResult objects and collect subsequent statements into their body location_results = [r for r in regions if r.kind == RegionKind.LOCATION] # Track which nodes, data_defs, and edges are moved into location regions moved_node_names = set() moved_data_names = set() moved_edge_sources = set() # Track edges by (source, dest) tuple if location_results: # Build a mapping of location regions to their collected body for loc_region in location_results: # Find the position of this region in the items list # by matching the tag body_nodes = {} body_edges = [] body_data_defs = [] # Collect subsequent non-region statements collecting = False for item in items: if isinstance(item, LocationResult) and item.region.tag == loc_region.tag: collecting = True continue if collecting: # Stop at next region boundary if isinstance(item, (FunctionResult, LocationResult)): break # Collect into location body if isinstance(item, NodeResult): body_nodes.update(item.nodes) moved_node_names.update(item.nodes.keys()) elif isinstance(item, EdgeResult): body_edges.extend(item.edges) moved_edge_sources.update((e.source, e.dest) for e in item.edges) elif isinstance(item, DataDefResult): body_data_defs.extend(item.data_defs) moved_data_names.update(d.name for d in item.data_defs) elif isinstance(item, CompositeResult): body_nodes.update(item.nodes) moved_node_names.update(item.nodes.keys()) body_edges.extend(item.edges) moved_edge_sources.update((e.source, e.dest) for e in item.edges) # Update the location region with collected body if body_nodes or body_edges or body_data_defs: new_body = IRGraph( nodes=body_nodes, edges=body_edges, regions=[], data_defs=body_data_defs, ) # Find and replace this region in the regions list regions = [ IRRegion( tag=r.tag, kind=r.kind, body=new_body if r.tag == loc_region.tag else r.body, loc=r.loc, ) for r in regions ] # Remove items that were moved into location regions from top-level containers nodes = {k: v for k, v in nodes.items() if k not in moved_node_names} data_defs = [d for d in data_defs if d.name not in moved_data_names] edges = [e for e in edges if (e.source, e.dest) not in moved_edge_sources] # Collect macro definitions and calls from items macro_defs = [] macro_calls = [] for item in items: if isinstance(item, MacroDefResult): macro_defs.append(item.macro_def) elif isinstance(item, MacroCallResult): macro_calls.append(item.macro_call) return IRGraph( nodes=nodes, edges=edges, regions=regions, data_defs=data_defs, system=self._system, errors=self._errors, macro_defs=macro_defs, macro_calls=macro_calls, raw_call_sites=tuple(call_sites), ) @v_args(inline=True) def inline_const(self, value) -> Union[int, ParamRef]: """Parse inline constant (space-separated, e.g., 'add 7' or '${param}').""" if isinstance(value, ParamRef): return value return int(str(value), 0) @v_args(inline=True, meta=True) def inst_def(self, meta, *args) -> StatementResult: """Process instruction definition.""" loc = self._extract_loc(meta) # Filter out tokens (FLOW_IN, etc.) - keep only transformed results args_list = _filter_args(args) # First arg is qualified_ref_dict, second is opcode, rest are arguments qualified_ref_dict = args_list[0] opcode = args_list[1] remaining_args = args_list[2:] if len(args_list) > 2 else [] # Extract name (will be qualified later in _process_statements) name = qualified_ref_dict["name"] # Check reserved names if self._check_reserved_name(name, loc): return NodeResult({}) # If opcode is None (invalid), skip node creation (error already added) if opcode is None: return NodeResult({}) # Extract placement (PE qualifier) pe = None if "placement" in qualified_ref_dict and qualified_ref_dict["placement"]: placement_val = qualified_ref_dict["placement"] if isinstance(placement_val, PlacementRef): pe = placement_val elif isinstance(placement_val, str) and placement_val.startswith("pe"): try: pe = int(placement_val[2:]) except ValueError: pass # Extract activation slot qualifier act_slot = qualified_ref_dict.get("act_slot") # Extract const and named args from arguments # Check if first remaining arg is an inline_const (int directly after opcode) const = None args_dict = {} positional_count = 0 for arg in remaining_args: if isinstance(arg, tuple): # named_arg arg_name, arg_value = arg args_dict[arg_name] = arg_value else: # positional argument if positional_count == 0: if isinstance(arg, dict) and isinstance(arg.get("name"), ParamRef): const = arg["name"] elif not isinstance(arg, dict): const = arg positional_count += 1 # Create IRNode node = IRNode( name=name, opcode=opcode, dest_l=None, dest_r=None, const=const, pe=pe, act_slot=act_slot, loc=loc, args=args_dict if args_dict else None, ) return NodeResult({name: node}) @v_args(inline=True, meta=True) def plain_edge(self, meta, *args) -> StatementResult: """Process plain edge (wiring between named nodes). The source's port (if specified) becomes source_port (output slot). The dest's port (if specified) becomes port (input port), defaulting to L. """ loc = self._extract_loc(meta) args_list = _filter_args(args) source_dict = args_list[0] dest_list = args_list[1] source_name = source_dict["name"] # Source port is from the source's port specification source_port = source_dict.get("port") if "port" in source_dict else None # Normalize source_port to Port if it's a raw int (convert 0→L, 1→R) if source_port is not None: source_port = _normalize_port(source_port) edges = [] for dest_dict in dest_list: dest_name = dest_dict["name"] # Dest port is from the dest's port specification, defaults to L raw_port = dest_dict.get("port") port_explicit = raw_port is not None if raw_port is None: dest_port = Port.L else: dest_port = _normalize_port(raw_port) edge = IREdge( source=source_name, dest=dest_name, port=dest_port, source_port=source_port, port_explicit=port_explicit, loc=loc, ) edges.append(edge) return EdgeResult(edges) def _wire_anonymous_node( self, opcode: Union[ALUOp, MemOp], inputs: list, outputs: list, loc: SourceLoc, const_value: Optional[int] = None, is_seed: bool = False, ) -> StatementResult: """Wire inputs and outputs for an anonymous edge node. Generates the IRNode for an anonymous edge and all associated edges (both input and output wiring). This logic is shared between strong_edge and weak_edge, which differ only in how they parse their arguments. Args: opcode: The instruction opcode inputs: List of input reference dicts with "name" and optional "port" outputs: List of output reference dicts with "name" and optional "port" loc: Source location for error reporting const_value: Optional constant value for the node is_seed: If True, mark the node as a seed (no IRAM slot, emits seed token) Returns: CompositeResult with anonymous node and all input/output edges """ # Generate anonymous node (not qualified yet) anon_name = f"&__anon_{self._anon_counter}" self._anon_counter += 1 # Create anonymous IRNode anon_node = IRNode( name=anon_name, opcode=opcode, const=const_value, loc=loc, seed=is_seed, ) # Wire inputs: first input → Port.L, second → Port.R edges = [] for idx, input_arg in enumerate(inputs): if isinstance(input_arg, dict) and "name" in input_arg: # It's a qualified_ref input_name = input_arg["name"] input_port = Port.L if idx == 0 else Port.R edge = IREdge( source=input_name, dest=anon_name, port=input_port, source_port=None, loc=loc, ) edges.append(edge) # Wire outputs for output_dict in outputs: output_name = output_dict["name"] raw_port = output_dict.get("port") out_port_explicit = raw_port is not None if raw_port is None: output_port = Port.L else: output_port = _normalize_port(raw_port) edge = IREdge( source=anon_name, dest=output_name, port=output_port, source_port=None, port_explicit=out_port_explicit, loc=loc, ) edges.append(edge) # Return both the node and edges return CompositeResult({anon_name: anon_node}, edges) @v_args(inline=True, meta=True) def strong_edge(self, meta, *args) -> StatementResult: """Process strong inline edge (anonymous node with inputs and outputs). Syntax: opcode input [, input ...] |> output [, output ...] Special case: `const N |> &dest` creates a seed node — a CONST node that emits a seed token at startup without occupying an IRAM slot. """ loc = self._extract_loc(meta) args_list = _filter_args(args) opcode = args_list[0] remaining_args = args_list[1:] # If opcode is None (invalid), skip edge creation (error already added) if opcode is None: return CompositeResult({}, []) # Split arguments into inputs and outputs inputs = [] outputs = [] processing_outputs = False const_value = None for arg in remaining_args: if isinstance(arg, list): # This is ref_list processing_outputs = True outputs = arg elif not processing_outputs: if isinstance(arg, int): const_value = arg else: inputs.append(arg) # Detect seed pattern: `const N |> &dest` is_seed = ( isinstance(opcode, RoutingOp) and opcode == RoutingOp.CONST and const_value is not None and len(inputs) == 0 ) # Wire the anonymous node and its edges return self._wire_anonymous_node(opcode, inputs, outputs, loc, const_value=const_value, is_seed=is_seed) @v_args(inline=True, meta=True) def weak_edge(self, meta, *args) -> StatementResult: """Process weak inline edge (outputs then opcode then inputs). Syntax: outputs... opcode inputs... Semantically identical to strong_edge but syntactically reversed. """ loc = self._extract_loc(meta) args_list = _filter_args(args) output_list = args_list[0] opcode = args_list[1] remaining_args = args_list[2:] if len(args_list) > 2 else [] # If opcode is None (invalid), skip edge creation (error already added) if opcode is None: return CompositeResult({}, []) inputs = list(remaining_args) outputs = output_list # Wire the anonymous node and its edges return self._wire_anonymous_node(opcode, inputs, outputs, loc) def func_def(self, args: list) -> StatementResult: """Process function definition (region with nested scope).""" # Without v_args decorator, args come as a list with LarkToken terminals mixed in # Filter out tokens and extract the actual data args_list = _filter_args(args) # args[0] is func_ref dict, rest are statement results func_ref_dict = args_list[0] if args_list else {} func_name = func_ref_dict.get("name", "$unknown") if isinstance(func_ref_dict, dict) else "$unknown" statement_results = args_list[1:] if len(args_list) > 1 else [] # Try to extract location from the raw args (may have meta on Tree nodes) loc = SourceLoc(0, 0) for arg in args: if hasattr(arg, 'meta'): try: loc = self._extract_loc(arg.meta) break except (AttributeError, TypeError): pass # Process the statements with the function scope func_nodes, func_edges, func_regions, func_data_defs, func_call_sites = self._process_statements( statement_results, func_scope=func_name ) # Collect macro_calls from function body statements func_macro_calls = [] for stmt in statement_results: if isinstance(stmt, MacroCallResult): func_macro_calls.append(stmt.macro_call) # Create IRRegion for the function body_graph = IRGraph( nodes=func_nodes, edges=func_edges, regions=func_regions, data_defs=func_data_defs, macro_calls=func_macro_calls, raw_call_sites=tuple(func_call_sites), ) region = IRRegion( tag=func_name, kind=RegionKind.FUNCTION, body=body_graph, loc=loc, ) return FunctionResult(region) def _apply_paste_patterns(self, body: IRGraph) -> IRGraph: """Post-process macro body to replace ${param} patterns with ParamRef. Scans all node names and edge endpoints in the body for ${param} patterns and constructs ParamRef instances with appropriate prefix/suffix fields. This post-processing approach avoids the bottom-up traversal issue where Lark processes node_ref/label_ref terminals before macro_def is invoked. Args: body: The constructed IRGraph from macro body processing Returns: New IRGraph with all ${param} patterns replaced by ParamRef instances """ # Process all nodes to replace ${param} patterns in their names new_nodes = {} for node_name, node in body.nodes.items(): match = _PASTE_PATTERN.match(node.name) if match: # Node name contains ${param} pattern new_name = ParamRef( param=match.group(2), prefix=match.group(1), suffix=match.group(3), ) new_nodes[node_name] = replace(node, name=new_name) else: new_nodes[node_name] = node # Process all edges to replace ${param} patterns in source/dest new_edges = [] for edge in body.edges: new_source = edge.source new_dest = edge.dest # Check source for pattern if isinstance(edge.source, str): match = _PASTE_PATTERN.match(edge.source) if match: new_source = ParamRef( param=match.group(2), prefix=match.group(1), suffix=match.group(3), ) # Check dest for pattern if isinstance(edge.dest, str): match = _PASTE_PATTERN.match(edge.dest) if match: new_dest = ParamRef( param=match.group(2), prefix=match.group(1), suffix=match.group(3), ) # Add edge with potential replacements if new_source != edge.source or new_dest != edge.dest: new_edges.append(replace(edge, source=new_source, dest=new_dest)) else: new_edges.append(edge) # Return new IRGraph with updated nodes and edges return replace(body, nodes=new_nodes, edges=new_edges) @v_args(meta=True) def macro_def(self, meta, args: list) -> StatementResult: """Process macro definition (template with parameters). Uses @v_args(meta=True) to receive source location metadata. """ # Extract macro name from first IDENT terminal (before filtering) macro_name = "unknown" for arg in args: if isinstance(arg, LarkToken): macro_name = str(arg) break # Extract location from meta loc = self._extract_loc(meta) # Check for reserved name (starts with "ret") if macro_name.startswith("ret"): self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.NAME, message=f"Macro name '#{macro_name}' uses reserved prefix 'ret'", )) return MacroDefResult(MacroDef(name=macro_name, params=(), body=IRGraph(), loc=loc)) # Separate params from body statements params: list[MacroParam] = [] statement_results: list = [] variadic_param_name: Optional[str] = None for item in args: if isinstance(item, list) and all(isinstance(p, tuple) and len(p) == 2 for p in item): # This is the macro_params result (list of (name, variadic) tuples) seen_names: set[str] = set() for param_name, is_variadic in item: if param_name in seen_names: self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.NAME, message=f"Duplicate parameter name '{param_name}' in macro '#{macro_name}'", )) else: seen_names.add(param_name) if is_variadic: # Validate: variadic param must be last if variadic_param_name is not None: self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.NAME, message=f"Multiple variadic parameters in macro '#{macro_name}' (only one allowed)", )) variadic_param_name = param_name elif variadic_param_name is not None: # Non-variadic param after variadic param self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.NAME, message=f"Variadic parameter must be last in macro '#{macro_name}'", )) params.append(MacroParam(name=param_name, variadic=is_variadic)) elif isinstance(item, StatementResult): statement_results.append(item) # Process body statements (no function scope — macros don't create ctx scopes) body_nodes, body_edges, body_regions, body_data_defs, body_call_sites = self._process_statements( statement_results, func_scope=None ) # Collect macro_calls and repetition_blocks from body statements body_macro_calls = [] repetition_blocks = [] for stmt in statement_results: if isinstance(stmt, MacroCallResult): body_macro_calls.append(stmt.macro_call) elif isinstance(stmt, RepetitionBlockResult): # Update variadic_param in the repetition block if we have a variadic param rep_block = stmt.repetition_block if variadic_param_name and rep_block.variadic_param == "": # Replace the placeholder with the actual variadic param name rep_block = replace(rep_block, variadic_param=variadic_param_name) repetition_blocks.append(rep_block) body = IRGraph( nodes=body_nodes, edges=body_edges, regions=body_regions, data_defs=body_data_defs, macro_calls=body_macro_calls, raw_call_sites=tuple(body_call_sites), ) # Post-process to apply ${param} token pasting patterns body = self._apply_paste_patterns(body) macro = MacroDef( name=macro_name, params=tuple(params), body=body, repetition_blocks=repetition_blocks, loc=loc, ) return MacroDefResult(macro) def macro_params(self, args: list) -> list[tuple]: """Process macro parameter list. Returns list of (name, variadic) tuples. Note: Comma tokens and other non-tuple/string types from the grammar are silently skipped during iteration. """ result = [] for arg in args: if isinstance(arg, tuple): # From macro_param rule (variadic_param or regular_param) result.append(arg) elif isinstance(arg, str): # Fallback for simple string params result.append((arg, False)) # Other token types (commas) are silently skipped return result def variadic_param(self, args: list) -> tuple: """Process a variadic macro parameter (*name). Returns (name, True) tuple. """ # args will be [VARIADIC_token, IDENT_token] # IDENT is always the last token per the grammar rule name = str(args[-1]) return (name, True) def regular_param(self, args: list) -> tuple: """Process a regular macro parameter (name). Returns (name, False) tuple. """ # args will be [IDENT_token] if args: name = str(args[0].value if hasattr(args[0], 'value') else args[0]) else: name = "unknown" return (name, False) @v_args(meta=True) def repetition_block(self, meta, args: list) -> StatementResult: """Process repetition block: $( body ),*. The repetition block syntax within macro bodies will be expanded in the expand pass. Here we collect the body as an IRGraph. Creates an IRRepetitionBlock with an empty string placeholder for variadic_param. The placeholder will be resolved during macro_def processing by matching against the macro's actual variadic parameter. """ loc = self._extract_loc(meta) # Filter statement results from args statement_results = [arg for arg in args if isinstance(arg, StatementResult)] # Process body statements body_nodes, body_edges, body_regions, body_data_defs, body_call_sites = self._process_statements( statement_results, func_scope=None ) body = IRGraph( nodes=body_nodes, edges=body_edges, regions=body_regions, data_defs=body_data_defs, raw_call_sites=tuple(body_call_sites), ) # Apply token pasting patterns to the body body = self._apply_paste_patterns(body) # Create a placeholder IRRepetitionBlock # The variadic_param will be resolved in the expand pass # For now, use empty string as a placeholder rep_block = IRRepetitionBlock( body=body, variadic_param="", # Placeholder, resolved in expand pass loc=loc, ) return RepetitionBlockResult(rep_block) @v_args(meta=True) def macro_call_stmt(self, meta, args: list) -> StatementResult: """Process standalone macro invocation.""" loc = self._extract_loc(meta) # Extract macro name from first IDENT terminal macro_name = "unknown" for arg in args: if isinstance(arg, LarkToken): macro_name = str(arg) break positional_args = [] named_args: dict[str, object] = {} output_dests = () found_name = False for item in args: if isinstance(item, LarkToken): if not found_name: # First LarkToken is the macro name found_name = True continue if item.type in ("OPCODE", "IDENT"): # Bare opcode or identifier as macro argument — wrap as string positional_args.append(str(item)) continue # Skip other tokens (FLOW_OUT, commas, etc.) continue elif isinstance(item, list) and all(isinstance(x, dict) for x in item): # call_output_list result — list of output dest dicts output_dests = tuple(item) elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], str): # Named argument from named_arg rule (name, value) named_args[item[0]] = item[1] elif isinstance(item, dict) and "name" in item: # Positional argument (qualified_ref or value) positional_args.append(item) elif item is not None: # Other argument types (int literals, etc.) positional_args.append(item) macro_call = IRMacroCall( name=macro_name, positional_args=tuple(positional_args), named_args=tuple(named_args.items()), output_dests=output_dests, loc=loc, ) return MacroCallResult(macro_call) @v_args(meta=True) def call_stmt(self, meta, args: list) -> StatementResult: """Process function call statement. The call_stmt grammar rule is: call_stmt: func_ref argument ("," argument)* FLOW_OUT call_output_list Args are: [func_ref_dict, arg1, arg2, ..., call_output_list] """ loc = self._extract_loc(meta) # Filter out LarkTokens (FLOW_OUT) args_list = _filter_args(args) if not args_list: self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.PARSE, message="call_stmt requires function name and arguments" )) return CallSiteResultStatement(CallSiteResult( func_name="$unknown", input_args=(), output_dests=(), loc=loc, )) # First arg is func_ref dict func_ref_dict = args_list[0] func_name = func_ref_dict.get("name", "$unknown") # Process remaining args: arguments come before output_dests # We need to find where call_output_list starts (it's a list of dicts/named outputs) input_args = [] output_dests = [] for i, item in enumerate(args_list[1:], start=1): if isinstance(item, list): # This is call_output_list result — flatten into output_dests output_dests.extend(item) elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[0], str): # named_arg: (name, value) input_args.append(item) elif isinstance(item, dict) and "name" in item: # positional_arg (qualified_ref) input_args.append((None, item)) # Store as (None, ref_dict) for positional elif isinstance(item, int): # literal value input_args.append((None, item)) else: # Fallback: treat as positional value input_args.append((None, item)) call_site = CallSiteResult( func_name=func_name, input_args=tuple(input_args), output_dests=tuple(output_dests), loc=loc, ) return CallSiteResultStatement(call_site) def call_output_list(self, args: list) -> list: """Process call output list — returns list of output dests.""" return [a for a in args if a is not None] @v_args(inline=True) def named_output(self, name_tok, ref) -> dict: """Process named output: name=@dest. Returns {"name": str, "ref": ref_dict} so the expand pass can map @ret_name return markers to the specified call-site destination. """ # name_tok could be a LarkToken if isinstance(name_tok, LarkToken): name_str = str(name_tok) else: name_str = name_tok return {"name": name_str, "ref": ref} @v_args(inline=True) def positional_output(self, ref) -> dict: """Process positional output: bare @dest or &ref.""" return ref def macro_ref(self, args: list) -> dict: """Process macro reference (#name).""" token = args[0] return {"name": f"#{token}"} def scoped_ref(self, args: list) -> dict: """Process dot-notation scope reference ($func.&label or #macro.&label).""" args_list = _filter_args(args) scope_dict = args_list[0] # func_ref or macro_ref dict inner_dict = args_list[1] # label_ref or node_ref dict scope_name = scope_dict["name"] inner_name = inner_dict["name"] return {"name": f"{scope_name}.{inner_name}"} @v_args(inline=True, meta=True) def data_def(self, meta, *args) -> StatementResult: """Process data definition.""" loc = self._extract_loc(meta) args_list = _filter_args(args) qualified_ref_dict = args_list[0] value_data = args_list[1] if len(args_list) > 1 else None name = qualified_ref_dict["name"] # Extract SM ID from placement sm_id = None if "placement" in qualified_ref_dict and qualified_ref_dict["placement"]: placement_val = qualified_ref_dict["placement"] if isinstance(placement_val, str) and placement_val.startswith("sm"): try: sm_id = int(placement_val[2:]) except ValueError: pass # Extract cell address from port # The port value from qualified_ref can be: # - Port.L/Port.R (for plain edge context) # - raw int (for data_def context, e.g., :0, :1, :2, etc.) cell_addr = None if "port" in qualified_ref_dict and qualified_ref_dict["port"] is not None: port_val = qualified_ref_dict["port"] # Extract the numeric value regardless of type if isinstance(port_val, Port): cell_addr = int(port_val) elif isinstance(port_val, int): cell_addr = port_val # Handle value_data value = 0 if isinstance(value_data, list): # value_list: pack values if all(isinstance(v, int) for v in value_data): # Integer values or char values if len(value_data) == 1: value = value_data[0] else: # Multiple values: only valid if all are bytes (0-255) if any(v > 255 for v in value_data): self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.VALUE, message=f"Multi-value data definition cannot contain values > 255. " f"Data defs support either a single 16-bit value OR multiple byte-values packed into one word.", )) value = value_data[0] # Use first value as fallback else: # All bytes: take the already-packed value from value_list value = value_data[0] # value_list already packs consecutive pairs else: value = value_data data_def = IRDataDef( name=name, sm_id=sm_id, cell_addr=cell_addr, value=value, loc=loc, ) return DataDefResult([data_def]) @v_args(inline=True, meta=True) def location_dir(self, meta, *args) -> StatementResult: """Process location directive.""" loc = self._extract_loc(meta) args_list = _filter_args(args) qualified_ref_dict = args_list[0] tag = qualified_ref_dict["name"] # Create region for location region = IRRegion( tag=tag, kind=RegionKind.LOCATION, body=IRGraph(), loc=loc, ) return LocationResult(region) @v_args(inline=True, meta=True) def system_pragma(self, meta, *params) -> Optional[StatementResult]: """Process @system pragma.""" loc = self._extract_loc(meta) # Filter out tokens params_list = _filter_args(params) # Check for duplicate @system pragma if self._system is not None: self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.PARSE, message="Duplicate @system pragma", )) return None # params are (name, value) tuples from system_param config_dict = {} for param_tuple in params_list: if isinstance(param_tuple, tuple): param_name, param_value = param_tuple config_dict[param_name] = param_value # Map parameter names pe_count = config_dict.get("pe") sm_count = config_dict.get("sm") iram_capacity = config_dict.get("iram", 256) frame_count = config_dict.get("frames", 8) if pe_count is None or sm_count is None: self._errors.append(AssemblyError( loc=loc, category=ErrorCategory.PARSE, message="@system pragma requires at least 'pe' and 'sm' parameters", )) return None self._system = SystemConfig( pe_count=pe_count, sm_count=sm_count, iram_capacity=iram_capacity, frame_count=frame_count, loc=loc, ) return None # Don't return a StatementResult for pragmas @v_args(inline=True) def system_param(self, param_name: LarkToken, value) -> tuple[str, int]: """Process @system parameter.""" # value can be a token (DEC_LIT or HEX_LIT) or already an int if isinstance(value, LarkToken): value = int(str(value), 0) # 0 base handles both decimal and 0x hex return (str(param_name), value) @v_args(inline=True) def opcode(self, token) -> Optional[Union[ALUOp, MemOp, ParamRef]]: """Map opcode token to ALUOp/MemOp enum, ParamRef, or None if invalid.""" if isinstance(token, ParamRef): return token mnemonic = str(token) if mnemonic not in MNEMONIC_TO_OP: self._errors.append(AssemblyError( loc=SourceLoc(line=token.line, column=token.column), category=ErrorCategory.PARSE, message=f"Unknown opcode '{mnemonic}'", )) return None return MNEMONIC_TO_OP[mnemonic] @v_args(inline=True) def qualified_ref(self, *args) -> dict: """Collect qualified reference components into a dict.""" ref_type = None placement = None act_slot = None port = None for arg in args: if isinstance(arg, PlacementRef): placement = arg elif isinstance(arg, PortRef): port = arg elif isinstance(arg, (ActSlotRef, ActSlotRange)): act_slot = arg elif isinstance(arg, (Port, int)): port = arg elif isinstance(arg, ParamRef): ref_type = {"name": arg} elif isinstance(arg, dict): ref_type = arg elif isinstance(arg, str) and (arg.startswith("pe") or arg.startswith("sm")): placement = arg result = ref_type.copy() if ref_type else {} if placement is not None: result["placement"] = placement if act_slot is not None: result["act_slot"] = act_slot if port is not None: result["port"] = port return result @v_args(inline=True) def node_ref(self, token: LarkToken) -> dict: """Process @name reference.""" return {"name": f"@{token}"} @v_args(inline=True) def label_ref(self, token: LarkToken) -> dict: """Process &name reference.""" return {"name": f"&{token}"} @v_args(inline=True) def func_ref(self, token: LarkToken) -> dict: """Process $name reference.""" return {"name": f"${token}"} def param_ref(self, args: list) -> Union[ParamRef, dict]: """Process ${name} macro parameter reference. Returns ParamRef directly. When used in qualified_ref context, the qualified_ref handler wraps it in a dict. """ name = str(args[-1]) return ParamRef(param=name) @v_args(inline=True) def placement(self, token) -> Union[str, PlacementRef]: """Extract placement specifier.""" if isinstance(token, ParamRef): return PlacementRef(param=token) return str(token) def ctx_slot(self, args: list): """Extract context slot specifier. Always returns a typed wrapper (ActSlotRef, ActSlotRange) so qualified_ref can distinguish ctx_slot ints from port ints. """ if len(args) == 1: arg = args[0] if isinstance(arg, ParamRef): return ActSlotRef(param=arg) if isinstance(arg, ActSlotRange): return arg n = int(str(arg)) return ActSlotRange(start=n, end=n) return args[0] def ctx_range(self, args: list) -> ActSlotRange: """Extract context slot range (start..end).""" return ActSlotRange(start=int(str(args[0])), end=int(str(args[1]))) @v_args(inline=True) def port(self, token) -> Union[Port, int, PortRef]: """Convert port specifier to Port enum, raw int, or PortRef. Returns: Port.L for "L" Port.R for "R" Raw int for numeric values (e.g., cell address in data_def) PortRef for param_ref """ if isinstance(token, ParamRef): return PortRef(param=token) spec = str(token) if spec == "L": return Port.L elif spec == "R": return Port.R else: try: return int(spec) except ValueError: return Port.L @v_args(inline=True) def hex_literal(self, token: LarkToken) -> int: """Parse hexadecimal literal.""" return int(str(token), 16) @v_args(inline=True) def dec_literal(self, token: LarkToken) -> int: """Parse decimal literal.""" return int(str(token)) def _process_escape_sequences(self, s: str) -> list[int]: """Process escape sequences in a string. Handles: \\n, \\t, \\r, \\0, \\\\, \\\', \\x## Args: s: String with potential escape sequences Returns: List of character codes """ result = [] i = 0 while i < len(s): if i + 1 < len(s) and s[i] == "\\": next_char = s[i + 1] if next_char == "n": result.append(ord("\n")) i += 2 elif next_char == "t": result.append(ord("\t")) i += 2 elif next_char == "r": result.append(ord("\r")) i += 2 elif next_char == "0": result.append(0) i += 2 elif next_char == "\\": result.append(ord("\\")) i += 2 elif next_char == "'": result.append(ord("'")) i += 2 elif next_char == '"': result.append(ord('"')) i += 2 elif next_char == "x" and i + 3 < len(s): # Hex escape: \xHH hex_str = s[i + 2:i + 4] try: result.append(int(hex_str, 16)) i += 4 except ValueError: # Invalid hex, just include the character result.append(ord(s[i])) i += 1 else: # Unknown escape, just include the character result.append(ord(s[i])) i += 1 else: result.append(ord(s[i])) i += 1 return result @v_args(inline=True) def char_literal(self, token: LarkToken) -> int: """Parse character literal.""" s = str(token) # Remove surrounding quotes s = s[1:-1] # Handle escape sequences if s == "\\n": return ord("\n") elif s == "\\t": return ord("\t") elif s == "\\r": return ord("\r") elif s == "\\0": return 0 elif s == "\\\\": return ord("\\") elif s == "\\'": return ord("'") elif s.startswith("\\x"): return int(s[2:], 16) else: return ord(s[0]) @v_args(inline=True) def string_literal(self, token: LarkToken) -> list[int]: """Parse string literal (returns list of character codes).""" s = str(token)[1:-1] # Remove quotes return self._process_escape_sequences(s) @v_args(inline=True) def raw_string_literal(self, token: LarkToken) -> list[int]: """Parse raw string literal (no escape processing).""" s = str(token)[2:-1] # Remove r" and " return [ord(c) for c in s] @v_args(inline=True) def byte_string_literal(self, token: LarkToken) -> list[int]: """Parse byte string literal.""" s = str(token)[2:-1] # Remove b" and " return self._process_escape_sequences(s) @v_args(inline=True) def named_arg(self, arg_name: LarkToken, value: Any) -> tuple[str, Any]: """Process named argument.""" return (str(arg_name), value) @v_args(inline=True) def ref_list(self, *refs) -> list[dict]: """Collect reference list.""" return list(refs) @v_args(inline=True) def value_list(self, *values) -> list[int]: """Collect value list and pack multi-char values big-endian. - Hex/dec literals: returned as single values (not packed) - Multiple char values: packed big-endian into 16-bit words - String/list data: chars extracted and packed """ # Flatten values (strings return lists of char codes) result = [] for value in values: if isinstance(value, list): # String data from string_literal, etc. result.extend(value) else: # Single value (char or hex/dec literal) result.append(value) # Only pack if we have multiple values (char pairs) AND all are bytes if len(result) <= 1: # Single value: return as-is (whether hex literal or single char) return result all_bytes = all(0 <= v <= 255 for v in result) if not all_bytes: # Mixed or large values, return as-is return result # Multiple bytes: pack consecutive pairs big-endian packed = [] i = 0 while i < len(result): if i + 1 < len(result): # Two bytes: big-endian val = (result[i] << 8) | result[i + 1] packed.append(val) i += 2 else: # Single byte: pad with 0 in low byte val = (result[i] << 8) | 0x00 packed.append(val) i += 1 return packed def lower(tree) -> IRGraph: """Lower a parse tree into an IRGraph. Args: tree: A Lark parse tree from parsing dfasm source Returns: An IRGraph with nodes, edges, regions, and any errors encountered """ transformer = LowerTransformer() return transformer.transform(tree)