OR-1 dataflow CPU sketch
at main 1077 lines 40 kB view raw
1"""Resource allocation pass for the OR1 assembler. 2 3Allocates IRAM offsets, activation IDs, frame layouts, and resolves symbolic 4destinations to FrameDest values with PE, offset, activation, and port information. 5 6Reference: Phase 5 design doc, Tasks 1-4. 7""" 8 9from __future__ import annotations 10 11from dataclasses import replace 12from collections import defaultdict 13 14from asm.errors import AssemblyError, ErrorCategory, ErrorSeverity 15from asm.ir import IRGraph, IRNode, IREdge, SourceLoc, ResolvedDest, CallSite, collect_all_nodes_and_edges, update_graph_nodes 16from asm.opcodes import is_dyadic, is_monadic 17from cm_inst import ArithOp, LogicOp, MemOp, Port, RoutingOp, OutputStyle, FrameDest, TokenKind 18 19 20# Module-level constants 21_SINK_MEMOPS = frozenset({MemOp.WRITE, MemOp.CLEAR, MemOp.FREE, MemOp.SET_PAGE, MemOp.WRITE_IMM}) 22 23 24def _group_nodes_by_pe(nodes: dict[str, IRNode]) -> dict[int, list[IRNode]]: 25 """Group nodes by their PE assignment. 26 27 Args: 28 nodes: Dictionary of all nodes 29 30 Returns: 31 Dictionary mapping PE ID to list of nodes on that PE 32 """ 33 groups: dict[int, list[IRNode]] = defaultdict(list) 34 for node in nodes.values(): 35 if node.pe is not None: 36 groups[node.pe].append(node) 37 return groups 38 39 40def _extract_function_scope(node_name: str) -> str: 41 """Extract function scope from qualified node name. 42 43 Strips macro scope segments (starting with #) before extracting the function scope. 44 Macro scopes are for name uniqueness only — they don't allocate context slots. 45 46 Examples: 47 "$main.&add" -> "$main" 48 "$main.#loop_0.&counter" -> "$main" (macro segment stripped) 49 "#loop_0.&counter" -> "" (macro at root scope) 50 "$func.#outer_1.#inner_2.&label" -> "$func" (all macro segments stripped) 51 "&top_level" -> "" (root scope) 52 53 Args: 54 node_name: Qualified node name 55 56 Returns: 57 Function scope name, or empty string for root scope 58 """ 59 if "." not in node_name: 60 return "" 61 62 # Split by dots and filter out segments starting with # 63 segments = node_name.split(".") 64 filtered = [seg for seg in segments if not seg.startswith("#")] 65 66 if not filtered: 67 # All segments were macro scopes 68 return "" 69 70 # Return the first non-macro segment if it starts with $, else root scope 71 first_segment = filtered[0] 72 if first_segment.startswith("$"): 73 return first_segment 74 return "" 75 76 77def _assign_iram_offsets( 78 nodes_on_pe: list[IRNode], 79 all_nodes: dict[str, IRNode], 80 iram_capacity: int, 81 pe_id: int, 82) -> tuple[dict[str, IRNode], list[AssemblyError]]: 83 """Assign provisional IRAM offsets to nodes on a PE. 84 85 Dyadic instructions get offsets 0..D-1, monadic get D..M-1. 86 All instructions cost 1 slot (per Phase 4 change). 87 88 Deduplication happens later via _deduplicate_iram() after 89 frame layouts and modes are computed. 90 91 Args: 92 nodes_on_pe: List of nodes on this PE 93 all_nodes: All nodes (for name lookup) 94 iram_capacity: Maximum IRAM slots for this PE 95 pe_id: The PE ID (for error messages) 96 97 Returns: 98 Tuple of (updated_nodes dict, errors list) 99 """ 100 errors = [] 101 updated_nodes = {} 102 103 # Partition into dyadic and monadic, preserving order within each partition 104 # Seed nodes are excluded — they don't occupy IRAM slots 105 dyadic_nodes = [] 106 monadic_nodes = [] 107 108 for node in nodes_on_pe: 109 if node.seed: 110 continue 111 if is_dyadic(node.opcode, node.const): 112 dyadic_nodes.append(node) 113 else: 114 monadic_nodes.append(node) 115 116 # Assign offsets 117 total = len(dyadic_nodes) + len(monadic_nodes) 118 if total > iram_capacity: 119 # Generate overflow error 120 error_msg = f"PE{pe_id} IRAM overflow: {total} instructions but only {iram_capacity} slots.\n" 121 if dyadic_nodes: 122 dyadic_names = ", ".join([n.name.split('.')[-1] for n in dyadic_nodes]) 123 error_msg += f" Dyadic: {dyadic_names} ({len(dyadic_nodes)} instructions)\n" 124 if monadic_nodes: 125 monadic_names = ", ".join([n.name.split('.')[-1] for n in monadic_nodes]) 126 error_msg += f" Monadic: {monadic_names} ({len(monadic_nodes)} instructions)" 127 128 error = AssemblyError( 129 loc=SourceLoc(0, 0), 130 category=ErrorCategory.RESOURCE, 131 message=error_msg, 132 ) 133 errors.append(error) 134 return {}, errors 135 136 # Assign offsets 137 for offset, node in enumerate(dyadic_nodes): 138 updated_nodes[node.name] = replace(node, iram_offset=offset) 139 140 for offset, node in enumerate(monadic_nodes): 141 updated_nodes[node.name] = replace(node, iram_offset=len(dyadic_nodes) + offset) 142 143 return updated_nodes, errors 144 145 146def _deduplicate_iram( 147 nodes_on_pe: dict[str, IRNode], 148 pe_id: int, 149) -> dict[str, IRNode]: 150 """Deduplicate IRAM entries for nodes that produce identical Instruction templates. 151 152 Two nodes share an IRAM offset when they have identical: 153 opcode (type and value), output (OutputStyle), has_const, dest_count, wide, fref. 154 155 Args: 156 nodes_on_pe: Dictionary of nodes on this PE 157 pe_id: The PE ID (for diagnostics) 158 159 Returns: 160 Updated nodes dictionary with deduplicated IRAM offsets 161 """ 162 template_to_offset: dict[tuple, int] = {} 163 updated = {} 164 165 for name, node in nodes_on_pe.items(): 166 if node.seed or node.iram_offset is None: 167 updated[name] = node 168 continue 169 170 # Build template key from the fields that make an Instruction 171 mode = node.mode # (OutputStyle, has_const, dest_count) 172 if mode is None: 173 updated[name] = node 174 continue 175 176 # Include opcode type to distinguish between ArithOp.ADD and MemOp.READ 177 # which both have value 0 178 template_key = ( 179 type(node.opcode).__name__, # Include type name to disambiguate IntEnums 180 int(node.opcode), # Include numeric value 181 mode[0].name, # Include OutputStyle name string 182 mode[1], # has_const 183 mode[2], # dest_count 184 node.wide, 185 node.fref, 186 ) 187 188 if template_key in template_to_offset: 189 # Reuse existing offset 190 updated[name] = replace(node, iram_offset=template_to_offset[template_key]) 191 else: 192 template_to_offset[template_key] = node.iram_offset 193 updated[name] = node 194 195 return updated 196 197 198def _compute_modes( 199 nodes_on_pe: dict[str, IRNode], 200 edges_by_source: dict[str, list[IREdge]], 201) -> dict[str, IRNode]: 202 """Compute (OutputStyle, has_const, dest_count) for each node from edge topology. 203 204 Args: 205 nodes_on_pe: Dictionary of nodes on this PE 206 edges_by_source: Edges indexed by source node name 207 208 Returns: 209 Updated nodes dictionary with mode field set 210 """ 211 updated = {} 212 for name, node in nodes_on_pe.items(): 213 if node.seed: 214 updated[name] = node 215 continue 216 217 out_edges = edges_by_source.get(name, []) 218 dest_count = len(out_edges) 219 has_const = node.const is not None 220 221 # Determine OutputStyle from opcode and edge topology 222 if isinstance(node.opcode, MemOp): 223 # SM instructions: mode depends on MemOp semantics 224 # WRITE, CLEAR, FREE, SET_PAGE, WRITE_IMM → no return value → SINK 225 # READ, RD_INC, RD_DEC, CMP_SW, RAW_READ, EXT → return value → INHERIT 226 if node.opcode in _SINK_MEMOPS: 227 output = OutputStyle.SINK 228 dest_count = 0 229 else: 230 output = OutputStyle.INHERIT 231 elif node.opcode == RoutingOp.FREE_FRAME: 232 output = OutputStyle.SINK 233 dest_count = 0 234 elif node.opcode == RoutingOp.EXTRACT_TAG: 235 output = OutputStyle.INHERIT 236 else: 237 # Check if any outgoing edge has ctx_override=True 238 # ctx_override=True on an edge means the source node is a 239 # cross-function return — its left operand carries a packed 240 # flit 1 (from EXTRACT_TAG) that determines the destination. 241 # In the frame model, this maps to OutputStyle.CHANGE_TAG. 242 has_ctx_override = any(e.ctx_override for e in out_edges) 243 if has_ctx_override: 244 output = OutputStyle.CHANGE_TAG 245 dest_count = 1 246 elif dest_count == 0: 247 output = OutputStyle.SINK 248 else: 249 output = OutputStyle.INHERIT 250 251 mode = (output, has_const, dest_count) 252 updated[name] = replace(node, mode=mode) 253 254 return updated 255 256 257def _compute_frame_layouts( 258 nodes_on_pe: dict[str, IRNode], 259 edges_by_source: dict[str, list[IREdge]], 260 edges_by_dest: dict[str, list[IREdge]], 261 all_nodes: dict[str, IRNode], 262 frame_slots: int, 263 matchable_offsets: int, 264 pe_id: int, 265) -> tuple[dict[str, IRNode], list[AssemblyError]]: 266 """Compute frame slot layouts per activation. 267 268 Slot assignment order: 269 0 to matchable_offsets-1: match operands (one pair per dyadic instruction) 270 then: constants (deduplicated by value) 271 then: destinations (deduplicated by FrameDest identity) 272 then: sinks and SM parameters 273 274 All activations of the same function share the canonical layout. 275 276 Args: 277 nodes_on_pe: Dictionary of nodes on this PE 278 edges_by_source: Edges indexed by source node name 279 edges_by_dest: Edges indexed by destination node name 280 all_nodes: All nodes in graph 281 frame_slots: Total slots per frame 282 matchable_offsets: Number of hardware match slots (const/dest slots start here) 283 pe_id: The PE ID (for error messages) 284 285 Returns: 286 Tuple of (updated_nodes dict, errors list) 287 """ 288 from asm.ir import FrameLayout, FrameSlotMap 289 290 errors = [] 291 updated = {} 292 293 # Group nodes by activation ID 294 nodes_by_act_id = defaultdict(list) 295 for name, node in nodes_on_pe.items(): 296 if not node.seed and node.act_id is not None: 297 nodes_by_act_id[node.act_id].append(node) 298 299 # Compute frame layout for each activation 300 act_id_to_layout = {} # act_id -> FrameLayout 301 302 for act_id, nodes_in_act in nodes_by_act_id.items(): 303 # Count dyadic instructions (match operands) 304 dyadic_count = sum(1 for n in nodes_in_act if is_dyadic(n.opcode, n.const)) 305 306 # Warn if dyadic_count exceeds hardware match slots 307 if dyadic_count > matchable_offsets: 308 warning = AssemblyError( 309 loc=SourceLoc(0, 0), 310 category=ErrorCategory.FRAME, 311 severity=ErrorSeverity.WARNING, 312 message=( 313 f"PE{pe_id} activation {act_id}: {dyadic_count} dyadic instructions " 314 f"but only {matchable_offsets} hardware match slots. " 315 f"Liveness-based slot sharing required (see AC5.4 note)." 316 ), 317 ) 318 errors.append(warning) 319 320 # Collect constants (deduplicated by value) 321 unique_const_values = set() 322 for n in nodes_in_act: 323 if n.const is not None and not isinstance(n.const, (str, type(None))): 324 unique_const_values.add(n.const) 325 const_count = len(unique_const_values) 326 327 # Collect destination slots: each node needs dest_count slots for its destinations 328 # These are not deduplicated - each node gets its own slot(s) 329 dest_count = sum(n.mode[2] for n in nodes_in_act if n.mode is not None) 330 331 # Count slots needed 332 # Match slots reserved at 0 to matchable_offsets-1 (regardless of dyadic_count) 333 match_slot_count = matchable_offsets 334 const_slot_count = const_count 335 dest_slot_count = dest_count 336 337 # SM params and sinks (only count actual sink MemOps) 338 sink_slot_count = len([n for n in nodes_in_act if isinstance(n.opcode, MemOp) and n.opcode in _SINK_MEMOPS]) 339 340 total_slots = match_slot_count + const_slot_count + dest_slot_count + sink_slot_count 341 if total_slots > frame_slots: 342 error_msg = ( 343 f"Frame slot overflow on PE{pe_id} activation {act_id}: " 344 f"{total_slots} slots needed, {frame_slots} available.\n" 345 f" Match region (reserved): {match_slot_count}\n" 346 f" Constants: {const_slot_count}\n" 347 f" Destinations: {dest_slot_count}\n" 348 f" Sinks/SM params: {sink_slot_count}" 349 ) 350 error = AssemblyError( 351 loc=SourceLoc(0, 0), 352 category=ErrorCategory.FRAME, 353 message=error_msg, 354 ) 355 errors.append(error) 356 continue 357 358 # Build frame layout 359 # NOTE: With interleaved const+dest allocation, const_slots and dest_slots are not 360 # contiguous separate regions. They're interleaved per node. The slot_map below is 361 # therefore approximate for documentation; the actual slot allocation comes from node frefs. 362 match_slots = tuple(range(match_slot_count)) 363 364 # All non-match slots are either const or dest type 365 # For documentation: mark const slots and dest slots 366 # This is approximate since they're interleaved 367 const_start = match_slot_count 368 const_slots = tuple(range(const_start, const_start + const_slot_count)) 369 dest_start = const_start + const_slot_count 370 dest_slots = tuple(range(dest_start, dest_start + dest_slot_count)) 371 sink_start = dest_start + dest_slot_count 372 sink_slots = tuple(range(sink_start, sink_start + sink_slot_count)) 373 374 slot_map = FrameSlotMap( 375 match_slots=match_slots, 376 const_slots=const_slots, 377 dest_slots=dest_slots, 378 sink_slots=sink_slots, 379 ) 380 layout = FrameLayout(slot_map=slot_map, total_slots=total_slots) 381 act_id_to_layout[act_id] = layout 382 383 # Assign frame layouts and frefs to nodes 384 # First, build per-activation node-to-fref mapping 385 act_id_to_node_frefs = {} # act_id -> {node_name -> fref} 386 387 for act_id, nodes_in_act in nodes_by_act_id.items(): 388 if act_id not in act_id_to_layout: 389 continue 390 391 layout = act_id_to_layout[act_id] 392 393 # Assign fref to each node in order 394 # const_nodes need: [const, dest1, dest2, ...] layout 395 # no-const,dest_nodes need: [dest1, dest2, ...] layout 396 # Allocation order: process nodes in sorted order, mixing const and dest allocations 397 # to ensure that const nodes get const_slot at fref with dests at fref+1 onward 398 node_frefs = {} 399 400 # Separate nodes by type 401 const_nodes = [] 402 no_const_dest_nodes = [] 403 no_const_no_dest_nodes = [] 404 sink_nodes = [] 405 406 for node in sorted(nodes_in_act, key=lambda n: n.name): 407 if node.seed or node.mode is None: 408 continue 409 output_style, has_const, dest_count = node.mode 410 if output_style == OutputStyle.SINK: 411 sink_nodes.append(node) 412 elif has_const: 413 const_nodes.append(node) 414 elif dest_count > 0: 415 no_const_dest_nodes.append(node) 416 else: 417 # No const, no destinations - still need a slot for frame matching 418 no_const_no_dest_nodes.append(node) 419 420 # Allocate const nodes first (they get const slot + dest slots) 421 slot_counter = matchable_offsets # Start after match region 422 for node in const_nodes: 423 _, has_const, dest_count = node.mode 424 # Assign fref to const slot 425 node_frefs[node.name] = slot_counter 426 slot_counter += 1 + dest_count # const + dests 427 428 # Then allocate no-const,dest nodes 429 for node in no_const_dest_nodes: 430 _, has_const, dest_count = node.mode 431 # Assign fref to first dest slot 432 node_frefs[node.name] = slot_counter 433 slot_counter += dest_count 434 435 # Then allocate no-const,no-dest nodes (still need fref slot for result writeback in SINK mode) 436 for node in no_const_no_dest_nodes: 437 node_frefs[node.name] = slot_counter 438 slot_counter += 1 439 440 # Finally allocate sink nodes 441 for node in sink_nodes: 442 node_frefs[node.name] = slot_counter 443 slot_counter += 1 444 445 act_id_to_node_frefs[act_id] = node_frefs 446 447 # Now assign frame layouts and frefs to nodes 448 for name, node in nodes_on_pe.items(): 449 if node.seed: 450 updated[name] = node 451 continue 452 453 if node.act_id is not None and node.act_id in act_id_to_layout: 454 layout = act_id_to_layout[node.act_id] 455 456 # Get fref from the per-activation mapping 457 node_frefs = act_id_to_node_frefs.get(node.act_id, {}) 458 fref = node_frefs.get(name) 459 460 if fref is not None: 461 updated[name] = replace(node, frame_layout=layout, fref=fref) 462 else: 463 updated[name] = replace(node, frame_layout=layout) 464 else: 465 updated[name] = node 466 467 return updated, errors 468 469 470def _assign_act_ids( 471 nodes_on_pe: list[IRNode], 472 all_nodes: dict[str, IRNode], 473 frame_count: int, 474 pe_id: int, 475 call_sites: list[CallSite] | None = None, 476) -> tuple[dict[str, IRNode], list[AssemblyError]]: 477 """Assign activation IDs (0 to frame_count-1) per function scope per PE. 478 479 Implements per-call-site activation allocation: 480 - Root scope always gets act_id=0 481 - Functions without call sites get one act_id by the existing scope rule 482 - Each call site allocates a fresh act_id on the PE(s) where the callee lives 483 484 Args: 485 nodes_on_pe: List of nodes on this PE 486 all_nodes: All nodes (for name lookup) 487 frame_count: Maximum activation IDs for this PE (default 8) 488 pe_id: The PE ID (for error messages) 489 call_sites: Optional list of CallSite objects for per-call-site allocation 490 491 Returns: 492 Tuple of (updated_nodes dict, errors list) 493 """ 494 if call_sites is None: 495 call_sites = [] 496 497 errors = [] 498 updated_nodes = {} 499 500 # Build global mapping of which nodes belong to which call sites 501 # Trampoline and free_frame nodes get the call site's act_id 502 callsite_for_node = {} # node_name -> CallSite 503 for call_site in call_sites: 504 for tramp_node in call_site.trampoline_nodes: 505 callsite_for_node[tramp_node] = call_site 506 for free_node in call_site.free_frame_nodes: 507 callsite_for_node[free_node] = call_site 508 509 # Build mapping: function scope -> call site (for function body nodes) 510 func_scope_to_callsite = {} # func_name -> CallSite 511 for call_site in call_sites: 512 func_scope_to_callsite[call_site.func_name] = call_site 513 514 # Allocate activation IDs for this PE 515 next_act_id = 0 516 act_breakdown = {} # For overflow error reporting 517 scope_to_act_id = {} 518 root_act_id = 0 # Default root scope activation 519 520 # Check if there are any root-scope nodes on this PE 521 has_root_scope_nodes = any( 522 not node.seed and _extract_function_scope(node.name) == "" 523 for node in nodes_on_pe 524 ) 525 526 # Root scope always gets act_id=0 if it has nodes on this PE 527 if has_root_scope_nodes: 528 scope_to_act_id[""] = root_act_id 529 act_breakdown["root"] = 1 530 next_act_id = 1 531 532 for node in nodes_on_pe: 533 if node.seed: 534 continue 535 scope = _extract_function_scope(node.name) 536 # Only process function scopes not already assigned 537 if scope and scope not in scope_to_act_id: 538 # Check if this function has any call sites 539 has_call_sites = any(cs.func_name == scope for cs in call_sites) 540 if not has_call_sites: 541 # No call sites, assign one slot (per-scope per-PE) 542 if next_act_id >= frame_count: 543 # Overflow 544 error_msg = _build_activation_overflow_message( 545 pe_id, frame_count, next_act_id, act_breakdown 546 ) 547 error = AssemblyError( 548 loc=SourceLoc(0, 0), 549 category=ErrorCategory.FRAME, 550 message=error_msg, 551 ) 552 errors.append(error) 553 return {}, errors 554 scope_to_act_id[scope] = next_act_id 555 act_breakdown[scope] = 1 556 next_act_id += 1 557 558 # Now allocate per-call-site activation IDs (one per call site per PE) 559 # Build mapping: call_site -> act_id on this PE 560 call_site_to_act_id_on_pe = {} 561 for call_site in call_sites: 562 # Check if any trampoline or free_frame node for this call site is on this PE 563 has_node_on_pe = False 564 for tramp_node in call_site.trampoline_nodes: 565 if tramp_node in all_nodes and all_nodes[tramp_node].pe == pe_id: 566 has_node_on_pe = True 567 break 568 if not has_node_on_pe: 569 for free_node in call_site.free_frame_nodes: 570 if free_node in all_nodes and all_nodes[free_node].pe == pe_id: 571 has_node_on_pe = True 572 break 573 574 if has_node_on_pe: 575 # This call site has nodes on this PE, allocate an activation ID 576 if next_act_id >= frame_count: 577 # Overflow 578 error_msg = _build_activation_overflow_message( 579 pe_id, frame_count, next_act_id, act_breakdown 580 ) 581 error = AssemblyError( 582 loc=SourceLoc(0, 0), 583 category=ErrorCategory.FRAME, 584 message=error_msg, 585 ) 586 errors.append(error) 587 return {}, errors 588 589 call_site_to_act_id_on_pe[call_site] = next_act_id 590 act_breakdown[f"{call_site.func_name} call site #{call_site.call_id}"] = 1 591 next_act_id += 1 592 593 # Check budget warning (75%) 594 if frame_count > 0: 595 utilisation = next_act_id / frame_count 596 if utilisation >= 0.75: 597 percent = int(utilisation * 100) 598 warning = AssemblyError( 599 loc=SourceLoc(0, 0), 600 category=ErrorCategory.FRAME, 601 severity=ErrorSeverity.WARNING, 602 message=f"PE{pe_id}: {next_act_id}/{frame_count} activation IDs used ({percent}%)", 603 ) 604 errors.append(warning) 605 606 # Assign activation IDs to nodes 607 for node in nodes_on_pe: 608 if node.seed: 609 continue 610 611 # Check if this node is a trampoline or free_frame node for a call site 612 act_id_value = None 613 if node.name in callsite_for_node: 614 call_site = callsite_for_node[node.name] 615 act_id_value = call_site_to_act_id_on_pe.get(call_site) 616 617 # If not part of a call site, check if it's a function body node 618 if act_id_value is None: 619 scope = _extract_function_scope(node.name) 620 if scope in func_scope_to_callsite: 621 # Function body node — gets the call site's act_id 622 cs = func_scope_to_callsite[scope] 623 act_id_value = call_site_to_act_id_on_pe.get(cs) 624 if act_id_value is None: 625 act_id_value = scope_to_act_id.get(scope, root_act_id) 626 627 updated_nodes[node.name] = replace(node, act_id=act_id_value) 628 629 return updated_nodes, errors 630 631 632def _build_activation_overflow_message(pe_id: int, frame_count: int, used: int, breakdown: dict) -> str: 633 """Build a detailed activation ID overflow error message. 634 635 Args: 636 pe_id: The PE ID 637 frame_count: Total available activation IDs 638 used: Number of IDs needed 639 breakdown: Dictionary mapping scope/call site to ID count 640 641 Returns: 642 Formatted error message 643 """ 644 lines = [ 645 f"Activation ID exhaustion on PE{pe_id}: {used} IDs needed, {frame_count} available" 646 ] 647 for scope_name, count in breakdown.items(): 648 if scope_name == "root": 649 lines.append(f" Root scope: {count} ID") 650 else: 651 lines.append(f" {scope_name}: {count} ID") 652 lines.append("Consider inlining frequently-called functions to reduce frame pressure.") 653 return "\n".join(lines) 654 655 656def _assign_sm_ids( 657 all_nodes: dict[str, IRNode], 658 sm_count: int, 659) -> tuple[dict[str, IRNode], list[AssemblyError]]: 660 """Assign SM IDs to MemOp instruction nodes that lack one. 661 662 For single-SM systems, defaults to sm_id=0. For multi-SM systems where 663 the SM target is ambiguous, reports an error. 664 665 Args: 666 all_nodes: Dictionary of all nodes 667 sm_count: Number of SMs in the system 668 669 Returns: 670 Tuple of (updated nodes dict, list of errors) 671 """ 672 errors: list[AssemblyError] = [] 673 updated: dict[str, IRNode] = {} 674 675 for name, node in all_nodes.items(): 676 if isinstance(node.opcode, MemOp) and node.sm_id is None: 677 if sm_count == 0: 678 errors.append(AssemblyError( 679 loc=node.loc, 680 category=ErrorCategory.RESOURCE, 681 message=f"Node '{name}' uses memory operation '{node.opcode.name}' but system has no SMs.", 682 )) 683 elif sm_count == 1: 684 updated[name] = replace(node, sm_id=0) 685 else: 686 errors.append(AssemblyError( 687 loc=node.loc, 688 category=ErrorCategory.RESOURCE, 689 message=( 690 f"Node '{name}' uses memory operation '{node.opcode.name}' but no SM target specified " 691 f"and system has {sm_count} SMs. Cannot infer target." 692 ), 693 )) 694 695 return updated, errors 696 697 698_COMMUTATIVE_OPS: frozenset = frozenset({ 699 ArithOp.ADD, 700 LogicOp.AND, 701 LogicOp.OR, 702 LogicOp.XOR, 703 LogicOp.EQ, 704}) 705 706 707def _build_edge_index_by_dest(edges: list[IREdge]) -> dict[str, list[IREdge]]: 708 """Build index of edges by destination node name.""" 709 index: dict[str, list[IREdge]] = defaultdict(list) 710 for edge in edges: 711 index[edge.dest].append(edge) 712 return index 713 714 715def _validate_noncommutative_const( 716 all_nodes: dict[str, IRNode], 717 edges_by_dest: dict[str, list[IREdge]], 718) -> list[AssemblyError]: 719 """Warn when non-commutative dyadic ops with IRAM const lack explicit ports. 720 721 When a dyadic instruction has a baked-in constant (e.g., `sub 3`), the 722 incoming token goes to whichever port the edge specifies (default: L). 723 For commutative ops this is irrelevant, but for non-commutative ops 724 (sub, lt, gt, etc.) the port determines operand order. If the user 725 didn't specify the port explicitly, emit a warning so they know the 726 implicit default is in effect. 727 728 Args: 729 all_nodes: All nodes in graph 730 edges_by_dest: Edges indexed by destination node name 731 732 Returns: 733 List of warnings for implicit port assignments 734 """ 735 warnings: list[AssemblyError] = [] 736 737 for name, node in all_nodes.items(): 738 if node.seed: 739 continue 740 if node.const is None: 741 continue 742 if is_monadic(node.opcode, node.const): 743 continue 744 if node.opcode in _COMMUTATIVE_OPS: 745 continue 746 747 incoming = edges_by_dest.get(name, []) 748 for edge in incoming: 749 if not edge.port_explicit: 750 warnings.append(AssemblyError( 751 loc=edge.loc, 752 category=ErrorCategory.PORT, 753 severity=ErrorSeverity.WARNING, 754 message=( 755 f"Non-commutative op '{node.opcode.name}' on node '{name}' " 756 f"has an IRAM constant but incoming edge from '{edge.source}' " 757 f"has no explicit port (:L or :R). Defaulting to :L — " 758 f"the token will be the left operand and the constant the right." 759 ), 760 )) 761 762 return warnings 763 764 765def _build_edge_index(edges: list[IREdge]) -> dict[str, list[IREdge]]: 766 """Build index of edges by source node name. 767 768 Args: 769 edges: List of all edges 770 771 Returns: 772 Dictionary mapping source name to list of edges from that source 773 """ 774 index: dict[str, list[IREdge]] = defaultdict(list) 775 for edge in edges: 776 index[edge.source].append(edge) 777 return index 778 779 780def _determine_token_kind(dest_node: IRNode) -> TokenKind: 781 """Determine the token kind for a destination node. 782 783 Token kind is determined by the destination node's opcode: 784 - If dyadic: TokenKind.DYADIC 785 - If monadic: TokenKind.MONADIC 786 - (TokenKind.INLINE is reserved for future use) 787 788 Args: 789 dest_node: The destination IRNode 790 791 Returns: 792 TokenKind enum value 793 """ 794 if is_dyadic(dest_node.opcode, dest_node.const): 795 return TokenKind.DYADIC 796 else: 797 return TokenKind.MONADIC 798 799 800def _resolve_destinations( 801 nodes_on_pe: dict[str, IRNode], 802 all_nodes: dict[str, IRNode], 803 edges_by_source: dict[str, list[IREdge]], 804) -> tuple[dict[str, IRNode], list[AssemblyError]]: 805 """Resolve symbolic destinations to FrameDest objects. 806 807 Uses edge-to-destination mapping rules: 808 - source_port=L -> dest_l 809 - source_port=R -> dest_r 810 - source_port=None: single edge -> dest_l, two edges -> first dest_l, second dest_r 811 812 Each destination is resolved to a FrameDest object containing target PE, 813 IRAM offset, activation ID, port, and token kind. 814 815 Args: 816 nodes_on_pe: Nodes on this PE (with iram_offset, act_id set) 817 all_nodes: All nodes in graph 818 edges_by_source: Edges indexed by source node name 819 820 Returns: 821 Tuple of (updated_nodes dict, errors list) 822 """ 823 errors = [] 824 updated_nodes = {} 825 826 for node_name, node in nodes_on_pe.items(): 827 updated_node = node 828 source_edges = edges_by_source.get(node_name, []) 829 830 # Validate edge count 831 if len(source_edges) > 2: 832 error = AssemblyError( 833 loc=node.loc, 834 category=ErrorCategory.PORT, 835 message=f"Node '{node_name}' has {len(source_edges)} outgoing edges, but maximum is 2.", 836 ) 837 errors.append(error) 838 continue 839 840 # Validate source_port conflicts 841 source_ports = [e.source_port for e in source_edges] 842 explicit_ports = [p for p in source_ports if p is not None] 843 if len(explicit_ports) != len(set(explicit_ports)): 844 error = AssemblyError( 845 loc=node.loc, 846 category=ErrorCategory.PORT, 847 message=f"Node '{node_name}' has conflicting source_port qualifiers.", 848 ) 849 errors.append(error) 850 continue 851 852 # Validate mixed explicit/implicit 853 if len(explicit_ports) > 0 and len(explicit_ports) < len(source_edges): 854 error = AssemblyError( 855 loc=node.loc, 856 category=ErrorCategory.PORT, 857 message=f"Node '{node_name}' has mixed explicit and implicit source ports.", 858 ) 859 errors.append(error) 860 continue 861 862 # Resolve edges to destinations 863 if len(source_edges) == 0: 864 # No outgoing edges, keep as-is 865 pass 866 elif len(source_edges) == 1: 867 # Single edge -> dest_l 868 edge = source_edges[0] 869 dest_node = all_nodes.get(edge.dest) 870 if dest_node is None: 871 error = AssemblyError( 872 loc=edge.loc, 873 category=ErrorCategory.NAME, 874 message=f"Edge destination '{edge.dest}' not found.", 875 ) 876 errors.append(error) 877 continue 878 879 # Skip resolution if destination lacks required fields 880 if dest_node.iram_offset is None or dest_node.act_id is None or dest_node.pe is None: 881 error = AssemblyError( 882 loc=edge.loc, 883 category=ErrorCategory.RESOURCE, 884 message=f"Destination '{edge.dest}' lacks iram_offset, act_id, or PE assignment.", 885 ) 886 errors.append(error) 887 continue 888 889 frame_dest = FrameDest( 890 target_pe=dest_node.pe, 891 offset=dest_node.iram_offset, 892 act_id=dest_node.act_id, 893 port=edge.port, 894 token_kind=_determine_token_kind(dest_node), 895 ) 896 resolved = ResolvedDest(name=edge.dest, addr=None, frame_dest=frame_dest) 897 updated_node = replace(updated_node, dest_l=resolved) 898 899 else: # len(source_edges) == 2 900 # Two edges: map by source_port or order 901 edges = source_edges 902 if explicit_ports: 903 # All explicit: sort by port 904 edges = sorted(edges, key=lambda e: e.source_port) 905 906 for idx, edge in enumerate(edges): 907 dest_node = all_nodes.get(edge.dest) 908 if dest_node is None: 909 error = AssemblyError( 910 loc=edge.loc, 911 category=ErrorCategory.NAME, 912 message=f"Edge destination '{edge.dest}' not found.", 913 ) 914 errors.append(error) 915 continue 916 917 # Skip resolution if destination lacks required fields 918 if dest_node.iram_offset is None or dest_node.act_id is None or dest_node.pe is None: 919 error = AssemblyError( 920 loc=edge.loc, 921 category=ErrorCategory.RESOURCE, 922 message=f"Destination '{edge.dest}' lacks iram_offset, act_id, or PE assignment.", 923 ) 924 errors.append(error) 925 continue 926 927 frame_dest = FrameDest( 928 target_pe=dest_node.pe, 929 offset=dest_node.iram_offset, 930 act_id=dest_node.act_id, 931 port=edge.port, 932 token_kind=_determine_token_kind(dest_node), 933 ) 934 resolved = ResolvedDest(name=edge.dest, addr=None, frame_dest=frame_dest) 935 936 if idx == 0: 937 updated_node = replace(updated_node, dest_l=resolved) 938 else: 939 updated_node = replace(updated_node, dest_r=resolved) 940 941 updated_nodes[node_name] = updated_node 942 943 return updated_nodes, errors 944 945 946def allocate(graph: IRGraph) -> IRGraph: 947 """Allocate resources: IRAM offsets, activation IDs, frame layouts, resolve destinations. 948 949 Performs the following operations per PE: 950 1. Assign provisional IRAM offsets (dyadic first, then monadic) 951 2. Assign activation IDs (0 to frame_count-1) 952 3. Compute modes (OutputStyle, has_const, dest_count) from edge topology 953 4. Compute frame layouts (assigns fref, frame_layout) per activation 954 5. Deduplicate IRAM entries by instruction template 955 6. Resolve destinations to FrameDest objects with PE, offset, act_id, port, token_kind 956 957 Args: 958 graph: The IRGraph to allocate 959 960 Returns: 961 New IRGraph with all nodes updated and allocation errors appended 962 """ 963 errors = list(graph.errors) 964 system = graph.system 965 966 if system is None: 967 # Should not happen if place() was called first, but handle gracefully 968 system_errors = [ 969 AssemblyError( 970 loc=SourceLoc(0, 0), 971 category=ErrorCategory.RESOURCE, 972 message="Cannot allocate without SystemConfig. Run place() first.", 973 ) 974 ] 975 return replace(graph, errors=errors + system_errors) 976 977 # Collect all nodes and edges 978 all_nodes, all_edges = collect_all_nodes_and_edges(graph) 979 edges_by_source = _build_edge_index(all_edges) 980 edges_by_dest = _build_edge_index_by_dest(all_edges) 981 982 # Validate non-commutative ops with IRAM constants have explicit ports 983 noncomm_errors = _validate_noncommutative_const(all_nodes, edges_by_dest) 984 errors.extend(noncomm_errors) 985 986 # Assign SM IDs to MemOp nodes that lack explicit SM targets 987 sm_updated, sm_errors = _assign_sm_ids(all_nodes, system.sm_count) 988 errors.extend(sm_errors) 989 all_nodes.update(sm_updated) 990 991 # Group nodes by PE 992 nodes_by_pe = _group_nodes_by_pe(all_nodes) 993 994 # First pass: assign IRAM offsets, activation IDs, compute modes and frame layouts 995 intermediate_nodes = {} 996 for pe_id, nodes_on_pe in sorted(nodes_by_pe.items()): 997 # 1. Assign provisional IRAM offsets 998 iram_updated, iram_errors = _assign_iram_offsets( 999 nodes_on_pe, 1000 all_nodes, 1001 system.iram_capacity, 1002 pe_id, 1003 ) 1004 errors.extend(iram_errors) 1005 1006 if iram_errors: 1007 # Skip further processing on this PE if IRAM error 1008 continue 1009 1010 # 2. Assign activation IDs 1011 act_updated, act_errors = _assign_act_ids( 1012 list(iram_updated.values()), 1013 all_nodes, 1014 system.frame_count, 1015 pe_id, 1016 call_sites=graph.call_sites, 1017 ) 1018 errors.extend(act_errors) 1019 1020 if act_errors: 1021 # Skip further processing on this PE if activation error 1022 continue 1023 1024 # 3. Compute modes (OutputStyle, has_const, dest_count) 1025 mode_updated = _compute_modes(act_updated, edges_by_source) 1026 1027 # 4. Compute frame layouts (assigns fref, frame_layout) 1028 frame_updated, frame_errors = _compute_frame_layouts( 1029 mode_updated, 1030 edges_by_source, 1031 edges_by_dest, 1032 all_nodes, 1033 system.frame_slots, 1034 system.matchable_offsets, 1035 pe_id, 1036 ) 1037 errors.extend(frame_errors) 1038 1039 if frame_errors: 1040 # Skip further processing on this PE if frame error 1041 continue 1042 1043 # 5. Deduplicate IRAM entries 1044 deduped = _deduplicate_iram(frame_updated, pe_id) 1045 1046 intermediate_nodes.update(deduped) 1047 1048 # Second pass: resolve destinations using intermediate nodes 1049 updated_all_nodes = {} 1050 for pe_id in sorted(nodes_by_pe.keys()): 1051 # Get nodes from this PE that made it through offset/slot assignment 1052 nodes_on_this_pe = { 1053 name: node for name, node in intermediate_nodes.items() 1054 if node.pe == pe_id 1055 } 1056 if not nodes_on_this_pe: 1057 # This PE had errors, skip it 1058 continue 1059 1060 resolved_updated, resolve_errors = _resolve_destinations( 1061 nodes_on_this_pe, 1062 intermediate_nodes, # Use intermediate nodes for lookups, not original 1063 edges_by_source, 1064 ) 1065 errors.extend(resolve_errors) 1066 1067 updated_all_nodes.update(resolved_updated) 1068 1069 # Merge updated_all_nodes with intermediate_nodes (for nodes that didn't get resolved destinations) 1070 # This ensures nodes from PEs with errors still get their offsets/slots 1071 final_nodes = dict(intermediate_nodes) 1072 final_nodes.update(updated_all_nodes) 1073 1074 # Reconstruct the graph with updated nodes 1075 # Need to preserve the tree structure (regions, etc.) 1076 result_graph = update_graph_nodes(graph, final_nodes) 1077 return replace(result_graph, errors=errors)