"""Name resolution pass for the OR1 assembler. Resolves all symbolic references in an IRGraph to concrete nodes. Implements: - Flattening of nested nodes (from regions) into a unified namespace - Edge validation (all edges reference existing nodes) - Scope violation detection (cross-function label references) - Levenshtein distance-based "did you mean" suggestions - Error accumulation (all issues reported, not fail-fast) Reference: Phase 3 design doc. """ from __future__ import annotations from collections.abc import Iterable from dataclasses import replace from typing import Optional from asm.errors import AssemblyError, ErrorCategory from asm.ir import IRGraph, IRNode, IREdge, IRRegion, SourceLoc, collect_all_nodes def _levenshtein(a: str, b: str) -> int: """Compute Levenshtein (edit) distance between two strings. Args: a: First string b: Second string Returns: Minimum edit distance (number of single-character edits) """ if len(a) < len(b): return _levenshtein(b, a) if not b: return len(a) prev = list(range(len(b) + 1)) for i, ca in enumerate(a): curr = [i + 1] for j, cb in enumerate(b): curr.append(min( prev[j + 1] + 1, # deletion curr[j] + 1, # insertion prev[j] + (ca != cb), # substitution )) prev = curr return prev[-1] def _build_scope_map(graph: IRGraph) -> dict[str, str]: """Build a map of node names to their defining scope. For top-level nodes, scope is None (empty string in map). For function-scoped nodes, scope is the function name (e.g., "$foo"). Args: graph: The IRGraph Returns: Dictionary mapping qualified name -> scope tag (or "" for top-level) """ scope_map = {} # Top-level nodes have empty scope for name in graph.nodes: scope_map[name] = "" # Walk regions to find function-scoped nodes def _walk_regions(regions: list[IRRegion], parent_scope: str = "") -> None: for region in regions: for name in region.body.nodes: # Scope is the region tag (e.g., "$foo") scope_map[name] = region.tag # Recursively walk nested regions _walk_regions(region.body.regions, region.tag) _walk_regions(graph.regions) return scope_map def _check_edge_resolved( edge: IREdge, flattened: dict[str, IRNode], scope_map: dict[str, str], source_scope: str = "", ) -> Optional[AssemblyError]: """Validate that an edge's source and dest exist in the flattened namespace. If either end is missing, generate an appropriate error: - NAME error if name doesn't exist anywhere - SCOPE error if name exists but in a different function scope - Includes "did you mean" suggestions via Levenshtein distance Edges can be either: 1. Already qualified by Lower pass (e.g., "$bar.&data") 2. Simple names that need qualification (older style) Args: edge: The IREdge to validate flattened: Flattened node dictionary scope_map: Scope map from _build_scope_map source_scope: The scope context where this edge was defined (e.g., "$foo") Returns: AssemblyError if validation fails, None if passes """ # Resolve source source_name = edge.source if source_name not in flattened: # Try with scope qualification if not already qualified if "." not in source_name and source_scope: qualified_source = f"{source_scope}.{source_name}" if qualified_source not in flattened: return _generate_unresolved_error( source_name, edge.loc, flattened, scope_map, ) source_name = qualified_source else: return _generate_unresolved_error( source_name, edge.loc, flattened, scope_map, ) # Resolve dest dest_name = edge.dest if dest_name not in flattened: # Try with scope qualification if not already qualified if "." not in dest_name and source_scope: qualified_dest = f"{source_scope}.{dest_name}" if qualified_dest not in flattened: # Check if dest exists in a different scope if dest_name.startswith("&"): for full_name, scope in scope_map.items(): if scope != "" and full_name.endswith("." + dest_name): # Found in different scope message = ( f"Reference to '{dest_name}' not found in this scope. " f"Did you mean '{full_name}'? (defined in function '{scope}')" ) return AssemblyError( loc=edge.loc, category=ErrorCategory.SCOPE, message=message, suggestions=[], ) return _generate_unresolved_error( dest_name, edge.loc, flattened, scope_map, ) dest_name = qualified_dest else: # dest_name is already qualified or there's no scope context # Check if it's a cross-scope reference if "." in dest_name: # Already qualified, extract the simple name simple_name = dest_name.split(".")[-1] # Check if this simple name exists in any other scope for full_name, scope in scope_map.items(): if scope != "" and full_name.endswith("." + simple_name) and full_name != dest_name: # Found in different scope message = ( f"Reference to '{dest_name}' not found. " f"Did you mean '{full_name}'? (defined in function '{scope}')" ) return AssemblyError( loc=edge.loc, category=ErrorCategory.SCOPE, message=message, suggestions=[], ) return _generate_unresolved_error( dest_name, edge.loc, flattened, scope_map, ) return None def _generate_unresolved_error( name: str, loc: SourceLoc, flattened: dict[str, IRNode], scope_map: dict[str, str], ) -> AssemblyError: """Generate an error for an unresolved name reference. Determines whether it's a NAME error (not found) or SCOPE error (found in different scope), and generates "did you mean" suggestions. Args: name: The unresolved name loc: Source location of the reference flattened: Flattened node dictionary scope_map: Scope map from _build_scope_map Returns: AssemblyError with appropriate category and suggestions """ # Check if this name exists in a different scope # For now, we only need to check if it's a label reference (starts with &) # and exists in some function scope if name.startswith("&"): # Look for this label in any function scope for full_name, scope in scope_map.items(): if scope != "" and full_name.endswith("." + name): # Found the label in a function scope message = ( f"Reference to '{name}' not found. " f"Did you mean '{full_name}'? (defined in function '{scope}')" ) return AssemblyError( loc=loc, category=ErrorCategory.SCOPE, message=message, suggestions=[], ) # Not found anywhere - generate NAME error with suggestions suggestions = _suggest_names(name, flattened.keys()) message = f"undefined reference to '{name}'" return AssemblyError( loc=loc, category=ErrorCategory.NAME, message=message, suggestions=suggestions, ) def _suggest_names(unresolved: str, available_names: Iterable[str]) -> list[str]: """Generate "did you mean" suggestions via Levenshtein distance. Compares unresolved name against all available names, returning suggestions with distance <= 3, or the closest match if all distances are > 3. Args: unresolved: The unresolved name available_names: Iterable of available node names Returns: List of suggestion strings (may be empty) """ if not available_names: return [] # Compute distances candidates = [] for name in available_names: dist = _levenshtein(unresolved, name) candidates.append((dist, name)) # Sort by distance candidates.sort(key=lambda x: x[0]) # Return suggestions with distance <= 3, or best if all > 3 suggestions = [] best_distance = candidates[0][0] for dist, name in candidates: if dist <= 3 or dist == best_distance: suggestions.append(f"Did you mean '{name}'?") else: break return suggestions def _check_edges_recursive( graph: IRGraph, flattened: dict[str, IRNode], scope_map: dict[str, str], source_scope: str = "", ) -> list[AssemblyError]: """Recursively validate all edges in the graph and its regions. Args: graph: The IRGraph to check flattened: Flattened node dictionary scope_map: Scope map source_scope: The scope context for this graph (e.g., "$foo" for region bodies) Returns: List of AssemblyErrors found """ errors = [] # Check edges at this level with the current scope context for edge in graph.edges: error = _check_edge_resolved(edge, flattened, scope_map, source_scope) if error: errors.append(error) # Check edges in nested regions, passing the region's scope for region in graph.regions: errors.extend( _check_edges_recursive(region.body, flattened, scope_map, region.tag) ) return errors def resolve(graph: IRGraph) -> IRGraph: """Resolve all symbolic references in an IRGraph. Returns a new IRGraph with all name resolution errors appended to graph.errors. If there are no errors, the returned graph is structurally identical to the input (immutable pass pattern). The resolution process: 1. Flattens all nodes (from graph and nested regions) 2. Builds a scope map (top-level vs function-scoped) 3. Validates all edges reference existing nodes 4. Accumulates errors (all issues found, not fail-fast) 5. Returns new IRGraph with errors appended Args: graph: The IRGraph to resolve Returns: New IRGraph with resolution errors appended to graph.errors """ # Skip if already has errors from earlier phases if graph.errors: return graph # Flatten nodes and build scope map flattened = collect_all_nodes(graph) scope_map = _build_scope_map(graph) # Check all edges resolution_errors = _check_edges_recursive(graph, flattened, scope_map) # Return new graph with errors appended if resolution_errors: new_errors = list(graph.errors) + resolution_errors return replace(graph, errors=new_errors) return graph