OR-1 dataflow CPU sketch
1"""Placement validation and auto-placement pass for the OR1 assembler.
2
3Validates user-provided PE placements and performs auto-placement for unplaced nodes.
4Uses a greedy bin-packing algorithm with locality heuristic to assign unplaced nodes
5to PEs while respecting IRAM capacity and context slot limits.
6
7Reference: Phase 4 and Phase 7 design docs.
8"""
9
10from __future__ import annotations
11
12from collections import Counter, defaultdict
13from dataclasses import replace
14
15from asm.errors import AssemblyError, ErrorCategory, ErrorSeverity
16from asm.ir import (
17 IRGraph, IRNode, IRRegion, RegionKind, SystemConfig, SourceLoc, collect_all_nodes,
18 update_graph_nodes, DEFAULT_IRAM_CAPACITY, DEFAULT_FRAME_COUNT
19)
20from asm.opcodes import is_dyadic
21
22
23def _infer_system_config(graph: IRGraph) -> SystemConfig:
24 """Infer a SystemConfig from node placements if none is provided.
25
26 Determines pe_count from the maximum PE ID referenced in node placements + 1.
27 Uses default capacity values (iram_capacity=256, frame_count=8) matching SystemConfig defaults.
28
29 Args:
30 graph: The IRGraph (may have system=None)
31
32 Returns:
33 SystemConfig with inferred pe_count and default capacity values
34 """
35 max_pe_id = -1
36
37 # Check all nodes recursively
38 def _find_max_pe(nodes: dict[str, IRNode]) -> None:
39 nonlocal max_pe_id
40 for node in nodes.values():
41 if node.pe is not None and node.pe > max_pe_id:
42 max_pe_id = node.pe
43
44 _find_max_pe(graph.nodes)
45
46 # Check nodes in regions
47 def _check_regions(regions: list[IRRegion]) -> None:
48 for region in regions:
49 _find_max_pe(region.body.nodes)
50 _check_regions(region.body.regions)
51
52 _check_regions(graph.regions)
53
54 pe_count = max(1, max_pe_id + 1) # At least 1 PE
55 return SystemConfig(
56 pe_count=pe_count,
57 sm_count=1, # Default to 1 SM
58 iram_capacity=DEFAULT_IRAM_CAPACITY,
59 frame_count=DEFAULT_FRAME_COUNT,
60 loc=SourceLoc(0, 0),
61 )
62
63
64
65
66def _find_node_scope(graph: IRGraph, node_name: str) -> str | None:
67 """Find the function scope of a node.
68
69 Returns the tag of the function region containing this node, or None if top-level.
70
71 Args:
72 graph: The IRGraph
73 node_name: Name of the node to find
74
75 Returns:
76 Function region tag if node is in a function, None if top-level
77 """
78 # Check if node is in top-level nodes
79 if node_name in graph.nodes:
80 return None
81
82 # Search in regions recursively
83 def _search_regions(regions: list[IRRegion]) -> str | None:
84 for region in regions:
85 if region.kind == RegionKind.FUNCTION:
86 # Check if node is in this function's body
87 if node_name in region.body.nodes:
88 return region.tag
89 # Recursively search nested regions (shouldn't happen with current design)
90 result = _search_regions(region.body.regions)
91 if result:
92 return result
93 else:
94 # For LOCATION regions, nodes are still top-level conceptually
95 if node_name in region.body.nodes:
96 return None
97 # Search nested regions
98 result = _search_regions(region.body.regions)
99 if result:
100 return result
101 return None
102
103 return _search_regions(graph.regions)
104
105
106def _build_adjacency(graph: IRGraph, all_nodes: dict[str, IRNode]) -> dict[str, set[str]]:
107 """Build adjacency map from edges: node -> set of connected neighbours.
108
109 Args:
110 graph: The IRGraph
111 all_nodes: Dictionary of all nodes
112
113 Returns:
114 Dictionary mapping node names to sets of connected node names
115 """
116 adjacency: dict[str, set[str]] = defaultdict(set)
117
118 def _process_edges(edges: list) -> None:
119 for edge in edges:
120 # Both source and dest are neighbours
121 adjacency[edge.source].add(edge.dest)
122 adjacency[edge.dest].add(edge.source)
123
124 _process_edges(graph.edges)
125
126 # Also process edges in regions
127 def _process_regions_edges(regions: list[IRRegion]) -> None:
128 for region in regions:
129 _process_edges(region.body.edges)
130 _process_regions_edges(region.body.regions)
131
132 _process_regions_edges(graph.regions)
133
134 return adjacency
135
136
137def _count_iram_cost(node: IRNode) -> int:
138 """Count IRAM slots used by a node.
139
140 In the frame model, all instructions use 1 IRAM slot.
141 Matching is handled by frame SRAM, not IRAM entries.
142
143 Args:
144 node: The IRNode
145
146 Returns:
147 Number of IRAM slots used (always 1)
148 """
149 return 1
150
151
152def _auto_place_nodes(
153 graph: IRGraph,
154 system: SystemConfig,
155 all_nodes: dict[str, IRNode],
156 adjacency: dict[str, set[str]],
157) -> tuple[dict[str, IRNode], list[AssemblyError]]:
158 """Auto-place unplaced nodes using greedy bin-packing with locality heuristic.
159
160 Algorithm:
161 1. Identify unplaced nodes (pe=None)
162 2. For each unplaced node in order:
163 a. Find PE of connected neighbours (use updated_nodes for current placements)
164 b. Prefer PE with most neighbours (locality)
165 c. Tie-break by remaining IRAM capacity
166 d. If no PE has room, record error and continue
167 3. Return updated nodes and any placement errors
168
169 Args:
170 graph: The IRGraph
171 system: SystemConfig with pe_count, iram_capacity, frame_count
172 all_nodes: Dictionary of all nodes
173 adjacency: Adjacency map
174
175 Returns:
176 Tuple of (updated nodes dict, list of placement errors)
177 """
178 errors: list[AssemblyError] = []
179
180 # Track resource usage per PE: (iram_used, frames_used)
181 iram_used = [0] * system.pe_count
182 frames_used = [0] * system.pe_count
183 # Track dyadic offset usage per PE for matchable offset warnings
184 dyadic_offsets_per_pe = [0] * system.pe_count
185
186 # Copy nodes so we can update placement as we go
187 updated_nodes = dict(all_nodes)
188
189 # Initialize PE resource usage from explicitly placed nodes
190 # Track which function scopes have been counted per PE to avoid double-counting
191 act_scopes_per_pe: dict[int, set[str | None]] = {pe_id: set() for pe_id in range(system.pe_count)}
192
193 for node_name, node in updated_nodes.items():
194 if node.pe is not None:
195 iram_cost = _count_iram_cost(node)
196 iram_used[node.pe] += iram_cost
197 # Count frames per function scope, not per node
198 scope = _find_node_scope(graph, node_name)
199 if scope not in act_scopes_per_pe[node.pe]:
200 act_scopes_per_pe[node.pe].add(scope)
201 frames_used[node.pe] += 1
202 # Track dyadic offsets for matchable offset warning
203 if is_dyadic(node.opcode, node.const):
204 dyadic_offsets_per_pe[node.pe] += 1
205
206 # For unplaced nodes, we'll track scopes similarly
207 act_scopes_updated: dict[int, set[str | None]] = {pe_id: set(scopes) for pe_id, scopes in act_scopes_per_pe.items()}
208
209 # Process nodes in insertion order
210 for node_name, node in all_nodes.items():
211 if node.pe is not None:
212 # Already placed, skip
213 continue
214
215 # Find neighbours and their PEs (from updated_nodes to include newly placed nodes)
216 neighbours = adjacency.get(node_name, set())
217 neighbour_pes: list[int] = []
218 for neighbour_name in neighbours:
219 neighbour = updated_nodes.get(neighbour_name)
220 if neighbour and neighbour.pe is not None:
221 neighbour_pes.append(neighbour.pe)
222
223 # Count PE occurrences among neighbours (for locality heuristic)
224 pe_counts: dict[int, int] = Counter(neighbour_pes)
225
226 # Sort PEs by: most neighbours first, then most remaining IRAM
227 candidate_pes = list(range(system.pe_count))
228 candidate_pes.sort(
229 key=lambda pe: (
230 -pe_counts.get(pe, 0), # Negative so most neighbours come first
231 -(system.iram_capacity - iram_used[pe]), # Then most room
232 ),
233 )
234
235 # Find first PE with room
236 iram_cost = _count_iram_cost(node)
237 node_scope = _find_node_scope(graph, node_name)
238 placed = False
239 for pe in candidate_pes:
240 # Check if this scope is new to this PE
241 scope_is_new = node_scope not in act_scopes_updated[pe]
242 frames_needed = 1 if scope_is_new else 0
243
244 if (
245 iram_used[pe] + iram_cost <= system.iram_capacity
246 and frames_used[pe] + frames_needed <= system.frame_count
247 ):
248 # Place node on this PE
249 updated_nodes[node_name] = replace(node, pe=pe)
250 iram_used[pe] += iram_cost
251 if scope_is_new:
252 act_scopes_updated[pe].add(node_scope)
253 frames_used[pe] += 1
254 # Track dyadic offsets
255 if is_dyadic(node.opcode, node.const):
256 dyadic_offsets_per_pe[pe] += 1
257 placed = True
258 break
259
260 if not placed:
261 # No PE has room - generate error with utilization breakdown
262 error = _format_placement_overflow_error(node, system, iram_used, frames_used)
263 errors.append(error)
264
265 # Check matchable offset limits and emit warnings if exceeded
266 for pe_id in range(system.pe_count):
267 if dyadic_offsets_per_pe[pe_id] > system.matchable_offsets:
268 # Find a node on this PE to use for location
269 warning_node = None
270 for node_name, node in updated_nodes.items():
271 if node.pe == pe_id and is_dyadic(node.opcode, node.const):
272 warning_node = node
273 break
274
275 if warning_node:
276 errors.append(AssemblyError(
277 loc=warning_node.loc,
278 category=ErrorCategory.FRAME,
279 message=f"PE {pe_id} uses {dyadic_offsets_per_pe[pe_id]} matchable offsets "
280 f"(limit: {system.matchable_offsets})",
281 severity=ErrorSeverity.WARNING,
282 ))
283
284 return updated_nodes, errors
285
286
287def _format_placement_overflow_error(
288 node: IRNode,
289 system: SystemConfig,
290 iram_used: list[int],
291 frames_used: list[int],
292) -> AssemblyError:
293 """Format a placement overflow error with per-PE utilization breakdown.
294
295 Args:
296 node: The node that couldn't be placed
297 system: SystemConfig
298 iram_used: List of IRAM slots used per PE
299 frames_used: List of frames used per PE
300
301 Returns:
302 AssemblyError with detailed breakdown
303 """
304 breakdown_lines = []
305 for pe_id in range(system.pe_count):
306 breakdown_lines.append(
307 f" PE{pe_id}: {iram_used[pe_id]}/{system.iram_capacity} IRAM slots, "
308 f"{frames_used[pe_id]}/{system.frame_count} frames"
309 )
310
311 breakdown = "\n".join(breakdown_lines)
312 message = f"Cannot place node '{node.name}': all PEs are full.\n{breakdown}"
313
314 return AssemblyError(
315 loc=node.loc,
316 category=ErrorCategory.PLACEMENT,
317 message=message,
318 suggestions=[],
319 )
320
321
322def place(graph: IRGraph) -> IRGraph:
323 """Placement pass: validate explicit placements and auto-place unplaced nodes.
324
325 Process:
326 1. Infer or use provided SystemConfig
327 2. Validate explicitly placed nodes (pe is not None)
328 3. Auto-place any unplaced nodes using greedy bin-packing + locality
329 4. Validate all PE IDs are < pe_count
330
331 Args:
332 graph: The IRGraph to place
333
334 Returns:
335 New IRGraph with all nodes placed and placement errors appended
336 """
337 # Determine system config
338 system = graph.system if graph.system is not None else _infer_system_config(graph)
339
340 errors = list(graph.errors)
341
342 # Collect all nodes
343 all_nodes = collect_all_nodes(graph)
344
345 # First pass: validate explicitly placed nodes (reject invalid PE IDs)
346 valid_nodes = {}
347 for node_name, node in all_nodes.items():
348 if node.pe is not None and node.pe >= system.pe_count:
349 error = AssemblyError(
350 loc=node.loc,
351 category=ErrorCategory.PLACEMENT,
352 message=f"Node '{node_name}' placed on PE{node.pe} but system only has {system.pe_count} PEs (0-{system.pe_count - 1}).",
353 suggestions=[],
354 )
355 errors.append(error)
356 else:
357 valid_nodes[node_name] = node
358
359 all_nodes = valid_nodes
360
361 # Check if any nodes are unplaced
362 unplaced_nodes = [node for node in all_nodes.values() if node.pe is None]
363
364 if unplaced_nodes:
365 # Auto-place unplaced nodes
366 adjacency = _build_adjacency(graph, all_nodes)
367 all_nodes, placement_errors = _auto_place_nodes(graph, system, all_nodes, adjacency)
368 errors.extend(placement_errors)
369
370 # Update graph with placed nodes
371 # This ensures nodes inside function scopes receive updated PE assignments
372 result_graph = update_graph_nodes(graph, all_nodes)
373 return replace(result_graph, system=system, errors=errors)