"""Resource allocation pass for the OR1 assembler. Allocates IRAM offsets, activation IDs, frame layouts, and resolves symbolic destinations to FrameDest values with PE, offset, activation, and port information. Reference: Phase 5 design doc, Tasks 1-4. """ from __future__ import annotations from dataclasses import replace from collections import defaultdict from asm.errors import AssemblyError, ErrorCategory, ErrorSeverity from asm.ir import IRGraph, IRNode, IREdge, SourceLoc, ResolvedDest, CallSite, collect_all_nodes_and_edges, update_graph_nodes from asm.opcodes import is_dyadic, is_monadic from cm_inst import ArithOp, LogicOp, MemOp, Port, RoutingOp, OutputStyle, FrameDest, TokenKind # Module-level constants _SINK_MEMOPS = frozenset({MemOp.WRITE, MemOp.CLEAR, MemOp.FREE, MemOp.SET_PAGE, MemOp.WRITE_IMM}) def _group_nodes_by_pe(nodes: dict[str, IRNode]) -> dict[int, list[IRNode]]: """Group nodes by their PE assignment. Args: nodes: Dictionary of all nodes Returns: Dictionary mapping PE ID to list of nodes on that PE """ groups: dict[int, list[IRNode]] = defaultdict(list) for node in nodes.values(): if node.pe is not None: groups[node.pe].append(node) return groups def _extract_function_scope(node_name: str) -> str: """Extract function scope from qualified node name. Strips macro scope segments (starting with #) before extracting the function scope. Macro scopes are for name uniqueness only — they don't allocate context slots. Examples: "$main.&add" -> "$main" "$main.#loop_0.&counter" -> "$main" (macro segment stripped) "#loop_0.&counter" -> "" (macro at root scope) "$func.#outer_1.#inner_2.&label" -> "$func" (all macro segments stripped) "&top_level" -> "" (root scope) Args: node_name: Qualified node name Returns: Function scope name, or empty string for root scope """ if "." not in node_name: return "" # Split by dots and filter out segments starting with # segments = node_name.split(".") filtered = [seg for seg in segments if not seg.startswith("#")] if not filtered: # All segments were macro scopes return "" # Return the first non-macro segment if it starts with $, else root scope first_segment = filtered[0] if first_segment.startswith("$"): return first_segment return "" def _assign_iram_offsets( nodes_on_pe: list[IRNode], all_nodes: dict[str, IRNode], iram_capacity: int, pe_id: int, ) -> tuple[dict[str, IRNode], list[AssemblyError]]: """Assign provisional IRAM offsets to nodes on a PE. Dyadic instructions get offsets 0..D-1, monadic get D..M-1. All instructions cost 1 slot (per Phase 4 change). Deduplication happens later via _deduplicate_iram() after frame layouts and modes are computed. Args: nodes_on_pe: List of nodes on this PE all_nodes: All nodes (for name lookup) iram_capacity: Maximum IRAM slots for this PE pe_id: The PE ID (for error messages) Returns: Tuple of (updated_nodes dict, errors list) """ errors = [] updated_nodes = {} # Partition into dyadic and monadic, preserving order within each partition # Seed nodes are excluded — they don't occupy IRAM slots dyadic_nodes = [] monadic_nodes = [] for node in nodes_on_pe: if node.seed: continue if is_dyadic(node.opcode, node.const): dyadic_nodes.append(node) else: monadic_nodes.append(node) # Assign offsets total = len(dyadic_nodes) + len(monadic_nodes) if total > iram_capacity: # Generate overflow error error_msg = f"PE{pe_id} IRAM overflow: {total} instructions but only {iram_capacity} slots.\n" if dyadic_nodes: dyadic_names = ", ".join([n.name.split('.')[-1] for n in dyadic_nodes]) error_msg += f" Dyadic: {dyadic_names} ({len(dyadic_nodes)} instructions)\n" if monadic_nodes: monadic_names = ", ".join([n.name.split('.')[-1] for n in monadic_nodes]) error_msg += f" Monadic: {monadic_names} ({len(monadic_nodes)} instructions)" error = AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.RESOURCE, message=error_msg, ) errors.append(error) return {}, errors # Assign offsets for offset, node in enumerate(dyadic_nodes): updated_nodes[node.name] = replace(node, iram_offset=offset) for offset, node in enumerate(monadic_nodes): updated_nodes[node.name] = replace(node, iram_offset=len(dyadic_nodes) + offset) return updated_nodes, errors def _deduplicate_iram( nodes_on_pe: dict[str, IRNode], pe_id: int, ) -> dict[str, IRNode]: """Deduplicate IRAM entries for nodes that produce identical Instruction templates. Two nodes share an IRAM offset when they have identical: opcode (type and value), output (OutputStyle), has_const, dest_count, wide, fref. Args: nodes_on_pe: Dictionary of nodes on this PE pe_id: The PE ID (for diagnostics) Returns: Updated nodes dictionary with deduplicated IRAM offsets """ template_to_offset: dict[tuple, int] = {} updated = {} for name, node in nodes_on_pe.items(): if node.seed or node.iram_offset is None: updated[name] = node continue # Build template key from the fields that make an Instruction mode = node.mode # (OutputStyle, has_const, dest_count) if mode is None: updated[name] = node continue # Include opcode type to distinguish between ArithOp.ADD and MemOp.READ # which both have value 0 template_key = ( type(node.opcode).__name__, # Include type name to disambiguate IntEnums int(node.opcode), # Include numeric value mode[0].name, # Include OutputStyle name string mode[1], # has_const mode[2], # dest_count node.wide, node.fref, ) if template_key in template_to_offset: # Reuse existing offset updated[name] = replace(node, iram_offset=template_to_offset[template_key]) else: template_to_offset[template_key] = node.iram_offset updated[name] = node return updated def _compute_modes( nodes_on_pe: dict[str, IRNode], edges_by_source: dict[str, list[IREdge]], ) -> dict[str, IRNode]: """Compute (OutputStyle, has_const, dest_count) for each node from edge topology. Args: nodes_on_pe: Dictionary of nodes on this PE edges_by_source: Edges indexed by source node name Returns: Updated nodes dictionary with mode field set """ updated = {} for name, node in nodes_on_pe.items(): if node.seed: updated[name] = node continue out_edges = edges_by_source.get(name, []) dest_count = len(out_edges) has_const = node.const is not None # Determine OutputStyle from opcode and edge topology if isinstance(node.opcode, MemOp): # SM instructions: mode depends on MemOp semantics # WRITE, CLEAR, FREE, SET_PAGE, WRITE_IMM → no return value → SINK # READ, RD_INC, RD_DEC, CMP_SW, RAW_READ, EXT → return value → INHERIT if node.opcode in _SINK_MEMOPS: output = OutputStyle.SINK dest_count = 0 else: output = OutputStyle.INHERIT elif node.opcode == RoutingOp.FREE_FRAME: output = OutputStyle.SINK dest_count = 0 elif node.opcode == RoutingOp.EXTRACT_TAG: output = OutputStyle.INHERIT else: # Check if any outgoing edge has ctx_override=True # ctx_override=True on an edge means the source node is a # cross-function return — its left operand carries a packed # flit 1 (from EXTRACT_TAG) that determines the destination. # In the frame model, this maps to OutputStyle.CHANGE_TAG. has_ctx_override = any(e.ctx_override for e in out_edges) if has_ctx_override: output = OutputStyle.CHANGE_TAG dest_count = 1 elif dest_count == 0: output = OutputStyle.SINK else: output = OutputStyle.INHERIT mode = (output, has_const, dest_count) updated[name] = replace(node, mode=mode) return updated def _compute_frame_layouts( nodes_on_pe: dict[str, IRNode], edges_by_source: dict[str, list[IREdge]], edges_by_dest: dict[str, list[IREdge]], all_nodes: dict[str, IRNode], frame_slots: int, matchable_offsets: int, pe_id: int, ) -> tuple[dict[str, IRNode], list[AssemblyError]]: """Compute frame slot layouts per activation. Slot assignment order: 0 to matchable_offsets-1: match operands (one pair per dyadic instruction) then: constants (deduplicated by value) then: destinations (deduplicated by FrameDest identity) then: sinks and SM parameters All activations of the same function share the canonical layout. Args: nodes_on_pe: Dictionary of nodes on this PE edges_by_source: Edges indexed by source node name edges_by_dest: Edges indexed by destination node name all_nodes: All nodes in graph frame_slots: Total slots per frame matchable_offsets: Number of hardware match slots (const/dest slots start here) pe_id: The PE ID (for error messages) Returns: Tuple of (updated_nodes dict, errors list) """ from asm.ir import FrameLayout, FrameSlotMap errors = [] updated = {} # Group nodes by activation ID nodes_by_act_id = defaultdict(list) for name, node in nodes_on_pe.items(): if not node.seed and node.act_id is not None: nodes_by_act_id[node.act_id].append(node) # Compute frame layout for each activation act_id_to_layout = {} # act_id -> FrameLayout for act_id, nodes_in_act in nodes_by_act_id.items(): # Count dyadic instructions (match operands) dyadic_count = sum(1 for n in nodes_in_act if is_dyadic(n.opcode, n.const)) # Warn if dyadic_count exceeds hardware match slots if dyadic_count > matchable_offsets: warning = AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.FRAME, severity=ErrorSeverity.WARNING, message=( f"PE{pe_id} activation {act_id}: {dyadic_count} dyadic instructions " f"but only {matchable_offsets} hardware match slots. " f"Liveness-based slot sharing required (see AC5.4 note)." ), ) errors.append(warning) # Collect constants (deduplicated by value) unique_const_values = set() for n in nodes_in_act: if n.const is not None and not isinstance(n.const, (str, type(None))): unique_const_values.add(n.const) const_count = len(unique_const_values) # Collect destination slots: each node needs dest_count slots for its destinations # These are not deduplicated - each node gets its own slot(s) dest_count = sum(n.mode[2] for n in nodes_in_act if n.mode is not None) # Count slots needed # Match slots reserved at 0 to matchable_offsets-1 (regardless of dyadic_count) match_slot_count = matchable_offsets const_slot_count = const_count dest_slot_count = dest_count # SM params and sinks (only count actual sink MemOps) sink_slot_count = len([n for n in nodes_in_act if isinstance(n.opcode, MemOp) and n.opcode in _SINK_MEMOPS]) total_slots = match_slot_count + const_slot_count + dest_slot_count + sink_slot_count if total_slots > frame_slots: error_msg = ( f"Frame slot overflow on PE{pe_id} activation {act_id}: " f"{total_slots} slots needed, {frame_slots} available.\n" f" Match region (reserved): {match_slot_count}\n" f" Constants: {const_slot_count}\n" f" Destinations: {dest_slot_count}\n" f" Sinks/SM params: {sink_slot_count}" ) error = AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.FRAME, message=error_msg, ) errors.append(error) continue # Build frame layout # NOTE: With interleaved const+dest allocation, const_slots and dest_slots are not # contiguous separate regions. They're interleaved per node. The slot_map below is # therefore approximate for documentation; the actual slot allocation comes from node frefs. match_slots = tuple(range(match_slot_count)) # All non-match slots are either const or dest type # For documentation: mark const slots and dest slots # This is approximate since they're interleaved const_start = match_slot_count const_slots = tuple(range(const_start, const_start + const_slot_count)) dest_start = const_start + const_slot_count dest_slots = tuple(range(dest_start, dest_start + dest_slot_count)) sink_start = dest_start + dest_slot_count sink_slots = tuple(range(sink_start, sink_start + sink_slot_count)) slot_map = FrameSlotMap( match_slots=match_slots, const_slots=const_slots, dest_slots=dest_slots, sink_slots=sink_slots, ) layout = FrameLayout(slot_map=slot_map, total_slots=total_slots) act_id_to_layout[act_id] = layout # Assign frame layouts and frefs to nodes # First, build per-activation node-to-fref mapping act_id_to_node_frefs = {} # act_id -> {node_name -> fref} for act_id, nodes_in_act in nodes_by_act_id.items(): if act_id not in act_id_to_layout: continue layout = act_id_to_layout[act_id] # Assign fref to each node in order # const_nodes need: [const, dest1, dest2, ...] layout # no-const,dest_nodes need: [dest1, dest2, ...] layout # Allocation order: process nodes in sorted order, mixing const and dest allocations # to ensure that const nodes get const_slot at fref with dests at fref+1 onward node_frefs = {} # Separate nodes by type const_nodes = [] no_const_dest_nodes = [] no_const_no_dest_nodes = [] sink_nodes = [] for node in sorted(nodes_in_act, key=lambda n: n.name): if node.seed or node.mode is None: continue output_style, has_const, dest_count = node.mode if output_style == OutputStyle.SINK: sink_nodes.append(node) elif has_const: const_nodes.append(node) elif dest_count > 0: no_const_dest_nodes.append(node) else: # No const, no destinations - still need a slot for frame matching no_const_no_dest_nodes.append(node) # Allocate const nodes first (they get const slot + dest slots) slot_counter = matchable_offsets # Start after match region for node in const_nodes: _, has_const, dest_count = node.mode # Assign fref to const slot node_frefs[node.name] = slot_counter slot_counter += 1 + dest_count # const + dests # Then allocate no-const,dest nodes for node in no_const_dest_nodes: _, has_const, dest_count = node.mode # Assign fref to first dest slot node_frefs[node.name] = slot_counter slot_counter += dest_count # Then allocate no-const,no-dest nodes (still need fref slot for result writeback in SINK mode) for node in no_const_no_dest_nodes: node_frefs[node.name] = slot_counter slot_counter += 1 # Finally allocate sink nodes for node in sink_nodes: node_frefs[node.name] = slot_counter slot_counter += 1 act_id_to_node_frefs[act_id] = node_frefs # Now assign frame layouts and frefs to nodes for name, node in nodes_on_pe.items(): if node.seed: updated[name] = node continue if node.act_id is not None and node.act_id in act_id_to_layout: layout = act_id_to_layout[node.act_id] # Get fref from the per-activation mapping node_frefs = act_id_to_node_frefs.get(node.act_id, {}) fref = node_frefs.get(name) if fref is not None: updated[name] = replace(node, frame_layout=layout, fref=fref) else: updated[name] = replace(node, frame_layout=layout) else: updated[name] = node return updated, errors def _assign_act_ids( nodes_on_pe: list[IRNode], all_nodes: dict[str, IRNode], frame_count: int, pe_id: int, call_sites: list[CallSite] | None = None, ) -> tuple[dict[str, IRNode], list[AssemblyError]]: """Assign activation IDs (0 to frame_count-1) per function scope per PE. Implements per-call-site activation allocation: - Root scope always gets act_id=0 - Functions without call sites get one act_id by the existing scope rule - Each call site allocates a fresh act_id on the PE(s) where the callee lives Args: nodes_on_pe: List of nodes on this PE all_nodes: All nodes (for name lookup) frame_count: Maximum activation IDs for this PE (default 8) pe_id: The PE ID (for error messages) call_sites: Optional list of CallSite objects for per-call-site allocation Returns: Tuple of (updated_nodes dict, errors list) """ if call_sites is None: call_sites = [] errors = [] updated_nodes = {} # Build global mapping of which nodes belong to which call sites # Trampoline and free_frame nodes get the call site's act_id callsite_for_node = {} # node_name -> CallSite for call_site in call_sites: for tramp_node in call_site.trampoline_nodes: callsite_for_node[tramp_node] = call_site for free_node in call_site.free_frame_nodes: callsite_for_node[free_node] = call_site # Build mapping: function scope -> call site (for function body nodes) func_scope_to_callsite = {} # func_name -> CallSite for call_site in call_sites: func_scope_to_callsite[call_site.func_name] = call_site # Allocate activation IDs for this PE next_act_id = 0 act_breakdown = {} # For overflow error reporting scope_to_act_id = {} root_act_id = 0 # Default root scope activation # Check if there are any root-scope nodes on this PE has_root_scope_nodes = any( not node.seed and _extract_function_scope(node.name) == "" for node in nodes_on_pe ) # Root scope always gets act_id=0 if it has nodes on this PE if has_root_scope_nodes: scope_to_act_id[""] = root_act_id act_breakdown["root"] = 1 next_act_id = 1 for node in nodes_on_pe: if node.seed: continue scope = _extract_function_scope(node.name) # Only process function scopes not already assigned if scope and scope not in scope_to_act_id: # Check if this function has any call sites has_call_sites = any(cs.func_name == scope for cs in call_sites) if not has_call_sites: # No call sites, assign one slot (per-scope per-PE) if next_act_id >= frame_count: # Overflow error_msg = _build_activation_overflow_message( pe_id, frame_count, next_act_id, act_breakdown ) error = AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.FRAME, message=error_msg, ) errors.append(error) return {}, errors scope_to_act_id[scope] = next_act_id act_breakdown[scope] = 1 next_act_id += 1 # Now allocate per-call-site activation IDs (one per call site per PE) # Build mapping: call_site -> act_id on this PE call_site_to_act_id_on_pe = {} for call_site in call_sites: # Check if any trampoline or free_frame node for this call site is on this PE has_node_on_pe = False for tramp_node in call_site.trampoline_nodes: if tramp_node in all_nodes and all_nodes[tramp_node].pe == pe_id: has_node_on_pe = True break if not has_node_on_pe: for free_node in call_site.free_frame_nodes: if free_node in all_nodes and all_nodes[free_node].pe == pe_id: has_node_on_pe = True break if has_node_on_pe: # This call site has nodes on this PE, allocate an activation ID if next_act_id >= frame_count: # Overflow error_msg = _build_activation_overflow_message( pe_id, frame_count, next_act_id, act_breakdown ) error = AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.FRAME, message=error_msg, ) errors.append(error) return {}, errors call_site_to_act_id_on_pe[call_site] = next_act_id act_breakdown[f"{call_site.func_name} call site #{call_site.call_id}"] = 1 next_act_id += 1 # Check budget warning (75%) if frame_count > 0: utilisation = next_act_id / frame_count if utilisation >= 0.75: percent = int(utilisation * 100) warning = AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.FRAME, severity=ErrorSeverity.WARNING, message=f"PE{pe_id}: {next_act_id}/{frame_count} activation IDs used ({percent}%)", ) errors.append(warning) # Assign activation IDs to nodes for node in nodes_on_pe: if node.seed: continue # Check if this node is a trampoline or free_frame node for a call site act_id_value = None if node.name in callsite_for_node: call_site = callsite_for_node[node.name] act_id_value = call_site_to_act_id_on_pe.get(call_site) # If not part of a call site, check if it's a function body node if act_id_value is None: scope = _extract_function_scope(node.name) if scope in func_scope_to_callsite: # Function body node — gets the call site's act_id cs = func_scope_to_callsite[scope] act_id_value = call_site_to_act_id_on_pe.get(cs) if act_id_value is None: act_id_value = scope_to_act_id.get(scope, root_act_id) updated_nodes[node.name] = replace(node, act_id=act_id_value) return updated_nodes, errors def _build_activation_overflow_message(pe_id: int, frame_count: int, used: int, breakdown: dict) -> str: """Build a detailed activation ID overflow error message. Args: pe_id: The PE ID frame_count: Total available activation IDs used: Number of IDs needed breakdown: Dictionary mapping scope/call site to ID count Returns: Formatted error message """ lines = [ f"Activation ID exhaustion on PE{pe_id}: {used} IDs needed, {frame_count} available" ] for scope_name, count in breakdown.items(): if scope_name == "root": lines.append(f" Root scope: {count} ID") else: lines.append(f" {scope_name}: {count} ID") lines.append("Consider inlining frequently-called functions to reduce frame pressure.") return "\n".join(lines) def _assign_sm_ids( all_nodes: dict[str, IRNode], sm_count: int, ) -> tuple[dict[str, IRNode], list[AssemblyError]]: """Assign SM IDs to MemOp instruction nodes that lack one. For single-SM systems, defaults to sm_id=0. For multi-SM systems where the SM target is ambiguous, reports an error. Args: all_nodes: Dictionary of all nodes sm_count: Number of SMs in the system Returns: Tuple of (updated nodes dict, list of errors) """ errors: list[AssemblyError] = [] updated: dict[str, IRNode] = {} for name, node in all_nodes.items(): if isinstance(node.opcode, MemOp) and node.sm_id is None: if sm_count == 0: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.RESOURCE, message=f"Node '{name}' uses memory operation '{node.opcode.name}' but system has no SMs.", )) elif sm_count == 1: updated[name] = replace(node, sm_id=0) else: errors.append(AssemblyError( loc=node.loc, category=ErrorCategory.RESOURCE, message=( f"Node '{name}' uses memory operation '{node.opcode.name}' but no SM target specified " f"and system has {sm_count} SMs. Cannot infer target." ), )) return updated, errors _COMMUTATIVE_OPS: frozenset = frozenset({ ArithOp.ADD, LogicOp.AND, LogicOp.OR, LogicOp.XOR, LogicOp.EQ, }) def _build_edge_index_by_dest(edges: list[IREdge]) -> dict[str, list[IREdge]]: """Build index of edges by destination node name.""" index: dict[str, list[IREdge]] = defaultdict(list) for edge in edges: index[edge.dest].append(edge) return index def _validate_noncommutative_const( all_nodes: dict[str, IRNode], edges_by_dest: dict[str, list[IREdge]], ) -> list[AssemblyError]: """Warn when non-commutative dyadic ops with IRAM const lack explicit ports. When a dyadic instruction has a baked-in constant (e.g., `sub 3`), the incoming token goes to whichever port the edge specifies (default: L). For commutative ops this is irrelevant, but for non-commutative ops (sub, lt, gt, etc.) the port determines operand order. If the user didn't specify the port explicitly, emit a warning so they know the implicit default is in effect. Args: all_nodes: All nodes in graph edges_by_dest: Edges indexed by destination node name Returns: List of warnings for implicit port assignments """ warnings: list[AssemblyError] = [] for name, node in all_nodes.items(): if node.seed: continue if node.const is None: continue if is_monadic(node.opcode, node.const): continue if node.opcode in _COMMUTATIVE_OPS: continue incoming = edges_by_dest.get(name, []) for edge in incoming: if not edge.port_explicit: warnings.append(AssemblyError( loc=edge.loc, category=ErrorCategory.PORT, severity=ErrorSeverity.WARNING, message=( f"Non-commutative op '{node.opcode.name}' on node '{name}' " f"has an IRAM constant but incoming edge from '{edge.source}' " f"has no explicit port (:L or :R). Defaulting to :L — " f"the token will be the left operand and the constant the right." ), )) return warnings def _build_edge_index(edges: list[IREdge]) -> dict[str, list[IREdge]]: """Build index of edges by source node name. Args: edges: List of all edges Returns: Dictionary mapping source name to list of edges from that source """ index: dict[str, list[IREdge]] = defaultdict(list) for edge in edges: index[edge.source].append(edge) return index def _determine_token_kind(dest_node: IRNode) -> TokenKind: """Determine the token kind for a destination node. Token kind is determined by the destination node's opcode: - If dyadic: TokenKind.DYADIC - If monadic: TokenKind.MONADIC - (TokenKind.INLINE is reserved for future use) Args: dest_node: The destination IRNode Returns: TokenKind enum value """ if is_dyadic(dest_node.opcode, dest_node.const): return TokenKind.DYADIC else: return TokenKind.MONADIC def _resolve_destinations( nodes_on_pe: dict[str, IRNode], all_nodes: dict[str, IRNode], edges_by_source: dict[str, list[IREdge]], ) -> tuple[dict[str, IRNode], list[AssemblyError]]: """Resolve symbolic destinations to FrameDest objects. Uses edge-to-destination mapping rules: - source_port=L -> dest_l - source_port=R -> dest_r - source_port=None: single edge -> dest_l, two edges -> first dest_l, second dest_r Each destination is resolved to a FrameDest object containing target PE, IRAM offset, activation ID, port, and token kind. Args: nodes_on_pe: Nodes on this PE (with iram_offset, act_id set) all_nodes: All nodes in graph edges_by_source: Edges indexed by source node name Returns: Tuple of (updated_nodes dict, errors list) """ errors = [] updated_nodes = {} for node_name, node in nodes_on_pe.items(): updated_node = node source_edges = edges_by_source.get(node_name, []) # Validate edge count if len(source_edges) > 2: error = AssemblyError( loc=node.loc, category=ErrorCategory.PORT, message=f"Node '{node_name}' has {len(source_edges)} outgoing edges, but maximum is 2.", ) errors.append(error) continue # Validate source_port conflicts source_ports = [e.source_port for e in source_edges] explicit_ports = [p for p in source_ports if p is not None] if len(explicit_ports) != len(set(explicit_ports)): error = AssemblyError( loc=node.loc, category=ErrorCategory.PORT, message=f"Node '{node_name}' has conflicting source_port qualifiers.", ) errors.append(error) continue # Validate mixed explicit/implicit if len(explicit_ports) > 0 and len(explicit_ports) < len(source_edges): error = AssemblyError( loc=node.loc, category=ErrorCategory.PORT, message=f"Node '{node_name}' has mixed explicit and implicit source ports.", ) errors.append(error) continue # Resolve edges to destinations if len(source_edges) == 0: # No outgoing edges, keep as-is pass elif len(source_edges) == 1: # Single edge -> dest_l edge = source_edges[0] dest_node = all_nodes.get(edge.dest) if dest_node is None: error = AssemblyError( loc=edge.loc, category=ErrorCategory.NAME, message=f"Edge destination '{edge.dest}' not found.", ) errors.append(error) continue # Skip resolution if destination lacks required fields if dest_node.iram_offset is None or dest_node.act_id is None or dest_node.pe is None: error = AssemblyError( loc=edge.loc, category=ErrorCategory.RESOURCE, message=f"Destination '{edge.dest}' lacks iram_offset, act_id, or PE assignment.", ) errors.append(error) continue frame_dest = FrameDest( target_pe=dest_node.pe, offset=dest_node.iram_offset, act_id=dest_node.act_id, port=edge.port, token_kind=_determine_token_kind(dest_node), ) resolved = ResolvedDest(name=edge.dest, addr=None, frame_dest=frame_dest) updated_node = replace(updated_node, dest_l=resolved) else: # len(source_edges) == 2 # Two edges: map by source_port or order edges = source_edges if explicit_ports: # All explicit: sort by port edges = sorted(edges, key=lambda e: e.source_port) for idx, edge in enumerate(edges): dest_node = all_nodes.get(edge.dest) if dest_node is None: error = AssemblyError( loc=edge.loc, category=ErrorCategory.NAME, message=f"Edge destination '{edge.dest}' not found.", ) errors.append(error) continue # Skip resolution if destination lacks required fields if dest_node.iram_offset is None or dest_node.act_id is None or dest_node.pe is None: error = AssemblyError( loc=edge.loc, category=ErrorCategory.RESOURCE, message=f"Destination '{edge.dest}' lacks iram_offset, act_id, or PE assignment.", ) errors.append(error) continue frame_dest = FrameDest( target_pe=dest_node.pe, offset=dest_node.iram_offset, act_id=dest_node.act_id, port=edge.port, token_kind=_determine_token_kind(dest_node), ) resolved = ResolvedDest(name=edge.dest, addr=None, frame_dest=frame_dest) if idx == 0: updated_node = replace(updated_node, dest_l=resolved) else: updated_node = replace(updated_node, dest_r=resolved) updated_nodes[node_name] = updated_node return updated_nodes, errors def allocate(graph: IRGraph) -> IRGraph: """Allocate resources: IRAM offsets, activation IDs, frame layouts, resolve destinations. Performs the following operations per PE: 1. Assign provisional IRAM offsets (dyadic first, then monadic) 2. Assign activation IDs (0 to frame_count-1) 3. Compute modes (OutputStyle, has_const, dest_count) from edge topology 4. Compute frame layouts (assigns fref, frame_layout) per activation 5. Deduplicate IRAM entries by instruction template 6. Resolve destinations to FrameDest objects with PE, offset, act_id, port, token_kind Args: graph: The IRGraph to allocate Returns: New IRGraph with all nodes updated and allocation errors appended """ errors = list(graph.errors) system = graph.system if system is None: # Should not happen if place() was called first, but handle gracefully system_errors = [ AssemblyError( loc=SourceLoc(0, 0), category=ErrorCategory.RESOURCE, message="Cannot allocate without SystemConfig. Run place() first.", ) ] return replace(graph, errors=errors + system_errors) # Collect all nodes and edges all_nodes, all_edges = collect_all_nodes_and_edges(graph) edges_by_source = _build_edge_index(all_edges) edges_by_dest = _build_edge_index_by_dest(all_edges) # Validate non-commutative ops with IRAM constants have explicit ports noncomm_errors = _validate_noncommutative_const(all_nodes, edges_by_dest) errors.extend(noncomm_errors) # Assign SM IDs to MemOp nodes that lack explicit SM targets sm_updated, sm_errors = _assign_sm_ids(all_nodes, system.sm_count) errors.extend(sm_errors) all_nodes.update(sm_updated) # Group nodes by PE nodes_by_pe = _group_nodes_by_pe(all_nodes) # First pass: assign IRAM offsets, activation IDs, compute modes and frame layouts intermediate_nodes = {} for pe_id, nodes_on_pe in sorted(nodes_by_pe.items()): # 1. Assign provisional IRAM offsets iram_updated, iram_errors = _assign_iram_offsets( nodes_on_pe, all_nodes, system.iram_capacity, pe_id, ) errors.extend(iram_errors) if iram_errors: # Skip further processing on this PE if IRAM error continue # 2. Assign activation IDs act_updated, act_errors = _assign_act_ids( list(iram_updated.values()), all_nodes, system.frame_count, pe_id, call_sites=graph.call_sites, ) errors.extend(act_errors) if act_errors: # Skip further processing on this PE if activation error continue # 3. Compute modes (OutputStyle, has_const, dest_count) mode_updated = _compute_modes(act_updated, edges_by_source) # 4. Compute frame layouts (assigns fref, frame_layout) frame_updated, frame_errors = _compute_frame_layouts( mode_updated, edges_by_source, edges_by_dest, all_nodes, system.frame_slots, system.matchable_offsets, pe_id, ) errors.extend(frame_errors) if frame_errors: # Skip further processing on this PE if frame error continue # 5. Deduplicate IRAM entries deduped = _deduplicate_iram(frame_updated, pe_id) intermediate_nodes.update(deduped) # Second pass: resolve destinations using intermediate nodes updated_all_nodes = {} for pe_id in sorted(nodes_by_pe.keys()): # Get nodes from this PE that made it through offset/slot assignment nodes_on_this_pe = { name: node for name, node in intermediate_nodes.items() if node.pe == pe_id } if not nodes_on_this_pe: # This PE had errors, skip it continue resolved_updated, resolve_errors = _resolve_destinations( nodes_on_this_pe, intermediate_nodes, # Use intermediate nodes for lookups, not original edges_by_source, ) errors.extend(resolve_errors) updated_all_nodes.update(resolved_updated) # Merge updated_all_nodes with intermediate_nodes (for nodes that didn't get resolved destinations) # This ensures nodes from PEs with errors still get their offsets/slots final_nodes = dict(intermediate_nodes) final_nodes.update(updated_all_nodes) # Reconstruct the graph with updated nodes # Need to preserve the tree structure (regions, etc.) result_graph = update_graph_nodes(graph, final_nodes) return replace(result_graph, errors=errors)