"""Placement validation and auto-placement pass for the OR1 assembler. Validates user-provided PE placements and performs auto-placement for unplaced nodes. Uses a greedy bin-packing algorithm with locality heuristic to assign unplaced nodes to PEs while respecting IRAM capacity and context slot limits. Reference: Phase 4 and Phase 7 design docs. """ from __future__ import annotations from collections import Counter, defaultdict from dataclasses import replace from asm.errors import AssemblyError, ErrorCategory, ErrorSeverity from asm.ir import ( IRGraph, IRNode, IRRegion, RegionKind, SystemConfig, SourceLoc, collect_all_nodes, update_graph_nodes, DEFAULT_IRAM_CAPACITY, DEFAULT_FRAME_COUNT ) from asm.opcodes import is_dyadic def _infer_system_config(graph: IRGraph) -> SystemConfig: """Infer a SystemConfig from node placements if none is provided. Determines pe_count from the maximum PE ID referenced in node placements + 1. Uses default capacity values (iram_capacity=256, frame_count=8) matching SystemConfig defaults. Args: graph: The IRGraph (may have system=None) Returns: SystemConfig with inferred pe_count and default capacity values """ max_pe_id = -1 # Check all nodes recursively def _find_max_pe(nodes: dict[str, IRNode]) -> None: nonlocal max_pe_id for node in nodes.values(): if node.pe is not None and node.pe > max_pe_id: max_pe_id = node.pe _find_max_pe(graph.nodes) # Check nodes in regions def _check_regions(regions: list[IRRegion]) -> None: for region in regions: _find_max_pe(region.body.nodes) _check_regions(region.body.regions) _check_regions(graph.regions) pe_count = max(1, max_pe_id + 1) # At least 1 PE return SystemConfig( pe_count=pe_count, sm_count=1, # Default to 1 SM iram_capacity=DEFAULT_IRAM_CAPACITY, frame_count=DEFAULT_FRAME_COUNT, loc=SourceLoc(0, 0), ) def _find_node_scope(graph: IRGraph, node_name: str) -> str | None: """Find the function scope of a node. Returns the tag of the function region containing this node, or None if top-level. Args: graph: The IRGraph node_name: Name of the node to find Returns: Function region tag if node is in a function, None if top-level """ # Check if node is in top-level nodes if node_name in graph.nodes: return None # Search in regions recursively def _search_regions(regions: list[IRRegion]) -> str | None: for region in regions: if region.kind == RegionKind.FUNCTION: # Check if node is in this function's body if node_name in region.body.nodes: return region.tag # Recursively search nested regions (shouldn't happen with current design) result = _search_regions(region.body.regions) if result: return result else: # For LOCATION regions, nodes are still top-level conceptually if node_name in region.body.nodes: return None # Search nested regions result = _search_regions(region.body.regions) if result: return result return None return _search_regions(graph.regions) def _build_adjacency(graph: IRGraph, all_nodes: dict[str, IRNode]) -> dict[str, set[str]]: """Build adjacency map from edges: node -> set of connected neighbours. Args: graph: The IRGraph all_nodes: Dictionary of all nodes Returns: Dictionary mapping node names to sets of connected node names """ adjacency: dict[str, set[str]] = defaultdict(set) def _process_edges(edges: list) -> None: for edge in edges: # Both source and dest are neighbours adjacency[edge.source].add(edge.dest) adjacency[edge.dest].add(edge.source) _process_edges(graph.edges) # Also process edges in regions def _process_regions_edges(regions: list[IRRegion]) -> None: for region in regions: _process_edges(region.body.edges) _process_regions_edges(region.body.regions) _process_regions_edges(graph.regions) return adjacency def _count_iram_cost(node: IRNode) -> int: """Count IRAM slots used by a node. In the frame model, all instructions use 1 IRAM slot. Matching is handled by frame SRAM, not IRAM entries. Args: node: The IRNode Returns: Number of IRAM slots used (always 1) """ return 1 def _auto_place_nodes( graph: IRGraph, system: SystemConfig, all_nodes: dict[str, IRNode], adjacency: dict[str, set[str]], ) -> tuple[dict[str, IRNode], list[AssemblyError]]: """Auto-place unplaced nodes using greedy bin-packing with locality heuristic. Algorithm: 1. Identify unplaced nodes (pe=None) 2. For each unplaced node in order: a. Find PE of connected neighbours (use updated_nodes for current placements) b. Prefer PE with most neighbours (locality) c. Tie-break by remaining IRAM capacity d. If no PE has room, record error and continue 3. Return updated nodes and any placement errors Args: graph: The IRGraph system: SystemConfig with pe_count, iram_capacity, frame_count all_nodes: Dictionary of all nodes adjacency: Adjacency map Returns: Tuple of (updated nodes dict, list of placement errors) """ errors: list[AssemblyError] = [] # Track resource usage per PE: (iram_used, frames_used) iram_used = [0] * system.pe_count frames_used = [0] * system.pe_count # Track dyadic offset usage per PE for matchable offset warnings dyadic_offsets_per_pe = [0] * system.pe_count # Copy nodes so we can update placement as we go updated_nodes = dict(all_nodes) # Initialize PE resource usage from explicitly placed nodes # Track which function scopes have been counted per PE to avoid double-counting act_scopes_per_pe: dict[int, set[str | None]] = {pe_id: set() for pe_id in range(system.pe_count)} for node_name, node in updated_nodes.items(): if node.pe is not None: iram_cost = _count_iram_cost(node) iram_used[node.pe] += iram_cost # Count frames per function scope, not per node scope = _find_node_scope(graph, node_name) if scope not in act_scopes_per_pe[node.pe]: act_scopes_per_pe[node.pe].add(scope) frames_used[node.pe] += 1 # Track dyadic offsets for matchable offset warning if is_dyadic(node.opcode, node.const): dyadic_offsets_per_pe[node.pe] += 1 # For unplaced nodes, we'll track scopes similarly act_scopes_updated: dict[int, set[str | None]] = {pe_id: set(scopes) for pe_id, scopes in act_scopes_per_pe.items()} # Process nodes in insertion order for node_name, node in all_nodes.items(): if node.pe is not None: # Already placed, skip continue # Find neighbours and their PEs (from updated_nodes to include newly placed nodes) neighbours = adjacency.get(node_name, set()) neighbour_pes: list[int] = [] for neighbour_name in neighbours: neighbour = updated_nodes.get(neighbour_name) if neighbour and neighbour.pe is not None: neighbour_pes.append(neighbour.pe) # Count PE occurrences among neighbours (for locality heuristic) pe_counts: dict[int, int] = Counter(neighbour_pes) # Sort PEs by: most neighbours first, then most remaining IRAM candidate_pes = list(range(system.pe_count)) candidate_pes.sort( key=lambda pe: ( -pe_counts.get(pe, 0), # Negative so most neighbours come first -(system.iram_capacity - iram_used[pe]), # Then most room ), ) # Find first PE with room iram_cost = _count_iram_cost(node) node_scope = _find_node_scope(graph, node_name) placed = False for pe in candidate_pes: # Check if this scope is new to this PE scope_is_new = node_scope not in act_scopes_updated[pe] frames_needed = 1 if scope_is_new else 0 if ( iram_used[pe] + iram_cost <= system.iram_capacity and frames_used[pe] + frames_needed <= system.frame_count ): # Place node on this PE updated_nodes[node_name] = replace(node, pe=pe) iram_used[pe] += iram_cost if scope_is_new: act_scopes_updated[pe].add(node_scope) frames_used[pe] += 1 # Track dyadic offsets if is_dyadic(node.opcode, node.const): dyadic_offsets_per_pe[pe] += 1 placed = True break if not placed: # No PE has room - generate error with utilization breakdown error = _format_placement_overflow_error(node, system, iram_used, frames_used) errors.append(error) # Check matchable offset limits and emit warnings if exceeded for pe_id in range(system.pe_count): if dyadic_offsets_per_pe[pe_id] > system.matchable_offsets: # Find a node on this PE to use for location warning_node = None for node_name, node in updated_nodes.items(): if node.pe == pe_id and is_dyadic(node.opcode, node.const): warning_node = node break if warning_node: errors.append(AssemblyError( loc=warning_node.loc, category=ErrorCategory.FRAME, message=f"PE {pe_id} uses {dyadic_offsets_per_pe[pe_id]} matchable offsets " f"(limit: {system.matchable_offsets})", severity=ErrorSeverity.WARNING, )) return updated_nodes, errors def _format_placement_overflow_error( node: IRNode, system: SystemConfig, iram_used: list[int], frames_used: list[int], ) -> AssemblyError: """Format a placement overflow error with per-PE utilization breakdown. Args: node: The node that couldn't be placed system: SystemConfig iram_used: List of IRAM slots used per PE frames_used: List of frames used per PE Returns: AssemblyError with detailed breakdown """ breakdown_lines = [] for pe_id in range(system.pe_count): breakdown_lines.append( f" PE{pe_id}: {iram_used[pe_id]}/{system.iram_capacity} IRAM slots, " f"{frames_used[pe_id]}/{system.frame_count} frames" ) breakdown = "\n".join(breakdown_lines) message = f"Cannot place node '{node.name}': all PEs are full.\n{breakdown}" return AssemblyError( loc=node.loc, category=ErrorCategory.PLACEMENT, message=message, suggestions=[], ) def place(graph: IRGraph) -> IRGraph: """Placement pass: validate explicit placements and auto-place unplaced nodes. Process: 1. Infer or use provided SystemConfig 2. Validate explicitly placed nodes (pe is not None) 3. Auto-place any unplaced nodes using greedy bin-packing + locality 4. Validate all PE IDs are < pe_count Args: graph: The IRGraph to place Returns: New IRGraph with all nodes placed and placement errors appended """ # Determine system config system = graph.system if graph.system is not None else _infer_system_config(graph) errors = list(graph.errors) # Collect all nodes all_nodes = collect_all_nodes(graph) # First pass: validate explicitly placed nodes (reject invalid PE IDs) valid_nodes = {} for node_name, node in all_nodes.items(): if node.pe is not None and node.pe >= system.pe_count: error = AssemblyError( loc=node.loc, category=ErrorCategory.PLACEMENT, message=f"Node '{node_name}' placed on PE{node.pe} but system only has {system.pe_count} PEs (0-{system.pe_count - 1}).", suggestions=[], ) errors.append(error) else: valid_nodes[node_name] = node all_nodes = valid_nodes # Check if any nodes are unplaced unplaced_nodes = [node for node in all_nodes.values() if node.pe is None] if unplaced_nodes: # Auto-place unplaced nodes adjacency = _build_adjacency(graph, all_nodes) all_nodes, placement_errors = _auto_place_nodes(graph, system, all_nodes, adjacency) errors.extend(placement_errors) # Update graph with placed nodes # This ensures nodes inside function scopes receive updated PE assignments result_graph = update_graph_nodes(graph, all_nodes) return replace(result_graph, system=system, errors=errors)