OR-1 dataflow CPU sketch
1"""Name resolution pass for the OR1 assembler.
2
3Resolves all symbolic references in an IRGraph to concrete nodes. Implements:
4- Flattening of nested nodes (from regions) into a unified namespace
5- Edge validation (all edges reference existing nodes)
6- Scope violation detection (cross-function label references)
7- Levenshtein distance-based "did you mean" suggestions
8- Error accumulation (all issues reported, not fail-fast)
9
10Reference: Phase 3 design doc.
11"""
12
13from __future__ import annotations
14
15from collections.abc import Iterable
16from dataclasses import replace
17from typing import Optional
18
19from asm.errors import AssemblyError, ErrorCategory
20from asm.ir import IRGraph, IRNode, IREdge, IRRegion, SourceLoc, collect_all_nodes
21
22
23def _levenshtein(a: str, b: str) -> int:
24 """Compute Levenshtein (edit) distance between two strings.
25
26 Args:
27 a: First string
28 b: Second string
29
30 Returns:
31 Minimum edit distance (number of single-character edits)
32 """
33 if len(a) < len(b):
34 return _levenshtein(b, a)
35 if not b:
36 return len(a)
37
38 prev = list(range(len(b) + 1))
39 for i, ca in enumerate(a):
40 curr = [i + 1]
41 for j, cb in enumerate(b):
42 curr.append(min(
43 prev[j + 1] + 1, # deletion
44 curr[j] + 1, # insertion
45 prev[j] + (ca != cb), # substitution
46 ))
47 prev = curr
48 return prev[-1]
49
50
51def _build_scope_map(graph: IRGraph) -> dict[str, str]:
52 """Build a map of node names to their defining scope.
53
54 For top-level nodes, scope is None (empty string in map).
55 For function-scoped nodes, scope is the function name (e.g., "$foo").
56
57 Args:
58 graph: The IRGraph
59
60 Returns:
61 Dictionary mapping qualified name -> scope tag (or "" for top-level)
62 """
63 scope_map = {}
64
65 # Top-level nodes have empty scope
66 for name in graph.nodes:
67 scope_map[name] = ""
68
69 # Walk regions to find function-scoped nodes
70 def _walk_regions(regions: list[IRRegion], parent_scope: str = "") -> None:
71 for region in regions:
72 for name in region.body.nodes:
73 # Scope is the region tag (e.g., "$foo")
74 scope_map[name] = region.tag
75 # Recursively walk nested regions
76 _walk_regions(region.body.regions, region.tag)
77
78 _walk_regions(graph.regions)
79 return scope_map
80
81
82def _check_edge_resolved(
83 edge: IREdge,
84 flattened: dict[str, IRNode],
85 scope_map: dict[str, str],
86 source_scope: str = "",
87) -> Optional[AssemblyError]:
88 """Validate that an edge's source and dest exist in the flattened namespace.
89
90 If either end is missing, generate an appropriate error:
91 - NAME error if name doesn't exist anywhere
92 - SCOPE error if name exists but in a different function scope
93 - Includes "did you mean" suggestions via Levenshtein distance
94
95 Edges can be either:
96 1. Already qualified by Lower pass (e.g., "$bar.&data")
97 2. Simple names that need qualification (older style)
98
99 Args:
100 edge: The IREdge to validate
101 flattened: Flattened node dictionary
102 scope_map: Scope map from _build_scope_map
103 source_scope: The scope context where this edge was defined (e.g., "$foo")
104
105 Returns:
106 AssemblyError if validation fails, None if passes
107 """
108 # Resolve source
109 source_name = edge.source
110 if source_name not in flattened:
111 # Try with scope qualification if not already qualified
112 if "." not in source_name and source_scope:
113 qualified_source = f"{source_scope}.{source_name}"
114 if qualified_source not in flattened:
115 return _generate_unresolved_error(
116 source_name,
117 edge.loc,
118 flattened,
119 scope_map,
120 )
121 source_name = qualified_source
122 else:
123 return _generate_unresolved_error(
124 source_name,
125 edge.loc,
126 flattened,
127 scope_map,
128 )
129
130 # Resolve dest
131 dest_name = edge.dest
132 if dest_name not in flattened:
133 # Try with scope qualification if not already qualified
134 if "." not in dest_name and source_scope:
135 qualified_dest = f"{source_scope}.{dest_name}"
136 if qualified_dest not in flattened:
137 # Check if dest exists in a different scope
138 if dest_name.startswith("&"):
139 for full_name, scope in scope_map.items():
140 if scope != "" and full_name.endswith("." + dest_name):
141 # Found in different scope
142 message = (
143 f"Reference to '{dest_name}' not found in this scope. "
144 f"Did you mean '{full_name}'? (defined in function '{scope}')"
145 )
146 return AssemblyError(
147 loc=edge.loc,
148 category=ErrorCategory.SCOPE,
149 message=message,
150 suggestions=[],
151 )
152 return _generate_unresolved_error(
153 dest_name,
154 edge.loc,
155 flattened,
156 scope_map,
157 )
158 dest_name = qualified_dest
159 else:
160 # dest_name is already qualified or there's no scope context
161 # Check if it's a cross-scope reference
162 if "." in dest_name:
163 # Already qualified, extract the simple name
164 simple_name = dest_name.split(".")[-1]
165 # Check if this simple name exists in any other scope
166 for full_name, scope in scope_map.items():
167 if scope != "" and full_name.endswith("." + simple_name) and full_name != dest_name:
168 # Found in different scope
169 message = (
170 f"Reference to '{dest_name}' not found. "
171 f"Did you mean '{full_name}'? (defined in function '{scope}')"
172 )
173 return AssemblyError(
174 loc=edge.loc,
175 category=ErrorCategory.SCOPE,
176 message=message,
177 suggestions=[],
178 )
179 return _generate_unresolved_error(
180 dest_name,
181 edge.loc,
182 flattened,
183 scope_map,
184 )
185
186 return None
187
188
189def _generate_unresolved_error(
190 name: str,
191 loc: SourceLoc,
192 flattened: dict[str, IRNode],
193 scope_map: dict[str, str],
194) -> AssemblyError:
195 """Generate an error for an unresolved name reference.
196
197 Determines whether it's a NAME error (not found) or SCOPE error (found
198 in different scope), and generates "did you mean" suggestions.
199
200 Args:
201 name: The unresolved name
202 loc: Source location of the reference
203 flattened: Flattened node dictionary
204 scope_map: Scope map from _build_scope_map
205
206 Returns:
207 AssemblyError with appropriate category and suggestions
208 """
209 # Check if this name exists in a different scope
210 # For now, we only need to check if it's a label reference (starts with &)
211 # and exists in some function scope
212 if name.startswith("&"):
213 # Look for this label in any function scope
214 for full_name, scope in scope_map.items():
215 if scope != "" and full_name.endswith("." + name):
216 # Found the label in a function scope
217 message = (
218 f"Reference to '{name}' not found. "
219 f"Did you mean '{full_name}'? (defined in function '{scope}')"
220 )
221 return AssemblyError(
222 loc=loc,
223 category=ErrorCategory.SCOPE,
224 message=message,
225 suggestions=[],
226 )
227
228 # Not found anywhere - generate NAME error with suggestions
229 suggestions = _suggest_names(name, flattened.keys())
230 message = f"undefined reference to '{name}'"
231
232 return AssemblyError(
233 loc=loc,
234 category=ErrorCategory.NAME,
235 message=message,
236 suggestions=suggestions,
237 )
238
239
240def _suggest_names(unresolved: str, available_names: Iterable[str]) -> list[str]:
241 """Generate "did you mean" suggestions via Levenshtein distance.
242
243 Compares unresolved name against all available names, returning suggestions
244 with distance <= 3, or the closest match if all distances are > 3.
245
246 Args:
247 unresolved: The unresolved name
248 available_names: Iterable of available node names
249
250 Returns:
251 List of suggestion strings (may be empty)
252 """
253 if not available_names:
254 return []
255
256 # Compute distances
257 candidates = []
258 for name in available_names:
259 dist = _levenshtein(unresolved, name)
260 candidates.append((dist, name))
261
262 # Sort by distance
263 candidates.sort(key=lambda x: x[0])
264
265 # Return suggestions with distance <= 3, or best if all > 3
266 suggestions = []
267 best_distance = candidates[0][0]
268
269 for dist, name in candidates:
270 if dist <= 3 or dist == best_distance:
271 suggestions.append(f"Did you mean '{name}'?")
272 else:
273 break
274
275 return suggestions
276
277
278def _check_edges_recursive(
279 graph: IRGraph,
280 flattened: dict[str, IRNode],
281 scope_map: dict[str, str],
282 source_scope: str = "",
283) -> list[AssemblyError]:
284 """Recursively validate all edges in the graph and its regions.
285
286 Args:
287 graph: The IRGraph to check
288 flattened: Flattened node dictionary
289 scope_map: Scope map
290 source_scope: The scope context for this graph (e.g., "$foo" for region bodies)
291
292 Returns:
293 List of AssemblyErrors found
294 """
295 errors = []
296
297 # Check edges at this level with the current scope context
298 for edge in graph.edges:
299 error = _check_edge_resolved(edge, flattened, scope_map, source_scope)
300 if error:
301 errors.append(error)
302
303 # Check edges in nested regions, passing the region's scope
304 for region in graph.regions:
305 errors.extend(
306 _check_edges_recursive(region.body, flattened, scope_map, region.tag)
307 )
308
309 return errors
310
311
312def resolve(graph: IRGraph) -> IRGraph:
313 """Resolve all symbolic references in an IRGraph.
314
315 Returns a new IRGraph with all name resolution errors appended to
316 graph.errors. If there are no errors, the returned graph is structurally
317 identical to the input (immutable pass pattern).
318
319 The resolution process:
320 1. Flattens all nodes (from graph and nested regions)
321 2. Builds a scope map (top-level vs function-scoped)
322 3. Validates all edges reference existing nodes
323 4. Accumulates errors (all issues found, not fail-fast)
324 5. Returns new IRGraph with errors appended
325
326 Args:
327 graph: The IRGraph to resolve
328
329 Returns:
330 New IRGraph with resolution errors appended to graph.errors
331 """
332 # Skip if already has errors from earlier phases
333 if graph.errors:
334 return graph
335
336 # Flatten nodes and build scope map
337 flattened = collect_all_nodes(graph)
338 scope_map = _build_scope_map(graph)
339
340 # Check all edges
341 resolution_errors = _check_edges_recursive(graph, flattened, scope_map)
342
343 # Return new graph with errors appended
344 if resolution_errors:
345 new_errors = list(graph.errors) + resolution_errors
346 return replace(graph, errors=new_errors)
347
348 return graph