OR-1 dataflow CPU sketch
1"""Resource allocation pass for the OR1 assembler.
2
3Allocates IRAM offsets, activation IDs, frame layouts, and resolves symbolic
4destinations to FrameDest values with PE, offset, activation, and port information.
5
6Reference: Phase 5 design doc, Tasks 1-4.
7"""
8
9from __future__ import annotations
10
11from dataclasses import replace
12from collections import defaultdict
13
14from asm.errors import AssemblyError, ErrorCategory, ErrorSeverity
15from asm.ir import IRGraph, IRNode, IREdge, SourceLoc, ResolvedDest, CallSite, collect_all_nodes_and_edges, update_graph_nodes
16from asm.opcodes import is_dyadic, is_monadic
17from cm_inst import ArithOp, LogicOp, MemOp, Port, RoutingOp, OutputStyle, FrameDest, TokenKind
18
19
20# Module-level constants
21_SINK_MEMOPS = frozenset({MemOp.WRITE, MemOp.CLEAR, MemOp.FREE, MemOp.SET_PAGE, MemOp.WRITE_IMM})
22
23
24def _group_nodes_by_pe(nodes: dict[str, IRNode]) -> dict[int, list[IRNode]]:
25 """Group nodes by their PE assignment.
26
27 Args:
28 nodes: Dictionary of all nodes
29
30 Returns:
31 Dictionary mapping PE ID to list of nodes on that PE
32 """
33 groups: dict[int, list[IRNode]] = defaultdict(list)
34 for node in nodes.values():
35 if node.pe is not None:
36 groups[node.pe].append(node)
37 return groups
38
39
40def _extract_function_scope(node_name: str) -> str:
41 """Extract function scope from qualified node name.
42
43 Strips macro scope segments (starting with #) before extracting the function scope.
44 Macro scopes are for name uniqueness only — they don't allocate context slots.
45
46 Examples:
47 "$main.&add" -> "$main"
48 "$main.#loop_0.&counter" -> "$main" (macro segment stripped)
49 "#loop_0.&counter" -> "" (macro at root scope)
50 "$func.#outer_1.#inner_2.&label" -> "$func" (all macro segments stripped)
51 "&top_level" -> "" (root scope)
52
53 Args:
54 node_name: Qualified node name
55
56 Returns:
57 Function scope name, or empty string for root scope
58 """
59 if "." not in node_name:
60 return ""
61
62 # Split by dots and filter out segments starting with #
63 segments = node_name.split(".")
64 filtered = [seg for seg in segments if not seg.startswith("#")]
65
66 if not filtered:
67 # All segments were macro scopes
68 return ""
69
70 # Return the first non-macro segment if it starts with $, else root scope
71 first_segment = filtered[0]
72 if first_segment.startswith("$"):
73 return first_segment
74 return ""
75
76
77def _assign_iram_offsets(
78 nodes_on_pe: list[IRNode],
79 all_nodes: dict[str, IRNode],
80 iram_capacity: int,
81 pe_id: int,
82) -> tuple[dict[str, IRNode], list[AssemblyError]]:
83 """Assign provisional IRAM offsets to nodes on a PE.
84
85 Dyadic instructions get offsets 0..D-1, monadic get D..M-1.
86 All instructions cost 1 slot (per Phase 4 change).
87
88 Deduplication happens later via _deduplicate_iram() after
89 frame layouts and modes are computed.
90
91 Args:
92 nodes_on_pe: List of nodes on this PE
93 all_nodes: All nodes (for name lookup)
94 iram_capacity: Maximum IRAM slots for this PE
95 pe_id: The PE ID (for error messages)
96
97 Returns:
98 Tuple of (updated_nodes dict, errors list)
99 """
100 errors = []
101 updated_nodes = {}
102
103 # Partition into dyadic and monadic, preserving order within each partition
104 # Seed nodes are excluded — they don't occupy IRAM slots
105 dyadic_nodes = []
106 monadic_nodes = []
107
108 for node in nodes_on_pe:
109 if node.seed:
110 continue
111 if is_dyadic(node.opcode, node.const):
112 dyadic_nodes.append(node)
113 else:
114 monadic_nodes.append(node)
115
116 # Assign offsets
117 total = len(dyadic_nodes) + len(monadic_nodes)
118 if total > iram_capacity:
119 # Generate overflow error
120 error_msg = f"PE{pe_id} IRAM overflow: {total} instructions but only {iram_capacity} slots.\n"
121 if dyadic_nodes:
122 dyadic_names = ", ".join([n.name.split('.')[-1] for n in dyadic_nodes])
123 error_msg += f" Dyadic: {dyadic_names} ({len(dyadic_nodes)} instructions)\n"
124 if monadic_nodes:
125 monadic_names = ", ".join([n.name.split('.')[-1] for n in monadic_nodes])
126 error_msg += f" Monadic: {monadic_names} ({len(monadic_nodes)} instructions)"
127
128 error = AssemblyError(
129 loc=SourceLoc(0, 0),
130 category=ErrorCategory.RESOURCE,
131 message=error_msg,
132 )
133 errors.append(error)
134 return {}, errors
135
136 # Assign offsets
137 for offset, node in enumerate(dyadic_nodes):
138 updated_nodes[node.name] = replace(node, iram_offset=offset)
139
140 for offset, node in enumerate(monadic_nodes):
141 updated_nodes[node.name] = replace(node, iram_offset=len(dyadic_nodes) + offset)
142
143 return updated_nodes, errors
144
145
146def _deduplicate_iram(
147 nodes_on_pe: dict[str, IRNode],
148 pe_id: int,
149) -> dict[str, IRNode]:
150 """Deduplicate IRAM entries for nodes that produce identical Instruction templates.
151
152 Two nodes share an IRAM offset when they have identical:
153 opcode (type and value), output (OutputStyle), has_const, dest_count, wide, fref.
154
155 Args:
156 nodes_on_pe: Dictionary of nodes on this PE
157 pe_id: The PE ID (for diagnostics)
158
159 Returns:
160 Updated nodes dictionary with deduplicated IRAM offsets
161 """
162 template_to_offset: dict[tuple, int] = {}
163 updated = {}
164
165 for name, node in nodes_on_pe.items():
166 if node.seed or node.iram_offset is None:
167 updated[name] = node
168 continue
169
170 # Build template key from the fields that make an Instruction
171 mode = node.mode # (OutputStyle, has_const, dest_count)
172 if mode is None:
173 updated[name] = node
174 continue
175
176 # Include opcode type to distinguish between ArithOp.ADD and MemOp.READ
177 # which both have value 0
178 template_key = (
179 type(node.opcode).__name__, # Include type name to disambiguate IntEnums
180 int(node.opcode), # Include numeric value
181 mode[0].name, # Include OutputStyle name string
182 mode[1], # has_const
183 mode[2], # dest_count
184 node.wide,
185 node.fref,
186 )
187
188 if template_key in template_to_offset:
189 # Reuse existing offset
190 updated[name] = replace(node, iram_offset=template_to_offset[template_key])
191 else:
192 template_to_offset[template_key] = node.iram_offset
193 updated[name] = node
194
195 return updated
196
197
198def _compute_modes(
199 nodes_on_pe: dict[str, IRNode],
200 edges_by_source: dict[str, list[IREdge]],
201) -> dict[str, IRNode]:
202 """Compute (OutputStyle, has_const, dest_count) for each node from edge topology.
203
204 Args:
205 nodes_on_pe: Dictionary of nodes on this PE
206 edges_by_source: Edges indexed by source node name
207
208 Returns:
209 Updated nodes dictionary with mode field set
210 """
211 updated = {}
212 for name, node in nodes_on_pe.items():
213 if node.seed:
214 updated[name] = node
215 continue
216
217 out_edges = edges_by_source.get(name, [])
218 dest_count = len(out_edges)
219 has_const = node.const is not None
220
221 # Determine OutputStyle from opcode and edge topology
222 if isinstance(node.opcode, MemOp):
223 # SM instructions: mode depends on MemOp semantics
224 # WRITE, CLEAR, FREE, SET_PAGE, WRITE_IMM → no return value → SINK
225 # READ, RD_INC, RD_DEC, CMP_SW, RAW_READ, EXT → return value → INHERIT
226 if node.opcode in _SINK_MEMOPS:
227 output = OutputStyle.SINK
228 dest_count = 0
229 else:
230 output = OutputStyle.INHERIT
231 elif node.opcode == RoutingOp.FREE_FRAME:
232 output = OutputStyle.SINK
233 dest_count = 0
234 elif node.opcode == RoutingOp.EXTRACT_TAG:
235 output = OutputStyle.INHERIT
236 else:
237 # Check if any outgoing edge has ctx_override=True
238 # ctx_override=True on an edge means the source node is a
239 # cross-function return — its left operand carries a packed
240 # flit 1 (from EXTRACT_TAG) that determines the destination.
241 # In the frame model, this maps to OutputStyle.CHANGE_TAG.
242 has_ctx_override = any(e.ctx_override for e in out_edges)
243 if has_ctx_override:
244 output = OutputStyle.CHANGE_TAG
245 dest_count = 1
246 elif dest_count == 0:
247 output = OutputStyle.SINK
248 else:
249 output = OutputStyle.INHERIT
250
251 mode = (output, has_const, dest_count)
252 updated[name] = replace(node, mode=mode)
253
254 return updated
255
256
257def _compute_frame_layouts(
258 nodes_on_pe: dict[str, IRNode],
259 edges_by_source: dict[str, list[IREdge]],
260 edges_by_dest: dict[str, list[IREdge]],
261 all_nodes: dict[str, IRNode],
262 frame_slots: int,
263 matchable_offsets: int,
264 pe_id: int,
265) -> tuple[dict[str, IRNode], list[AssemblyError]]:
266 """Compute frame slot layouts per activation.
267
268 Slot assignment order:
269 0 to matchable_offsets-1: match operands (one pair per dyadic instruction)
270 then: constants (deduplicated by value)
271 then: destinations (deduplicated by FrameDest identity)
272 then: sinks and SM parameters
273
274 All activations of the same function share the canonical layout.
275
276 Args:
277 nodes_on_pe: Dictionary of nodes on this PE
278 edges_by_source: Edges indexed by source node name
279 edges_by_dest: Edges indexed by destination node name
280 all_nodes: All nodes in graph
281 frame_slots: Total slots per frame
282 matchable_offsets: Number of hardware match slots (const/dest slots start here)
283 pe_id: The PE ID (for error messages)
284
285 Returns:
286 Tuple of (updated_nodes dict, errors list)
287 """
288 from asm.ir import FrameLayout, FrameSlotMap
289
290 errors = []
291 updated = {}
292
293 # Group nodes by activation ID
294 nodes_by_act_id = defaultdict(list)
295 for name, node in nodes_on_pe.items():
296 if not node.seed and node.act_id is not None:
297 nodes_by_act_id[node.act_id].append(node)
298
299 # Compute frame layout for each activation
300 act_id_to_layout = {} # act_id -> FrameLayout
301
302 for act_id, nodes_in_act in nodes_by_act_id.items():
303 # Count dyadic instructions (match operands)
304 dyadic_count = sum(1 for n in nodes_in_act if is_dyadic(n.opcode, n.const))
305
306 # Warn if dyadic_count exceeds hardware match slots
307 if dyadic_count > matchable_offsets:
308 warning = AssemblyError(
309 loc=SourceLoc(0, 0),
310 category=ErrorCategory.FRAME,
311 severity=ErrorSeverity.WARNING,
312 message=(
313 f"PE{pe_id} activation {act_id}: {dyadic_count} dyadic instructions "
314 f"but only {matchable_offsets} hardware match slots. "
315 f"Liveness-based slot sharing required (see AC5.4 note)."
316 ),
317 )
318 errors.append(warning)
319
320 # Collect constants (deduplicated by value)
321 unique_const_values = set()
322 for n in nodes_in_act:
323 if n.const is not None and not isinstance(n.const, (str, type(None))):
324 unique_const_values.add(n.const)
325 const_count = len(unique_const_values)
326
327 # Collect destination slots: each node needs dest_count slots for its destinations
328 # These are not deduplicated - each node gets its own slot(s)
329 dest_count = sum(n.mode[2] for n in nodes_in_act if n.mode is not None)
330
331 # Count slots needed
332 # Match slots reserved at 0 to matchable_offsets-1 (regardless of dyadic_count)
333 match_slot_count = matchable_offsets
334 const_slot_count = const_count
335 dest_slot_count = dest_count
336
337 # SM params and sinks (only count actual sink MemOps)
338 sink_slot_count = len([n for n in nodes_in_act if isinstance(n.opcode, MemOp) and n.opcode in _SINK_MEMOPS])
339
340 total_slots = match_slot_count + const_slot_count + dest_slot_count + sink_slot_count
341 if total_slots > frame_slots:
342 error_msg = (
343 f"Frame slot overflow on PE{pe_id} activation {act_id}: "
344 f"{total_slots} slots needed, {frame_slots} available.\n"
345 f" Match region (reserved): {match_slot_count}\n"
346 f" Constants: {const_slot_count}\n"
347 f" Destinations: {dest_slot_count}\n"
348 f" Sinks/SM params: {sink_slot_count}"
349 )
350 error = AssemblyError(
351 loc=SourceLoc(0, 0),
352 category=ErrorCategory.FRAME,
353 message=error_msg,
354 )
355 errors.append(error)
356 continue
357
358 # Build frame layout
359 # NOTE: With interleaved const+dest allocation, const_slots and dest_slots are not
360 # contiguous separate regions. They're interleaved per node. The slot_map below is
361 # therefore approximate for documentation; the actual slot allocation comes from node frefs.
362 match_slots = tuple(range(match_slot_count))
363
364 # All non-match slots are either const or dest type
365 # For documentation: mark const slots and dest slots
366 # This is approximate since they're interleaved
367 const_start = match_slot_count
368 const_slots = tuple(range(const_start, const_start + const_slot_count))
369 dest_start = const_start + const_slot_count
370 dest_slots = tuple(range(dest_start, dest_start + dest_slot_count))
371 sink_start = dest_start + dest_slot_count
372 sink_slots = tuple(range(sink_start, sink_start + sink_slot_count))
373
374 slot_map = FrameSlotMap(
375 match_slots=match_slots,
376 const_slots=const_slots,
377 dest_slots=dest_slots,
378 sink_slots=sink_slots,
379 )
380 layout = FrameLayout(slot_map=slot_map, total_slots=total_slots)
381 act_id_to_layout[act_id] = layout
382
383 # Assign frame layouts and frefs to nodes
384 # First, build per-activation node-to-fref mapping
385 act_id_to_node_frefs = {} # act_id -> {node_name -> fref}
386
387 for act_id, nodes_in_act in nodes_by_act_id.items():
388 if act_id not in act_id_to_layout:
389 continue
390
391 layout = act_id_to_layout[act_id]
392
393 # Assign fref to each node in order
394 # const_nodes need: [const, dest1, dest2, ...] layout
395 # no-const,dest_nodes need: [dest1, dest2, ...] layout
396 # Allocation order: process nodes in sorted order, mixing const and dest allocations
397 # to ensure that const nodes get const_slot at fref with dests at fref+1 onward
398 node_frefs = {}
399
400 # Separate nodes by type
401 const_nodes = []
402 no_const_dest_nodes = []
403 no_const_no_dest_nodes = []
404 sink_nodes = []
405
406 for node in sorted(nodes_in_act, key=lambda n: n.name):
407 if node.seed or node.mode is None:
408 continue
409 output_style, has_const, dest_count = node.mode
410 if output_style == OutputStyle.SINK:
411 sink_nodes.append(node)
412 elif has_const:
413 const_nodes.append(node)
414 elif dest_count > 0:
415 no_const_dest_nodes.append(node)
416 else:
417 # No const, no destinations - still need a slot for frame matching
418 no_const_no_dest_nodes.append(node)
419
420 # Allocate const nodes first (they get const slot + dest slots)
421 slot_counter = matchable_offsets # Start after match region
422 for node in const_nodes:
423 _, has_const, dest_count = node.mode
424 # Assign fref to const slot
425 node_frefs[node.name] = slot_counter
426 slot_counter += 1 + dest_count # const + dests
427
428 # Then allocate no-const,dest nodes
429 for node in no_const_dest_nodes:
430 _, has_const, dest_count = node.mode
431 # Assign fref to first dest slot
432 node_frefs[node.name] = slot_counter
433 slot_counter += dest_count
434
435 # Then allocate no-const,no-dest nodes (still need fref slot for result writeback in SINK mode)
436 for node in no_const_no_dest_nodes:
437 node_frefs[node.name] = slot_counter
438 slot_counter += 1
439
440 # Finally allocate sink nodes
441 for node in sink_nodes:
442 node_frefs[node.name] = slot_counter
443 slot_counter += 1
444
445 act_id_to_node_frefs[act_id] = node_frefs
446
447 # Now assign frame layouts and frefs to nodes
448 for name, node in nodes_on_pe.items():
449 if node.seed:
450 updated[name] = node
451 continue
452
453 if node.act_id is not None and node.act_id in act_id_to_layout:
454 layout = act_id_to_layout[node.act_id]
455
456 # Get fref from the per-activation mapping
457 node_frefs = act_id_to_node_frefs.get(node.act_id, {})
458 fref = node_frefs.get(name)
459
460 if fref is not None:
461 updated[name] = replace(node, frame_layout=layout, fref=fref)
462 else:
463 updated[name] = replace(node, frame_layout=layout)
464 else:
465 updated[name] = node
466
467 return updated, errors
468
469
470def _assign_act_ids(
471 nodes_on_pe: list[IRNode],
472 all_nodes: dict[str, IRNode],
473 frame_count: int,
474 pe_id: int,
475 call_sites: list[CallSite] | None = None,
476) -> tuple[dict[str, IRNode], list[AssemblyError]]:
477 """Assign activation IDs (0 to frame_count-1) per function scope per PE.
478
479 Implements per-call-site activation allocation:
480 - Root scope always gets act_id=0
481 - Functions without call sites get one act_id by the existing scope rule
482 - Each call site allocates a fresh act_id on the PE(s) where the callee lives
483
484 Args:
485 nodes_on_pe: List of nodes on this PE
486 all_nodes: All nodes (for name lookup)
487 frame_count: Maximum activation IDs for this PE (default 8)
488 pe_id: The PE ID (for error messages)
489 call_sites: Optional list of CallSite objects for per-call-site allocation
490
491 Returns:
492 Tuple of (updated_nodes dict, errors list)
493 """
494 if call_sites is None:
495 call_sites = []
496
497 errors = []
498 updated_nodes = {}
499
500 # Build global mapping of which nodes belong to which call sites
501 # Trampoline and free_frame nodes get the call site's act_id
502 callsite_for_node = {} # node_name -> CallSite
503 for call_site in call_sites:
504 for tramp_node in call_site.trampoline_nodes:
505 callsite_for_node[tramp_node] = call_site
506 for free_node in call_site.free_frame_nodes:
507 callsite_for_node[free_node] = call_site
508
509 # Build mapping: function scope -> call site (for function body nodes)
510 func_scope_to_callsite = {} # func_name -> CallSite
511 for call_site in call_sites:
512 func_scope_to_callsite[call_site.func_name] = call_site
513
514 # Allocate activation IDs for this PE
515 next_act_id = 0
516 act_breakdown = {} # For overflow error reporting
517 scope_to_act_id = {}
518 root_act_id = 0 # Default root scope activation
519
520 # Check if there are any root-scope nodes on this PE
521 has_root_scope_nodes = any(
522 not node.seed and _extract_function_scope(node.name) == ""
523 for node in nodes_on_pe
524 )
525
526 # Root scope always gets act_id=0 if it has nodes on this PE
527 if has_root_scope_nodes:
528 scope_to_act_id[""] = root_act_id
529 act_breakdown["root"] = 1
530 next_act_id = 1
531
532 for node in nodes_on_pe:
533 if node.seed:
534 continue
535 scope = _extract_function_scope(node.name)
536 # Only process function scopes not already assigned
537 if scope and scope not in scope_to_act_id:
538 # Check if this function has any call sites
539 has_call_sites = any(cs.func_name == scope for cs in call_sites)
540 if not has_call_sites:
541 # No call sites, assign one slot (per-scope per-PE)
542 if next_act_id >= frame_count:
543 # Overflow
544 error_msg = _build_activation_overflow_message(
545 pe_id, frame_count, next_act_id, act_breakdown
546 )
547 error = AssemblyError(
548 loc=SourceLoc(0, 0),
549 category=ErrorCategory.FRAME,
550 message=error_msg,
551 )
552 errors.append(error)
553 return {}, errors
554 scope_to_act_id[scope] = next_act_id
555 act_breakdown[scope] = 1
556 next_act_id += 1
557
558 # Now allocate per-call-site activation IDs (one per call site per PE)
559 # Build mapping: call_site -> act_id on this PE
560 call_site_to_act_id_on_pe = {}
561 for call_site in call_sites:
562 # Check if any trampoline or free_frame node for this call site is on this PE
563 has_node_on_pe = False
564 for tramp_node in call_site.trampoline_nodes:
565 if tramp_node in all_nodes and all_nodes[tramp_node].pe == pe_id:
566 has_node_on_pe = True
567 break
568 if not has_node_on_pe:
569 for free_node in call_site.free_frame_nodes:
570 if free_node in all_nodes and all_nodes[free_node].pe == pe_id:
571 has_node_on_pe = True
572 break
573
574 if has_node_on_pe:
575 # This call site has nodes on this PE, allocate an activation ID
576 if next_act_id >= frame_count:
577 # Overflow
578 error_msg = _build_activation_overflow_message(
579 pe_id, frame_count, next_act_id, act_breakdown
580 )
581 error = AssemblyError(
582 loc=SourceLoc(0, 0),
583 category=ErrorCategory.FRAME,
584 message=error_msg,
585 )
586 errors.append(error)
587 return {}, errors
588
589 call_site_to_act_id_on_pe[call_site] = next_act_id
590 act_breakdown[f"{call_site.func_name} call site #{call_site.call_id}"] = 1
591 next_act_id += 1
592
593 # Check budget warning (75%)
594 if frame_count > 0:
595 utilisation = next_act_id / frame_count
596 if utilisation >= 0.75:
597 percent = int(utilisation * 100)
598 warning = AssemblyError(
599 loc=SourceLoc(0, 0),
600 category=ErrorCategory.FRAME,
601 severity=ErrorSeverity.WARNING,
602 message=f"PE{pe_id}: {next_act_id}/{frame_count} activation IDs used ({percent}%)",
603 )
604 errors.append(warning)
605
606 # Assign activation IDs to nodes
607 for node in nodes_on_pe:
608 if node.seed:
609 continue
610
611 # Check if this node is a trampoline or free_frame node for a call site
612 act_id_value = None
613 if node.name in callsite_for_node:
614 call_site = callsite_for_node[node.name]
615 act_id_value = call_site_to_act_id_on_pe.get(call_site)
616
617 # If not part of a call site, check if it's a function body node
618 if act_id_value is None:
619 scope = _extract_function_scope(node.name)
620 if scope in func_scope_to_callsite:
621 # Function body node — gets the call site's act_id
622 cs = func_scope_to_callsite[scope]
623 act_id_value = call_site_to_act_id_on_pe.get(cs)
624 if act_id_value is None:
625 act_id_value = scope_to_act_id.get(scope, root_act_id)
626
627 updated_nodes[node.name] = replace(node, act_id=act_id_value)
628
629 return updated_nodes, errors
630
631
632def _build_activation_overflow_message(pe_id: int, frame_count: int, used: int, breakdown: dict) -> str:
633 """Build a detailed activation ID overflow error message.
634
635 Args:
636 pe_id: The PE ID
637 frame_count: Total available activation IDs
638 used: Number of IDs needed
639 breakdown: Dictionary mapping scope/call site to ID count
640
641 Returns:
642 Formatted error message
643 """
644 lines = [
645 f"Activation ID exhaustion on PE{pe_id}: {used} IDs needed, {frame_count} available"
646 ]
647 for scope_name, count in breakdown.items():
648 if scope_name == "root":
649 lines.append(f" Root scope: {count} ID")
650 else:
651 lines.append(f" {scope_name}: {count} ID")
652 lines.append("Consider inlining frequently-called functions to reduce frame pressure.")
653 return "\n".join(lines)
654
655
656def _assign_sm_ids(
657 all_nodes: dict[str, IRNode],
658 sm_count: int,
659) -> tuple[dict[str, IRNode], list[AssemblyError]]:
660 """Assign SM IDs to MemOp instruction nodes that lack one.
661
662 For single-SM systems, defaults to sm_id=0. For multi-SM systems where
663 the SM target is ambiguous, reports an error.
664
665 Args:
666 all_nodes: Dictionary of all nodes
667 sm_count: Number of SMs in the system
668
669 Returns:
670 Tuple of (updated nodes dict, list of errors)
671 """
672 errors: list[AssemblyError] = []
673 updated: dict[str, IRNode] = {}
674
675 for name, node in all_nodes.items():
676 if isinstance(node.opcode, MemOp) and node.sm_id is None:
677 if sm_count == 0:
678 errors.append(AssemblyError(
679 loc=node.loc,
680 category=ErrorCategory.RESOURCE,
681 message=f"Node '{name}' uses memory operation '{node.opcode.name}' but system has no SMs.",
682 ))
683 elif sm_count == 1:
684 updated[name] = replace(node, sm_id=0)
685 else:
686 errors.append(AssemblyError(
687 loc=node.loc,
688 category=ErrorCategory.RESOURCE,
689 message=(
690 f"Node '{name}' uses memory operation '{node.opcode.name}' but no SM target specified "
691 f"and system has {sm_count} SMs. Cannot infer target."
692 ),
693 ))
694
695 return updated, errors
696
697
698_COMMUTATIVE_OPS: frozenset = frozenset({
699 ArithOp.ADD,
700 LogicOp.AND,
701 LogicOp.OR,
702 LogicOp.XOR,
703 LogicOp.EQ,
704})
705
706
707def _build_edge_index_by_dest(edges: list[IREdge]) -> dict[str, list[IREdge]]:
708 """Build index of edges by destination node name."""
709 index: dict[str, list[IREdge]] = defaultdict(list)
710 for edge in edges:
711 index[edge.dest].append(edge)
712 return index
713
714
715def _validate_noncommutative_const(
716 all_nodes: dict[str, IRNode],
717 edges_by_dest: dict[str, list[IREdge]],
718) -> list[AssemblyError]:
719 """Warn when non-commutative dyadic ops with IRAM const lack explicit ports.
720
721 When a dyadic instruction has a baked-in constant (e.g., `sub 3`), the
722 incoming token goes to whichever port the edge specifies (default: L).
723 For commutative ops this is irrelevant, but for non-commutative ops
724 (sub, lt, gt, etc.) the port determines operand order. If the user
725 didn't specify the port explicitly, emit a warning so they know the
726 implicit default is in effect.
727
728 Args:
729 all_nodes: All nodes in graph
730 edges_by_dest: Edges indexed by destination node name
731
732 Returns:
733 List of warnings for implicit port assignments
734 """
735 warnings: list[AssemblyError] = []
736
737 for name, node in all_nodes.items():
738 if node.seed:
739 continue
740 if node.const is None:
741 continue
742 if is_monadic(node.opcode, node.const):
743 continue
744 if node.opcode in _COMMUTATIVE_OPS:
745 continue
746
747 incoming = edges_by_dest.get(name, [])
748 for edge in incoming:
749 if not edge.port_explicit:
750 warnings.append(AssemblyError(
751 loc=edge.loc,
752 category=ErrorCategory.PORT,
753 severity=ErrorSeverity.WARNING,
754 message=(
755 f"Non-commutative op '{node.opcode.name}' on node '{name}' "
756 f"has an IRAM constant but incoming edge from '{edge.source}' "
757 f"has no explicit port (:L or :R). Defaulting to :L — "
758 f"the token will be the left operand and the constant the right."
759 ),
760 ))
761
762 return warnings
763
764
765def _build_edge_index(edges: list[IREdge]) -> dict[str, list[IREdge]]:
766 """Build index of edges by source node name.
767
768 Args:
769 edges: List of all edges
770
771 Returns:
772 Dictionary mapping source name to list of edges from that source
773 """
774 index: dict[str, list[IREdge]] = defaultdict(list)
775 for edge in edges:
776 index[edge.source].append(edge)
777 return index
778
779
780def _determine_token_kind(dest_node: IRNode) -> TokenKind:
781 """Determine the token kind for a destination node.
782
783 Token kind is determined by the destination node's opcode:
784 - If dyadic: TokenKind.DYADIC
785 - If monadic: TokenKind.MONADIC
786 - (TokenKind.INLINE is reserved for future use)
787
788 Args:
789 dest_node: The destination IRNode
790
791 Returns:
792 TokenKind enum value
793 """
794 if is_dyadic(dest_node.opcode, dest_node.const):
795 return TokenKind.DYADIC
796 else:
797 return TokenKind.MONADIC
798
799
800def _resolve_destinations(
801 nodes_on_pe: dict[str, IRNode],
802 all_nodes: dict[str, IRNode],
803 edges_by_source: dict[str, list[IREdge]],
804) -> tuple[dict[str, IRNode], list[AssemblyError]]:
805 """Resolve symbolic destinations to FrameDest objects.
806
807 Uses edge-to-destination mapping rules:
808 - source_port=L -> dest_l
809 - source_port=R -> dest_r
810 - source_port=None: single edge -> dest_l, two edges -> first dest_l, second dest_r
811
812 Each destination is resolved to a FrameDest object containing target PE,
813 IRAM offset, activation ID, port, and token kind.
814
815 Args:
816 nodes_on_pe: Nodes on this PE (with iram_offset, act_id set)
817 all_nodes: All nodes in graph
818 edges_by_source: Edges indexed by source node name
819
820 Returns:
821 Tuple of (updated_nodes dict, errors list)
822 """
823 errors = []
824 updated_nodes = {}
825
826 for node_name, node in nodes_on_pe.items():
827 updated_node = node
828 source_edges = edges_by_source.get(node_name, [])
829
830 # Validate edge count
831 if len(source_edges) > 2:
832 error = AssemblyError(
833 loc=node.loc,
834 category=ErrorCategory.PORT,
835 message=f"Node '{node_name}' has {len(source_edges)} outgoing edges, but maximum is 2.",
836 )
837 errors.append(error)
838 continue
839
840 # Validate source_port conflicts
841 source_ports = [e.source_port for e in source_edges]
842 explicit_ports = [p for p in source_ports if p is not None]
843 if len(explicit_ports) != len(set(explicit_ports)):
844 error = AssemblyError(
845 loc=node.loc,
846 category=ErrorCategory.PORT,
847 message=f"Node '{node_name}' has conflicting source_port qualifiers.",
848 )
849 errors.append(error)
850 continue
851
852 # Validate mixed explicit/implicit
853 if len(explicit_ports) > 0 and len(explicit_ports) < len(source_edges):
854 error = AssemblyError(
855 loc=node.loc,
856 category=ErrorCategory.PORT,
857 message=f"Node '{node_name}' has mixed explicit and implicit source ports.",
858 )
859 errors.append(error)
860 continue
861
862 # Resolve edges to destinations
863 if len(source_edges) == 0:
864 # No outgoing edges, keep as-is
865 pass
866 elif len(source_edges) == 1:
867 # Single edge -> dest_l
868 edge = source_edges[0]
869 dest_node = all_nodes.get(edge.dest)
870 if dest_node is None:
871 error = AssemblyError(
872 loc=edge.loc,
873 category=ErrorCategory.NAME,
874 message=f"Edge destination '{edge.dest}' not found.",
875 )
876 errors.append(error)
877 continue
878
879 # Skip resolution if destination lacks required fields
880 if dest_node.iram_offset is None or dest_node.act_id is None or dest_node.pe is None:
881 error = AssemblyError(
882 loc=edge.loc,
883 category=ErrorCategory.RESOURCE,
884 message=f"Destination '{edge.dest}' lacks iram_offset, act_id, or PE assignment.",
885 )
886 errors.append(error)
887 continue
888
889 frame_dest = FrameDest(
890 target_pe=dest_node.pe,
891 offset=dest_node.iram_offset,
892 act_id=dest_node.act_id,
893 port=edge.port,
894 token_kind=_determine_token_kind(dest_node),
895 )
896 resolved = ResolvedDest(name=edge.dest, addr=None, frame_dest=frame_dest)
897 updated_node = replace(updated_node, dest_l=resolved)
898
899 else: # len(source_edges) == 2
900 # Two edges: map by source_port or order
901 edges = source_edges
902 if explicit_ports:
903 # All explicit: sort by port
904 edges = sorted(edges, key=lambda e: e.source_port)
905
906 for idx, edge in enumerate(edges):
907 dest_node = all_nodes.get(edge.dest)
908 if dest_node is None:
909 error = AssemblyError(
910 loc=edge.loc,
911 category=ErrorCategory.NAME,
912 message=f"Edge destination '{edge.dest}' not found.",
913 )
914 errors.append(error)
915 continue
916
917 # Skip resolution if destination lacks required fields
918 if dest_node.iram_offset is None or dest_node.act_id is None or dest_node.pe is None:
919 error = AssemblyError(
920 loc=edge.loc,
921 category=ErrorCategory.RESOURCE,
922 message=f"Destination '{edge.dest}' lacks iram_offset, act_id, or PE assignment.",
923 )
924 errors.append(error)
925 continue
926
927 frame_dest = FrameDest(
928 target_pe=dest_node.pe,
929 offset=dest_node.iram_offset,
930 act_id=dest_node.act_id,
931 port=edge.port,
932 token_kind=_determine_token_kind(dest_node),
933 )
934 resolved = ResolvedDest(name=edge.dest, addr=None, frame_dest=frame_dest)
935
936 if idx == 0:
937 updated_node = replace(updated_node, dest_l=resolved)
938 else:
939 updated_node = replace(updated_node, dest_r=resolved)
940
941 updated_nodes[node_name] = updated_node
942
943 return updated_nodes, errors
944
945
946def allocate(graph: IRGraph) -> IRGraph:
947 """Allocate resources: IRAM offsets, activation IDs, frame layouts, resolve destinations.
948
949 Performs the following operations per PE:
950 1. Assign provisional IRAM offsets (dyadic first, then monadic)
951 2. Assign activation IDs (0 to frame_count-1)
952 3. Compute modes (OutputStyle, has_const, dest_count) from edge topology
953 4. Compute frame layouts (assigns fref, frame_layout) per activation
954 5. Deduplicate IRAM entries by instruction template
955 6. Resolve destinations to FrameDest objects with PE, offset, act_id, port, token_kind
956
957 Args:
958 graph: The IRGraph to allocate
959
960 Returns:
961 New IRGraph with all nodes updated and allocation errors appended
962 """
963 errors = list(graph.errors)
964 system = graph.system
965
966 if system is None:
967 # Should not happen if place() was called first, but handle gracefully
968 system_errors = [
969 AssemblyError(
970 loc=SourceLoc(0, 0),
971 category=ErrorCategory.RESOURCE,
972 message="Cannot allocate without SystemConfig. Run place() first.",
973 )
974 ]
975 return replace(graph, errors=errors + system_errors)
976
977 # Collect all nodes and edges
978 all_nodes, all_edges = collect_all_nodes_and_edges(graph)
979 edges_by_source = _build_edge_index(all_edges)
980 edges_by_dest = _build_edge_index_by_dest(all_edges)
981
982 # Validate non-commutative ops with IRAM constants have explicit ports
983 noncomm_errors = _validate_noncommutative_const(all_nodes, edges_by_dest)
984 errors.extend(noncomm_errors)
985
986 # Assign SM IDs to MemOp nodes that lack explicit SM targets
987 sm_updated, sm_errors = _assign_sm_ids(all_nodes, system.sm_count)
988 errors.extend(sm_errors)
989 all_nodes.update(sm_updated)
990
991 # Group nodes by PE
992 nodes_by_pe = _group_nodes_by_pe(all_nodes)
993
994 # First pass: assign IRAM offsets, activation IDs, compute modes and frame layouts
995 intermediate_nodes = {}
996 for pe_id, nodes_on_pe in sorted(nodes_by_pe.items()):
997 # 1. Assign provisional IRAM offsets
998 iram_updated, iram_errors = _assign_iram_offsets(
999 nodes_on_pe,
1000 all_nodes,
1001 system.iram_capacity,
1002 pe_id,
1003 )
1004 errors.extend(iram_errors)
1005
1006 if iram_errors:
1007 # Skip further processing on this PE if IRAM error
1008 continue
1009
1010 # 2. Assign activation IDs
1011 act_updated, act_errors = _assign_act_ids(
1012 list(iram_updated.values()),
1013 all_nodes,
1014 system.frame_count,
1015 pe_id,
1016 call_sites=graph.call_sites,
1017 )
1018 errors.extend(act_errors)
1019
1020 if act_errors:
1021 # Skip further processing on this PE if activation error
1022 continue
1023
1024 # 3. Compute modes (OutputStyle, has_const, dest_count)
1025 mode_updated = _compute_modes(act_updated, edges_by_source)
1026
1027 # 4. Compute frame layouts (assigns fref, frame_layout)
1028 frame_updated, frame_errors = _compute_frame_layouts(
1029 mode_updated,
1030 edges_by_source,
1031 edges_by_dest,
1032 all_nodes,
1033 system.frame_slots,
1034 system.matchable_offsets,
1035 pe_id,
1036 )
1037 errors.extend(frame_errors)
1038
1039 if frame_errors:
1040 # Skip further processing on this PE if frame error
1041 continue
1042
1043 # 5. Deduplicate IRAM entries
1044 deduped = _deduplicate_iram(frame_updated, pe_id)
1045
1046 intermediate_nodes.update(deduped)
1047
1048 # Second pass: resolve destinations using intermediate nodes
1049 updated_all_nodes = {}
1050 for pe_id in sorted(nodes_by_pe.keys()):
1051 # Get nodes from this PE that made it through offset/slot assignment
1052 nodes_on_this_pe = {
1053 name: node for name, node in intermediate_nodes.items()
1054 if node.pe == pe_id
1055 }
1056 if not nodes_on_this_pe:
1057 # This PE had errors, skip it
1058 continue
1059
1060 resolved_updated, resolve_errors = _resolve_destinations(
1061 nodes_on_this_pe,
1062 intermediate_nodes, # Use intermediate nodes for lookups, not original
1063 edges_by_source,
1064 )
1065 errors.extend(resolve_errors)
1066
1067 updated_all_nodes.update(resolved_updated)
1068
1069 # Merge updated_all_nodes with intermediate_nodes (for nodes that didn't get resolved destinations)
1070 # This ensures nodes from PEs with errors still get their offsets/slots
1071 final_nodes = dict(intermediate_nodes)
1072 final_nodes.update(updated_all_nodes)
1073
1074 # Reconstruct the graph with updated nodes
1075 # Need to preserve the tree structure (regions, etc.)
1076 result_graph = update_graph_nodes(graph, final_nodes)
1077 return replace(result_graph, errors=errors)