from __future__ import annotations import logging from dataclasses import replace from typing import TYPE_CHECKING, Optional import simpy from cm_inst import MemOp from emu.alu import UINT16_MASK from emu.events import ( EventCallback, TokenReceived, CellWritten, DeferredRead as DeferredReadEvent, DeferredSatisfied, ResultSent, ) from emu.types import DeferredRead from encoding import flit_count, unpack_token from sm_mod import Presence, SMCell from tokens import CMToken, MonadToken, SMToken, Token if TYPE_CHECKING: from emu.network import System logger = logging.getLogger(__name__) ATOMIC_CELL_LIMIT = 256 class StructureMemory: def __init__( self, env: simpy.Environment, sm_id: int, cell_count: int = 512, fifo_capacity: int = 8, tier_boundary: int = 256, on_event: EventCallback | None = None, ): self.env = env self.sm_id = sm_id self.tier_boundary = tier_boundary self.cells: list[SMCell] = [ SMCell(Presence.EMPTY, None, None) for _ in range(cell_count) ] self.deferred_read: Optional[DeferredRead] = None self._deferred_satisfied: Optional[simpy.Event] = None self._deferred_cancelled: bool = False self.input_store: simpy.Store = simpy.Store(env, capacity=fifo_capacity) self.route_table: dict[int, simpy.Store] = {} self.t0_store: list[int] = [] self.system: Optional[System] = None self._on_event: EventCallback = on_event or (lambda _: None) self._component = f"sm:{sm_id}" self.process = env.process(self._run()) def _is_t0(self, addr: int) -> bool: """Return True if address falls in the T0 (shared raw storage) region.""" return addr >= self.tier_boundary def _run(self): while True: token = yield self.input_store.get() yield self.env.timeout(1) # dequeue cycle self._on_event(TokenReceived(time=self.env.now, component=self._component, token=token)) if not isinstance(token, SMToken): logger.warning( "SM%d: unexpected token type: %s", self.sm_id, type(token) ) continue addr = token.addr op = token.op if self._is_t0(addr): match op: case MemOp.READ: yield from self._handle_t0_read(addr, token) case MemOp.WRITE: self._handle_t0_write(addr, token) yield self.env.timeout(1) # write cycle case MemOp.EXEC: yield from self._handle_exec(addr) case _: logger.warning( "SM%d: I-structure op %s on T0 address %d", self.sm_id, op.name, addr, ) continue match op: case MemOp.READ: yield from self._handle_read(addr, token) case MemOp.WRITE: yield from self._handle_write(addr, token) case MemOp.CLEAR: self._handle_clear(addr) yield self.env.timeout(1) # process cycle case MemOp.RD_INC: yield from self._handle_atomic(addr, token, delta=1) case MemOp.RD_DEC: yield from self._handle_atomic(addr, token, delta=-1) case MemOp.CMP_SW: yield from self._handle_cas(addr, token) case MemOp.ALLOC: self._handle_alloc(addr) yield self.env.timeout(1) # process cycle case MemOp.FREE: self._handle_clear(addr) yield self.env.timeout(1) # process cycle case MemOp.EXEC: logger.warning( "SM%d: EXEC on T1 address %d (must be T0)", self.sm_id, addr, ) case MemOp.SET_PAGE | MemOp.WRITE_IMM | MemOp.RAW_READ | MemOp.EXT: raise NotImplementedError( f"SM{self.sm_id}: {op.name} not yet implemented" ) case _: logger.warning("SM%d: unknown op %s", self.sm_id, op) def _handle_read(self, addr: int, token: SMToken): cell = self.cells[addr] if cell.pres == Presence.FULL: yield self.env.timeout(1) # process cycle yield from self._send_result(token.ret, cell.data_l) return if self.deferred_read is not None: self.env.process(self._wait_and_retry_read(addr, token)) return self.deferred_read = DeferredRead(cell_addr=addr, return_route=token.ret) old_pres = cell.pres cell.pres = Presence.WAITING self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.WAITING)) self._on_event(DeferredReadEvent(time=self.env.now, component=self._component, addr=addr)) yield self.env.timeout(1) # process cycle (set WAITING) def _wait_and_retry_read(self, addr: int, token: SMToken): self._deferred_satisfied = self.env.event() yield self._deferred_satisfied self._deferred_satisfied = None if self._deferred_cancelled: self._deferred_cancelled = False return yield from self._handle_read(addr, token) def _handle_write(self, addr: int, token: SMToken): cell = self.cells[addr] if ( cell.pres == Presence.WAITING and self.deferred_read is not None and self.deferred_read.cell_addr == addr ): return_route = self.deferred_read.return_route self.deferred_read = None old_pres = cell.pres cell.pres = Presence.FULL cell.data_l = token.data self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.FULL)) self._on_event(DeferredSatisfied(time=self.env.now, component=self._component, addr=addr, data=token.data)) if self._deferred_satisfied is not None: self._deferred_satisfied.succeed() yield self.env.timeout(1) # process cycle (write + satisfy) yield from self._send_result(return_route, token.data) return old_pres = cell.pres cell.pres = Presence.FULL cell.data_l = token.data self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.FULL)) yield self.env.timeout(1) # write cycle def _handle_clear(self, addr: int): cell = self.cells[addr] old_pres = cell.pres cell.pres = Presence.EMPTY cell.data_l = None cell.data_r = None self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.EMPTY)) if self.deferred_read is not None and self.deferred_read.cell_addr == addr: self.deferred_read = None self._deferred_cancelled = True if self._deferred_satisfied is not None: self._deferred_satisfied.succeed() def _handle_alloc(self, addr: int): cell = self.cells[addr] if cell.pres == Presence.EMPTY: old_pres = cell.pres cell.pres = Presence.RESERVED self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.RESERVED)) def _handle_atomic(self, addr: int, token: SMToken, delta: int): if addr >= ATOMIC_CELL_LIMIT: logger.warning( "SM%d: atomic op on cell %d >= %d", self.sm_id, addr, ATOMIC_CELL_LIMIT ) return cell = self.cells[addr] if cell.pres != Presence.FULL: logger.warning("SM%d: atomic op on non-FULL cell %d", self.sm_id, addr) return old_value = cell.data_l if cell.data_l is not None else 0 cell.data_l = (old_value + delta) & UINT16_MASK self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=Presence.FULL, new_pres=Presence.FULL)) yield self.env.timeout(1) # read-modify-write cycle yield from self._send_result(token.ret, old_value) def _handle_cas(self, addr: int, token: SMToken): if addr >= ATOMIC_CELL_LIMIT: logger.warning( "SM%d: CAS on cell %d >= %d", self.sm_id, addr, ATOMIC_CELL_LIMIT ) return cell = self.cells[addr] if cell.pres != Presence.FULL: logger.warning("SM%d: CAS on non-FULL cell %d", self.sm_id, addr) return old_value = cell.data_l if cell.data_l is not None else 0 expected = token.flags if token.flags is not None else 0 if old_value == expected: cell.data_l = token.data yield self.env.timeout(1) # compare-and-swap cycle yield from self._send_result(token.ret, old_value) def _send_result(self, return_route: CMToken, data: int): result = replace(return_route, data=data) self._on_event(ResultSent(time=self.env.now, component=self._component, token=result)) yield self.env.timeout(1) # response/delivery cycle (inline, blocks SM) yield self.route_table[return_route.target].put(result) def _handle_t0_read(self, addr: int, token: SMToken): """T0 READ: return stored data immediately, no presence tracking or deferral.""" if token.ret is None: return t0_idx = addr - self.tier_boundary yield self.env.timeout(1) # process cycle if t0_idx < len(self.t0_store): data = self.t0_store[t0_idx] yield from self._send_result(token.ret, data) else: yield from self._send_result(token.ret, 0) def _handle_t0_write(self, addr: int, token: SMToken): """T0 WRITE: store raw integer data without presence checking. Stores token.data (an int) into t0_store at the T0-relative index. For EXEC bootstrap, pre-load t0_store with packed flit sequences (via pack_token()) before simulation starts. T0 is now list[int] with token serialisation/deserialisation via pack_token()/unpack_token(). """ t0_idx = addr - self.tier_boundary while len(self.t0_store) <= t0_idx: self.t0_store.append(0) self.t0_store[t0_idx] = token.data def _handle_exec(self, addr: int): """EXEC: read raw int sequences from T0, reconstitute tokens, inject into network. Uses flit_count() to determine packet boundaries, unpack_token() to reconstitute tokens from raw int sequences. Uses send() to properly trigger SimPy Store.put() events, ensuring tokens wake up pending get() operations in target PEs/SMs. """ if self.system is None: logger.warning("SM%d: EXEC but no system reference", self.sm_id) return t0_idx = addr - self.tier_boundary if t0_idx >= len(self.t0_store): return yield self.env.timeout(1) # process cycle pos = t0_idx while pos < len(self.t0_store): header = self.t0_store[pos] count = flit_count(header) if pos + count > len(self.t0_store): logger.warning("SM%d: truncated packet at T0[%d], stopping EXEC", self.sm_id, pos) break flits = self.t0_store[pos:pos + count] try: token = unpack_token(flits) except (ValueError, KeyError): logger.warning("SM%d: failed to unpack token at T0[%d], stopping EXEC", self.sm_id, pos) break yield from self.system.send(token) yield self.env.timeout(1) # per-token injection cycle pos += count