OR-1 dataflow CPU sketch
at main 373 lines 13 kB view raw
1"""Placement validation and auto-placement pass for the OR1 assembler. 2 3Validates user-provided PE placements and performs auto-placement for unplaced nodes. 4Uses a greedy bin-packing algorithm with locality heuristic to assign unplaced nodes 5to PEs while respecting IRAM capacity and context slot limits. 6 7Reference: Phase 4 and Phase 7 design docs. 8""" 9 10from __future__ import annotations 11 12from collections import Counter, defaultdict 13from dataclasses import replace 14 15from asm.errors import AssemblyError, ErrorCategory, ErrorSeverity 16from asm.ir import ( 17 IRGraph, IRNode, IRRegion, RegionKind, SystemConfig, SourceLoc, collect_all_nodes, 18 update_graph_nodes, DEFAULT_IRAM_CAPACITY, DEFAULT_FRAME_COUNT 19) 20from asm.opcodes import is_dyadic 21 22 23def _infer_system_config(graph: IRGraph) -> SystemConfig: 24 """Infer a SystemConfig from node placements if none is provided. 25 26 Determines pe_count from the maximum PE ID referenced in node placements + 1. 27 Uses default capacity values (iram_capacity=256, frame_count=8) matching SystemConfig defaults. 28 29 Args: 30 graph: The IRGraph (may have system=None) 31 32 Returns: 33 SystemConfig with inferred pe_count and default capacity values 34 """ 35 max_pe_id = -1 36 37 # Check all nodes recursively 38 def _find_max_pe(nodes: dict[str, IRNode]) -> None: 39 nonlocal max_pe_id 40 for node in nodes.values(): 41 if node.pe is not None and node.pe > max_pe_id: 42 max_pe_id = node.pe 43 44 _find_max_pe(graph.nodes) 45 46 # Check nodes in regions 47 def _check_regions(regions: list[IRRegion]) -> None: 48 for region in regions: 49 _find_max_pe(region.body.nodes) 50 _check_regions(region.body.regions) 51 52 _check_regions(graph.regions) 53 54 pe_count = max(1, max_pe_id + 1) # At least 1 PE 55 return SystemConfig( 56 pe_count=pe_count, 57 sm_count=1, # Default to 1 SM 58 iram_capacity=DEFAULT_IRAM_CAPACITY, 59 frame_count=DEFAULT_FRAME_COUNT, 60 loc=SourceLoc(0, 0), 61 ) 62 63 64 65 66def _find_node_scope(graph: IRGraph, node_name: str) -> str | None: 67 """Find the function scope of a node. 68 69 Returns the tag of the function region containing this node, or None if top-level. 70 71 Args: 72 graph: The IRGraph 73 node_name: Name of the node to find 74 75 Returns: 76 Function region tag if node is in a function, None if top-level 77 """ 78 # Check if node is in top-level nodes 79 if node_name in graph.nodes: 80 return None 81 82 # Search in regions recursively 83 def _search_regions(regions: list[IRRegion]) -> str | None: 84 for region in regions: 85 if region.kind == RegionKind.FUNCTION: 86 # Check if node is in this function's body 87 if node_name in region.body.nodes: 88 return region.tag 89 # Recursively search nested regions (shouldn't happen with current design) 90 result = _search_regions(region.body.regions) 91 if result: 92 return result 93 else: 94 # For LOCATION regions, nodes are still top-level conceptually 95 if node_name in region.body.nodes: 96 return None 97 # Search nested regions 98 result = _search_regions(region.body.regions) 99 if result: 100 return result 101 return None 102 103 return _search_regions(graph.regions) 104 105 106def _build_adjacency(graph: IRGraph, all_nodes: dict[str, IRNode]) -> dict[str, set[str]]: 107 """Build adjacency map from edges: node -> set of connected neighbours. 108 109 Args: 110 graph: The IRGraph 111 all_nodes: Dictionary of all nodes 112 113 Returns: 114 Dictionary mapping node names to sets of connected node names 115 """ 116 adjacency: dict[str, set[str]] = defaultdict(set) 117 118 def _process_edges(edges: list) -> None: 119 for edge in edges: 120 # Both source and dest are neighbours 121 adjacency[edge.source].add(edge.dest) 122 adjacency[edge.dest].add(edge.source) 123 124 _process_edges(graph.edges) 125 126 # Also process edges in regions 127 def _process_regions_edges(regions: list[IRRegion]) -> None: 128 for region in regions: 129 _process_edges(region.body.edges) 130 _process_regions_edges(region.body.regions) 131 132 _process_regions_edges(graph.regions) 133 134 return adjacency 135 136 137def _count_iram_cost(node: IRNode) -> int: 138 """Count IRAM slots used by a node. 139 140 In the frame model, all instructions use 1 IRAM slot. 141 Matching is handled by frame SRAM, not IRAM entries. 142 143 Args: 144 node: The IRNode 145 146 Returns: 147 Number of IRAM slots used (always 1) 148 """ 149 return 1 150 151 152def _auto_place_nodes( 153 graph: IRGraph, 154 system: SystemConfig, 155 all_nodes: dict[str, IRNode], 156 adjacency: dict[str, set[str]], 157) -> tuple[dict[str, IRNode], list[AssemblyError]]: 158 """Auto-place unplaced nodes using greedy bin-packing with locality heuristic. 159 160 Algorithm: 161 1. Identify unplaced nodes (pe=None) 162 2. For each unplaced node in order: 163 a. Find PE of connected neighbours (use updated_nodes for current placements) 164 b. Prefer PE with most neighbours (locality) 165 c. Tie-break by remaining IRAM capacity 166 d. If no PE has room, record error and continue 167 3. Return updated nodes and any placement errors 168 169 Args: 170 graph: The IRGraph 171 system: SystemConfig with pe_count, iram_capacity, frame_count 172 all_nodes: Dictionary of all nodes 173 adjacency: Adjacency map 174 175 Returns: 176 Tuple of (updated nodes dict, list of placement errors) 177 """ 178 errors: list[AssemblyError] = [] 179 180 # Track resource usage per PE: (iram_used, frames_used) 181 iram_used = [0] * system.pe_count 182 frames_used = [0] * system.pe_count 183 # Track dyadic offset usage per PE for matchable offset warnings 184 dyadic_offsets_per_pe = [0] * system.pe_count 185 186 # Copy nodes so we can update placement as we go 187 updated_nodes = dict(all_nodes) 188 189 # Initialize PE resource usage from explicitly placed nodes 190 # Track which function scopes have been counted per PE to avoid double-counting 191 act_scopes_per_pe: dict[int, set[str | None]] = {pe_id: set() for pe_id in range(system.pe_count)} 192 193 for node_name, node in updated_nodes.items(): 194 if node.pe is not None: 195 iram_cost = _count_iram_cost(node) 196 iram_used[node.pe] += iram_cost 197 # Count frames per function scope, not per node 198 scope = _find_node_scope(graph, node_name) 199 if scope not in act_scopes_per_pe[node.pe]: 200 act_scopes_per_pe[node.pe].add(scope) 201 frames_used[node.pe] += 1 202 # Track dyadic offsets for matchable offset warning 203 if is_dyadic(node.opcode, node.const): 204 dyadic_offsets_per_pe[node.pe] += 1 205 206 # For unplaced nodes, we'll track scopes similarly 207 act_scopes_updated: dict[int, set[str | None]] = {pe_id: set(scopes) for pe_id, scopes in act_scopes_per_pe.items()} 208 209 # Process nodes in insertion order 210 for node_name, node in all_nodes.items(): 211 if node.pe is not None: 212 # Already placed, skip 213 continue 214 215 # Find neighbours and their PEs (from updated_nodes to include newly placed nodes) 216 neighbours = adjacency.get(node_name, set()) 217 neighbour_pes: list[int] = [] 218 for neighbour_name in neighbours: 219 neighbour = updated_nodes.get(neighbour_name) 220 if neighbour and neighbour.pe is not None: 221 neighbour_pes.append(neighbour.pe) 222 223 # Count PE occurrences among neighbours (for locality heuristic) 224 pe_counts: dict[int, int] = Counter(neighbour_pes) 225 226 # Sort PEs by: most neighbours first, then most remaining IRAM 227 candidate_pes = list(range(system.pe_count)) 228 candidate_pes.sort( 229 key=lambda pe: ( 230 -pe_counts.get(pe, 0), # Negative so most neighbours come first 231 -(system.iram_capacity - iram_used[pe]), # Then most room 232 ), 233 ) 234 235 # Find first PE with room 236 iram_cost = _count_iram_cost(node) 237 node_scope = _find_node_scope(graph, node_name) 238 placed = False 239 for pe in candidate_pes: 240 # Check if this scope is new to this PE 241 scope_is_new = node_scope not in act_scopes_updated[pe] 242 frames_needed = 1 if scope_is_new else 0 243 244 if ( 245 iram_used[pe] + iram_cost <= system.iram_capacity 246 and frames_used[pe] + frames_needed <= system.frame_count 247 ): 248 # Place node on this PE 249 updated_nodes[node_name] = replace(node, pe=pe) 250 iram_used[pe] += iram_cost 251 if scope_is_new: 252 act_scopes_updated[pe].add(node_scope) 253 frames_used[pe] += 1 254 # Track dyadic offsets 255 if is_dyadic(node.opcode, node.const): 256 dyadic_offsets_per_pe[pe] += 1 257 placed = True 258 break 259 260 if not placed: 261 # No PE has room - generate error with utilization breakdown 262 error = _format_placement_overflow_error(node, system, iram_used, frames_used) 263 errors.append(error) 264 265 # Check matchable offset limits and emit warnings if exceeded 266 for pe_id in range(system.pe_count): 267 if dyadic_offsets_per_pe[pe_id] > system.matchable_offsets: 268 # Find a node on this PE to use for location 269 warning_node = None 270 for node_name, node in updated_nodes.items(): 271 if node.pe == pe_id and is_dyadic(node.opcode, node.const): 272 warning_node = node 273 break 274 275 if warning_node: 276 errors.append(AssemblyError( 277 loc=warning_node.loc, 278 category=ErrorCategory.FRAME, 279 message=f"PE {pe_id} uses {dyadic_offsets_per_pe[pe_id]} matchable offsets " 280 f"(limit: {system.matchable_offsets})", 281 severity=ErrorSeverity.WARNING, 282 )) 283 284 return updated_nodes, errors 285 286 287def _format_placement_overflow_error( 288 node: IRNode, 289 system: SystemConfig, 290 iram_used: list[int], 291 frames_used: list[int], 292) -> AssemblyError: 293 """Format a placement overflow error with per-PE utilization breakdown. 294 295 Args: 296 node: The node that couldn't be placed 297 system: SystemConfig 298 iram_used: List of IRAM slots used per PE 299 frames_used: List of frames used per PE 300 301 Returns: 302 AssemblyError with detailed breakdown 303 """ 304 breakdown_lines = [] 305 for pe_id in range(system.pe_count): 306 breakdown_lines.append( 307 f" PE{pe_id}: {iram_used[pe_id]}/{system.iram_capacity} IRAM slots, " 308 f"{frames_used[pe_id]}/{system.frame_count} frames" 309 ) 310 311 breakdown = "\n".join(breakdown_lines) 312 message = f"Cannot place node '{node.name}': all PEs are full.\n{breakdown}" 313 314 return AssemblyError( 315 loc=node.loc, 316 category=ErrorCategory.PLACEMENT, 317 message=message, 318 suggestions=[], 319 ) 320 321 322def place(graph: IRGraph) -> IRGraph: 323 """Placement pass: validate explicit placements and auto-place unplaced nodes. 324 325 Process: 326 1. Infer or use provided SystemConfig 327 2. Validate explicitly placed nodes (pe is not None) 328 3. Auto-place any unplaced nodes using greedy bin-packing + locality 329 4. Validate all PE IDs are < pe_count 330 331 Args: 332 graph: The IRGraph to place 333 334 Returns: 335 New IRGraph with all nodes placed and placement errors appended 336 """ 337 # Determine system config 338 system = graph.system if graph.system is not None else _infer_system_config(graph) 339 340 errors = list(graph.errors) 341 342 # Collect all nodes 343 all_nodes = collect_all_nodes(graph) 344 345 # First pass: validate explicitly placed nodes (reject invalid PE IDs) 346 valid_nodes = {} 347 for node_name, node in all_nodes.items(): 348 if node.pe is not None and node.pe >= system.pe_count: 349 error = AssemblyError( 350 loc=node.loc, 351 category=ErrorCategory.PLACEMENT, 352 message=f"Node '{node_name}' placed on PE{node.pe} but system only has {system.pe_count} PEs (0-{system.pe_count - 1}).", 353 suggestions=[], 354 ) 355 errors.append(error) 356 else: 357 valid_nodes[node_name] = node 358 359 all_nodes = valid_nodes 360 361 # Check if any nodes are unplaced 362 unplaced_nodes = [node for node in all_nodes.values() if node.pe is None] 363 364 if unplaced_nodes: 365 # Auto-place unplaced nodes 366 adjacency = _build_adjacency(graph, all_nodes) 367 all_nodes, placement_errors = _auto_place_nodes(graph, system, all_nodes, adjacency) 368 errors.extend(placement_errors) 369 370 # Update graph with placed nodes 371 # This ensures nodes inside function scopes receive updated PE assignments 372 result_graph = update_graph_nodes(graph, all_nodes) 373 return replace(result_graph, system=system, errors=errors)