OR-1 dataflow CPU sketch
at main 348 lines 12 kB view raw
1"""Name resolution pass for the OR1 assembler. 2 3Resolves all symbolic references in an IRGraph to concrete nodes. Implements: 4- Flattening of nested nodes (from regions) into a unified namespace 5- Edge validation (all edges reference existing nodes) 6- Scope violation detection (cross-function label references) 7- Levenshtein distance-based "did you mean" suggestions 8- Error accumulation (all issues reported, not fail-fast) 9 10Reference: Phase 3 design doc. 11""" 12 13from __future__ import annotations 14 15from collections.abc import Iterable 16from dataclasses import replace 17from typing import Optional 18 19from asm.errors import AssemblyError, ErrorCategory 20from asm.ir import IRGraph, IRNode, IREdge, IRRegion, SourceLoc, collect_all_nodes 21 22 23def _levenshtein(a: str, b: str) -> int: 24 """Compute Levenshtein (edit) distance between two strings. 25 26 Args: 27 a: First string 28 b: Second string 29 30 Returns: 31 Minimum edit distance (number of single-character edits) 32 """ 33 if len(a) < len(b): 34 return _levenshtein(b, a) 35 if not b: 36 return len(a) 37 38 prev = list(range(len(b) + 1)) 39 for i, ca in enumerate(a): 40 curr = [i + 1] 41 for j, cb in enumerate(b): 42 curr.append(min( 43 prev[j + 1] + 1, # deletion 44 curr[j] + 1, # insertion 45 prev[j] + (ca != cb), # substitution 46 )) 47 prev = curr 48 return prev[-1] 49 50 51def _build_scope_map(graph: IRGraph) -> dict[str, str]: 52 """Build a map of node names to their defining scope. 53 54 For top-level nodes, scope is None (empty string in map). 55 For function-scoped nodes, scope is the function name (e.g., "$foo"). 56 57 Args: 58 graph: The IRGraph 59 60 Returns: 61 Dictionary mapping qualified name -> scope tag (or "" for top-level) 62 """ 63 scope_map = {} 64 65 # Top-level nodes have empty scope 66 for name in graph.nodes: 67 scope_map[name] = "" 68 69 # Walk regions to find function-scoped nodes 70 def _walk_regions(regions: list[IRRegion], parent_scope: str = "") -> None: 71 for region in regions: 72 for name in region.body.nodes: 73 # Scope is the region tag (e.g., "$foo") 74 scope_map[name] = region.tag 75 # Recursively walk nested regions 76 _walk_regions(region.body.regions, region.tag) 77 78 _walk_regions(graph.regions) 79 return scope_map 80 81 82def _check_edge_resolved( 83 edge: IREdge, 84 flattened: dict[str, IRNode], 85 scope_map: dict[str, str], 86 source_scope: str = "", 87) -> Optional[AssemblyError]: 88 """Validate that an edge's source and dest exist in the flattened namespace. 89 90 If either end is missing, generate an appropriate error: 91 - NAME error if name doesn't exist anywhere 92 - SCOPE error if name exists but in a different function scope 93 - Includes "did you mean" suggestions via Levenshtein distance 94 95 Edges can be either: 96 1. Already qualified by Lower pass (e.g., "$bar.&data") 97 2. Simple names that need qualification (older style) 98 99 Args: 100 edge: The IREdge to validate 101 flattened: Flattened node dictionary 102 scope_map: Scope map from _build_scope_map 103 source_scope: The scope context where this edge was defined (e.g., "$foo") 104 105 Returns: 106 AssemblyError if validation fails, None if passes 107 """ 108 # Resolve source 109 source_name = edge.source 110 if source_name not in flattened: 111 # Try with scope qualification if not already qualified 112 if "." not in source_name and source_scope: 113 qualified_source = f"{source_scope}.{source_name}" 114 if qualified_source not in flattened: 115 return _generate_unresolved_error( 116 source_name, 117 edge.loc, 118 flattened, 119 scope_map, 120 ) 121 source_name = qualified_source 122 else: 123 return _generate_unresolved_error( 124 source_name, 125 edge.loc, 126 flattened, 127 scope_map, 128 ) 129 130 # Resolve dest 131 dest_name = edge.dest 132 if dest_name not in flattened: 133 # Try with scope qualification if not already qualified 134 if "." not in dest_name and source_scope: 135 qualified_dest = f"{source_scope}.{dest_name}" 136 if qualified_dest not in flattened: 137 # Check if dest exists in a different scope 138 if dest_name.startswith("&"): 139 for full_name, scope in scope_map.items(): 140 if scope != "" and full_name.endswith("." + dest_name): 141 # Found in different scope 142 message = ( 143 f"Reference to '{dest_name}' not found in this scope. " 144 f"Did you mean '{full_name}'? (defined in function '{scope}')" 145 ) 146 return AssemblyError( 147 loc=edge.loc, 148 category=ErrorCategory.SCOPE, 149 message=message, 150 suggestions=[], 151 ) 152 return _generate_unresolved_error( 153 dest_name, 154 edge.loc, 155 flattened, 156 scope_map, 157 ) 158 dest_name = qualified_dest 159 else: 160 # dest_name is already qualified or there's no scope context 161 # Check if it's a cross-scope reference 162 if "." in dest_name: 163 # Already qualified, extract the simple name 164 simple_name = dest_name.split(".")[-1] 165 # Check if this simple name exists in any other scope 166 for full_name, scope in scope_map.items(): 167 if scope != "" and full_name.endswith("." + simple_name) and full_name != dest_name: 168 # Found in different scope 169 message = ( 170 f"Reference to '{dest_name}' not found. " 171 f"Did you mean '{full_name}'? (defined in function '{scope}')" 172 ) 173 return AssemblyError( 174 loc=edge.loc, 175 category=ErrorCategory.SCOPE, 176 message=message, 177 suggestions=[], 178 ) 179 return _generate_unresolved_error( 180 dest_name, 181 edge.loc, 182 flattened, 183 scope_map, 184 ) 185 186 return None 187 188 189def _generate_unresolved_error( 190 name: str, 191 loc: SourceLoc, 192 flattened: dict[str, IRNode], 193 scope_map: dict[str, str], 194) -> AssemblyError: 195 """Generate an error for an unresolved name reference. 196 197 Determines whether it's a NAME error (not found) or SCOPE error (found 198 in different scope), and generates "did you mean" suggestions. 199 200 Args: 201 name: The unresolved name 202 loc: Source location of the reference 203 flattened: Flattened node dictionary 204 scope_map: Scope map from _build_scope_map 205 206 Returns: 207 AssemblyError with appropriate category and suggestions 208 """ 209 # Check if this name exists in a different scope 210 # For now, we only need to check if it's a label reference (starts with &) 211 # and exists in some function scope 212 if name.startswith("&"): 213 # Look for this label in any function scope 214 for full_name, scope in scope_map.items(): 215 if scope != "" and full_name.endswith("." + name): 216 # Found the label in a function scope 217 message = ( 218 f"Reference to '{name}' not found. " 219 f"Did you mean '{full_name}'? (defined in function '{scope}')" 220 ) 221 return AssemblyError( 222 loc=loc, 223 category=ErrorCategory.SCOPE, 224 message=message, 225 suggestions=[], 226 ) 227 228 # Not found anywhere - generate NAME error with suggestions 229 suggestions = _suggest_names(name, flattened.keys()) 230 message = f"undefined reference to '{name}'" 231 232 return AssemblyError( 233 loc=loc, 234 category=ErrorCategory.NAME, 235 message=message, 236 suggestions=suggestions, 237 ) 238 239 240def _suggest_names(unresolved: str, available_names: Iterable[str]) -> list[str]: 241 """Generate "did you mean" suggestions via Levenshtein distance. 242 243 Compares unresolved name against all available names, returning suggestions 244 with distance <= 3, or the closest match if all distances are > 3. 245 246 Args: 247 unresolved: The unresolved name 248 available_names: Iterable of available node names 249 250 Returns: 251 List of suggestion strings (may be empty) 252 """ 253 if not available_names: 254 return [] 255 256 # Compute distances 257 candidates = [] 258 for name in available_names: 259 dist = _levenshtein(unresolved, name) 260 candidates.append((dist, name)) 261 262 # Sort by distance 263 candidates.sort(key=lambda x: x[0]) 264 265 # Return suggestions with distance <= 3, or best if all > 3 266 suggestions = [] 267 best_distance = candidates[0][0] 268 269 for dist, name in candidates: 270 if dist <= 3 or dist == best_distance: 271 suggestions.append(f"Did you mean '{name}'?") 272 else: 273 break 274 275 return suggestions 276 277 278def _check_edges_recursive( 279 graph: IRGraph, 280 flattened: dict[str, IRNode], 281 scope_map: dict[str, str], 282 source_scope: str = "", 283) -> list[AssemblyError]: 284 """Recursively validate all edges in the graph and its regions. 285 286 Args: 287 graph: The IRGraph to check 288 flattened: Flattened node dictionary 289 scope_map: Scope map 290 source_scope: The scope context for this graph (e.g., "$foo" for region bodies) 291 292 Returns: 293 List of AssemblyErrors found 294 """ 295 errors = [] 296 297 # Check edges at this level with the current scope context 298 for edge in graph.edges: 299 error = _check_edge_resolved(edge, flattened, scope_map, source_scope) 300 if error: 301 errors.append(error) 302 303 # Check edges in nested regions, passing the region's scope 304 for region in graph.regions: 305 errors.extend( 306 _check_edges_recursive(region.body, flattened, scope_map, region.tag) 307 ) 308 309 return errors 310 311 312def resolve(graph: IRGraph) -> IRGraph: 313 """Resolve all symbolic references in an IRGraph. 314 315 Returns a new IRGraph with all name resolution errors appended to 316 graph.errors. If there are no errors, the returned graph is structurally 317 identical to the input (immutable pass pattern). 318 319 The resolution process: 320 1. Flattens all nodes (from graph and nested regions) 321 2. Builds a scope map (top-level vs function-scoped) 322 3. Validates all edges reference existing nodes 323 4. Accumulates errors (all issues found, not fail-fast) 324 5. Returns new IRGraph with errors appended 325 326 Args: 327 graph: The IRGraph to resolve 328 329 Returns: 330 New IRGraph with resolution errors appended to graph.errors 331 """ 332 # Skip if already has errors from earlier phases 333 if graph.errors: 334 return graph 335 336 # Flatten nodes and build scope map 337 flattened = collect_all_nodes(graph) 338 scope_map = _build_scope_map(graph) 339 340 # Check all edges 341 resolution_errors = _check_edges_recursive(graph, flattened, scope_map) 342 343 # Return new graph with errors appended 344 if resolution_errors: 345 new_errors = list(graph.errors) + resolution_errors 346 return replace(graph, errors=new_errors) 347 348 return graph