OR-1 dataflow CPU sketch
at main 296 lines 12 kB view raw
1from __future__ import annotations 2 3import logging 4from dataclasses import replace 5from typing import TYPE_CHECKING, Optional 6 7import simpy 8 9from cm_inst import MemOp 10from emu.alu import UINT16_MASK 11from emu.events import ( 12 EventCallback, TokenReceived, CellWritten, DeferredRead as DeferredReadEvent, 13 DeferredSatisfied, ResultSent, 14) 15from emu.types import DeferredRead 16from encoding import flit_count, unpack_token 17from sm_mod import Presence, SMCell 18from tokens import CMToken, MonadToken, SMToken, Token 19 20if TYPE_CHECKING: 21 from emu.network import System 22 23logger = logging.getLogger(__name__) 24 25ATOMIC_CELL_LIMIT = 256 26 27 28class StructureMemory: 29 def __init__( 30 self, 31 env: simpy.Environment, 32 sm_id: int, 33 cell_count: int = 512, 34 fifo_capacity: int = 8, 35 tier_boundary: int = 256, 36 on_event: EventCallback | None = None, 37 ): 38 self.env = env 39 self.sm_id = sm_id 40 self.tier_boundary = tier_boundary 41 self.cells: list[SMCell] = [ 42 SMCell(Presence.EMPTY, None, None) for _ in range(cell_count) 43 ] 44 self.deferred_read: Optional[DeferredRead] = None 45 self._deferred_satisfied: Optional[simpy.Event] = None 46 self._deferred_cancelled: bool = False 47 self.input_store: simpy.Store = simpy.Store(env, capacity=fifo_capacity) 48 self.route_table: dict[int, simpy.Store] = {} 49 self.t0_store: list[int] = [] 50 self.system: Optional[System] = None 51 self._on_event: EventCallback = on_event or (lambda _: None) 52 self._component = f"sm:{sm_id}" 53 self.process = env.process(self._run()) 54 55 def _is_t0(self, addr: int) -> bool: 56 """Return True if address falls in the T0 (shared raw storage) region.""" 57 return addr >= self.tier_boundary 58 59 def _run(self): 60 while True: 61 token = yield self.input_store.get() 62 yield self.env.timeout(1) # dequeue cycle 63 self._on_event(TokenReceived(time=self.env.now, component=self._component, token=token)) 64 65 if not isinstance(token, SMToken): 66 logger.warning( 67 "SM%d: unexpected token type: %s", self.sm_id, type(token) 68 ) 69 continue 70 71 addr = token.addr 72 op = token.op 73 74 if self._is_t0(addr): 75 match op: 76 case MemOp.READ: 77 yield from self._handle_t0_read(addr, token) 78 case MemOp.WRITE: 79 self._handle_t0_write(addr, token) 80 yield self.env.timeout(1) # write cycle 81 case MemOp.EXEC: 82 yield from self._handle_exec(addr) 83 case _: 84 logger.warning( 85 "SM%d: I-structure op %s on T0 address %d", 86 self.sm_id, op.name, addr, 87 ) 88 continue 89 90 match op: 91 case MemOp.READ: 92 yield from self._handle_read(addr, token) 93 case MemOp.WRITE: 94 yield from self._handle_write(addr, token) 95 case MemOp.CLEAR: 96 self._handle_clear(addr) 97 yield self.env.timeout(1) # process cycle 98 case MemOp.RD_INC: 99 yield from self._handle_atomic(addr, token, delta=1) 100 case MemOp.RD_DEC: 101 yield from self._handle_atomic(addr, token, delta=-1) 102 case MemOp.CMP_SW: 103 yield from self._handle_cas(addr, token) 104 case MemOp.ALLOC: 105 self._handle_alloc(addr) 106 yield self.env.timeout(1) # process cycle 107 case MemOp.FREE: 108 self._handle_clear(addr) 109 yield self.env.timeout(1) # process cycle 110 case MemOp.EXEC: 111 logger.warning( 112 "SM%d: EXEC on T1 address %d (must be T0)", 113 self.sm_id, addr, 114 ) 115 case MemOp.SET_PAGE | MemOp.WRITE_IMM | MemOp.RAW_READ | MemOp.EXT: 116 raise NotImplementedError( 117 f"SM{self.sm_id}: {op.name} not yet implemented" 118 ) 119 case _: 120 logger.warning("SM%d: unknown op %s", self.sm_id, op) 121 122 def _handle_read(self, addr: int, token: SMToken): 123 cell = self.cells[addr] 124 125 if cell.pres == Presence.FULL: 126 yield self.env.timeout(1) # process cycle 127 yield from self._send_result(token.ret, cell.data_l) 128 return 129 130 if self.deferred_read is not None: 131 self.env.process(self._wait_and_retry_read(addr, token)) 132 return 133 134 self.deferred_read = DeferredRead(cell_addr=addr, return_route=token.ret) 135 old_pres = cell.pres 136 cell.pres = Presence.WAITING 137 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.WAITING)) 138 self._on_event(DeferredReadEvent(time=self.env.now, component=self._component, addr=addr)) 139 yield self.env.timeout(1) # process cycle (set WAITING) 140 141 def _wait_and_retry_read(self, addr: int, token: SMToken): 142 self._deferred_satisfied = self.env.event() 143 yield self._deferred_satisfied 144 self._deferred_satisfied = None 145 if self._deferred_cancelled: 146 self._deferred_cancelled = False 147 return 148 yield from self._handle_read(addr, token) 149 150 def _handle_write(self, addr: int, token: SMToken): 151 cell = self.cells[addr] 152 153 if ( 154 cell.pres == Presence.WAITING 155 and self.deferred_read is not None 156 and self.deferred_read.cell_addr == addr 157 ): 158 return_route = self.deferred_read.return_route 159 self.deferred_read = None 160 old_pres = cell.pres 161 cell.pres = Presence.FULL 162 cell.data_l = token.data 163 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.FULL)) 164 self._on_event(DeferredSatisfied(time=self.env.now, component=self._component, addr=addr, data=token.data)) 165 if self._deferred_satisfied is not None: 166 self._deferred_satisfied.succeed() 167 yield self.env.timeout(1) # process cycle (write + satisfy) 168 yield from self._send_result(return_route, token.data) 169 return 170 171 old_pres = cell.pres 172 cell.pres = Presence.FULL 173 cell.data_l = token.data 174 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.FULL)) 175 yield self.env.timeout(1) # write cycle 176 177 def _handle_clear(self, addr: int): 178 cell = self.cells[addr] 179 old_pres = cell.pres 180 cell.pres = Presence.EMPTY 181 cell.data_l = None 182 cell.data_r = None 183 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.EMPTY)) 184 185 if self.deferred_read is not None and self.deferred_read.cell_addr == addr: 186 self.deferred_read = None 187 self._deferred_cancelled = True 188 if self._deferred_satisfied is not None: 189 self._deferred_satisfied.succeed() 190 191 def _handle_alloc(self, addr: int): 192 cell = self.cells[addr] 193 if cell.pres == Presence.EMPTY: 194 old_pres = cell.pres 195 cell.pres = Presence.RESERVED 196 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.RESERVED)) 197 198 def _handle_atomic(self, addr: int, token: SMToken, delta: int): 199 if addr >= ATOMIC_CELL_LIMIT: 200 logger.warning( 201 "SM%d: atomic op on cell %d >= %d", self.sm_id, addr, ATOMIC_CELL_LIMIT 202 ) 203 return 204 205 cell = self.cells[addr] 206 if cell.pres != Presence.FULL: 207 logger.warning("SM%d: atomic op on non-FULL cell %d", self.sm_id, addr) 208 return 209 210 old_value = cell.data_l if cell.data_l is not None else 0 211 cell.data_l = (old_value + delta) & UINT16_MASK 212 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=Presence.FULL, new_pres=Presence.FULL)) 213 yield self.env.timeout(1) # read-modify-write cycle 214 yield from self._send_result(token.ret, old_value) 215 216 def _handle_cas(self, addr: int, token: SMToken): 217 if addr >= ATOMIC_CELL_LIMIT: 218 logger.warning( 219 "SM%d: CAS on cell %d >= %d", self.sm_id, addr, ATOMIC_CELL_LIMIT 220 ) 221 return 222 223 cell = self.cells[addr] 224 if cell.pres != Presence.FULL: 225 logger.warning("SM%d: CAS on non-FULL cell %d", self.sm_id, addr) 226 return 227 228 old_value = cell.data_l if cell.data_l is not None else 0 229 expected = token.flags if token.flags is not None else 0 230 if old_value == expected: 231 cell.data_l = token.data 232 yield self.env.timeout(1) # compare-and-swap cycle 233 yield from self._send_result(token.ret, old_value) 234 235 def _send_result(self, return_route: CMToken, data: int): 236 result = replace(return_route, data=data) 237 self._on_event(ResultSent(time=self.env.now, component=self._component, token=result)) 238 yield self.env.timeout(1) # response/delivery cycle (inline, blocks SM) 239 yield self.route_table[return_route.target].put(result) 240 241 def _handle_t0_read(self, addr: int, token: SMToken): 242 """T0 READ: return stored data immediately, no presence tracking or deferral.""" 243 if token.ret is None: 244 return 245 t0_idx = addr - self.tier_boundary 246 yield self.env.timeout(1) # process cycle 247 if t0_idx < len(self.t0_store): 248 data = self.t0_store[t0_idx] 249 yield from self._send_result(token.ret, data) 250 else: 251 yield from self._send_result(token.ret, 0) 252 253 def _handle_t0_write(self, addr: int, token: SMToken): 254 """T0 WRITE: store raw integer data without presence checking. 255 256 Stores token.data (an int) into t0_store at the T0-relative index. 257 For EXEC bootstrap, pre-load t0_store with packed flit sequences 258 (via pack_token()) before simulation starts. T0 is now list[int] with 259 token serialisation/deserialisation via pack_token()/unpack_token(). 260 """ 261 t0_idx = addr - self.tier_boundary 262 while len(self.t0_store) <= t0_idx: 263 self.t0_store.append(0) 264 self.t0_store[t0_idx] = token.data 265 266 def _handle_exec(self, addr: int): 267 """EXEC: read raw int sequences from T0, reconstitute tokens, inject into network. 268 269 Uses flit_count() to determine packet boundaries, unpack_token() to reconstitute 270 tokens from raw int sequences. Uses send() to properly trigger SimPy Store.put() 271 events, ensuring tokens wake up pending get() operations in target PEs/SMs. 272 """ 273 if self.system is None: 274 logger.warning("SM%d: EXEC but no system reference", self.sm_id) 275 return 276 t0_idx = addr - self.tier_boundary 277 if t0_idx >= len(self.t0_store): 278 return 279 yield self.env.timeout(1) # process cycle 280 281 pos = t0_idx 282 while pos < len(self.t0_store): 283 header = self.t0_store[pos] 284 count = flit_count(header) 285 if pos + count > len(self.t0_store): 286 logger.warning("SM%d: truncated packet at T0[%d], stopping EXEC", self.sm_id, pos) 287 break 288 flits = self.t0_store[pos:pos + count] 289 try: 290 token = unpack_token(flits) 291 except (ValueError, KeyError): 292 logger.warning("SM%d: failed to unpack token at T0[%d], stopping EXEC", self.sm_id, pos) 293 break 294 yield from self.system.send(token) 295 yield self.env.timeout(1) # per-token injection cycle 296 pos += count