OR-1 dataflow CPU sketch
1from __future__ import annotations
2
3import logging
4from dataclasses import replace
5from typing import TYPE_CHECKING, Optional
6
7import simpy
8
9from cm_inst import MemOp
10from emu.alu import UINT16_MASK
11from emu.events import (
12 EventCallback, TokenReceived, CellWritten, DeferredRead as DeferredReadEvent,
13 DeferredSatisfied, ResultSent,
14)
15from emu.types import DeferredRead
16from encoding import flit_count, unpack_token
17from sm_mod import Presence, SMCell
18from tokens import CMToken, MonadToken, SMToken, Token
19
20if TYPE_CHECKING:
21 from emu.network import System
22
23logger = logging.getLogger(__name__)
24
25ATOMIC_CELL_LIMIT = 256
26
27
28class StructureMemory:
29 def __init__(
30 self,
31 env: simpy.Environment,
32 sm_id: int,
33 cell_count: int = 512,
34 fifo_capacity: int = 8,
35 tier_boundary: int = 256,
36 on_event: EventCallback | None = None,
37 ):
38 self.env = env
39 self.sm_id = sm_id
40 self.tier_boundary = tier_boundary
41 self.cells: list[SMCell] = [
42 SMCell(Presence.EMPTY, None, None) for _ in range(cell_count)
43 ]
44 self.deferred_read: Optional[DeferredRead] = None
45 self._deferred_satisfied: Optional[simpy.Event] = None
46 self._deferred_cancelled: bool = False
47 self.input_store: simpy.Store = simpy.Store(env, capacity=fifo_capacity)
48 self.route_table: dict[int, simpy.Store] = {}
49 self.t0_store: list[int] = []
50 self.system: Optional[System] = None
51 self._on_event: EventCallback = on_event or (lambda _: None)
52 self._component = f"sm:{sm_id}"
53 self.process = env.process(self._run())
54
55 def _is_t0(self, addr: int) -> bool:
56 """Return True if address falls in the T0 (shared raw storage) region."""
57 return addr >= self.tier_boundary
58
59 def _run(self):
60 while True:
61 token = yield self.input_store.get()
62 yield self.env.timeout(1) # dequeue cycle
63 self._on_event(TokenReceived(time=self.env.now, component=self._component, token=token))
64
65 if not isinstance(token, SMToken):
66 logger.warning(
67 "SM%d: unexpected token type: %s", self.sm_id, type(token)
68 )
69 continue
70
71 addr = token.addr
72 op = token.op
73
74 if self._is_t0(addr):
75 match op:
76 case MemOp.READ:
77 yield from self._handle_t0_read(addr, token)
78 case MemOp.WRITE:
79 self._handle_t0_write(addr, token)
80 yield self.env.timeout(1) # write cycle
81 case MemOp.EXEC:
82 yield from self._handle_exec(addr)
83 case _:
84 logger.warning(
85 "SM%d: I-structure op %s on T0 address %d",
86 self.sm_id, op.name, addr,
87 )
88 continue
89
90 match op:
91 case MemOp.READ:
92 yield from self._handle_read(addr, token)
93 case MemOp.WRITE:
94 yield from self._handle_write(addr, token)
95 case MemOp.CLEAR:
96 self._handle_clear(addr)
97 yield self.env.timeout(1) # process cycle
98 case MemOp.RD_INC:
99 yield from self._handle_atomic(addr, token, delta=1)
100 case MemOp.RD_DEC:
101 yield from self._handle_atomic(addr, token, delta=-1)
102 case MemOp.CMP_SW:
103 yield from self._handle_cas(addr, token)
104 case MemOp.ALLOC:
105 self._handle_alloc(addr)
106 yield self.env.timeout(1) # process cycle
107 case MemOp.FREE:
108 self._handle_clear(addr)
109 yield self.env.timeout(1) # process cycle
110 case MemOp.EXEC:
111 logger.warning(
112 "SM%d: EXEC on T1 address %d (must be T0)",
113 self.sm_id, addr,
114 )
115 case MemOp.SET_PAGE | MemOp.WRITE_IMM | MemOp.RAW_READ | MemOp.EXT:
116 raise NotImplementedError(
117 f"SM{self.sm_id}: {op.name} not yet implemented"
118 )
119 case _:
120 logger.warning("SM%d: unknown op %s", self.sm_id, op)
121
122 def _handle_read(self, addr: int, token: SMToken):
123 cell = self.cells[addr]
124
125 if cell.pres == Presence.FULL:
126 yield self.env.timeout(1) # process cycle
127 yield from self._send_result(token.ret, cell.data_l)
128 return
129
130 if self.deferred_read is not None:
131 self.env.process(self._wait_and_retry_read(addr, token))
132 return
133
134 self.deferred_read = DeferredRead(cell_addr=addr, return_route=token.ret)
135 old_pres = cell.pres
136 cell.pres = Presence.WAITING
137 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.WAITING))
138 self._on_event(DeferredReadEvent(time=self.env.now, component=self._component, addr=addr))
139 yield self.env.timeout(1) # process cycle (set WAITING)
140
141 def _wait_and_retry_read(self, addr: int, token: SMToken):
142 self._deferred_satisfied = self.env.event()
143 yield self._deferred_satisfied
144 self._deferred_satisfied = None
145 if self._deferred_cancelled:
146 self._deferred_cancelled = False
147 return
148 yield from self._handle_read(addr, token)
149
150 def _handle_write(self, addr: int, token: SMToken):
151 cell = self.cells[addr]
152
153 if (
154 cell.pres == Presence.WAITING
155 and self.deferred_read is not None
156 and self.deferred_read.cell_addr == addr
157 ):
158 return_route = self.deferred_read.return_route
159 self.deferred_read = None
160 old_pres = cell.pres
161 cell.pres = Presence.FULL
162 cell.data_l = token.data
163 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.FULL))
164 self._on_event(DeferredSatisfied(time=self.env.now, component=self._component, addr=addr, data=token.data))
165 if self._deferred_satisfied is not None:
166 self._deferred_satisfied.succeed()
167 yield self.env.timeout(1) # process cycle (write + satisfy)
168 yield from self._send_result(return_route, token.data)
169 return
170
171 old_pres = cell.pres
172 cell.pres = Presence.FULL
173 cell.data_l = token.data
174 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.FULL))
175 yield self.env.timeout(1) # write cycle
176
177 def _handle_clear(self, addr: int):
178 cell = self.cells[addr]
179 old_pres = cell.pres
180 cell.pres = Presence.EMPTY
181 cell.data_l = None
182 cell.data_r = None
183 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.EMPTY))
184
185 if self.deferred_read is not None and self.deferred_read.cell_addr == addr:
186 self.deferred_read = None
187 self._deferred_cancelled = True
188 if self._deferred_satisfied is not None:
189 self._deferred_satisfied.succeed()
190
191 def _handle_alloc(self, addr: int):
192 cell = self.cells[addr]
193 if cell.pres == Presence.EMPTY:
194 old_pres = cell.pres
195 cell.pres = Presence.RESERVED
196 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=old_pres, new_pres=Presence.RESERVED))
197
198 def _handle_atomic(self, addr: int, token: SMToken, delta: int):
199 if addr >= ATOMIC_CELL_LIMIT:
200 logger.warning(
201 "SM%d: atomic op on cell %d >= %d", self.sm_id, addr, ATOMIC_CELL_LIMIT
202 )
203 return
204
205 cell = self.cells[addr]
206 if cell.pres != Presence.FULL:
207 logger.warning("SM%d: atomic op on non-FULL cell %d", self.sm_id, addr)
208 return
209
210 old_value = cell.data_l if cell.data_l is not None else 0
211 cell.data_l = (old_value + delta) & UINT16_MASK
212 self._on_event(CellWritten(time=self.env.now, component=self._component, addr=addr, old_pres=Presence.FULL, new_pres=Presence.FULL))
213 yield self.env.timeout(1) # read-modify-write cycle
214 yield from self._send_result(token.ret, old_value)
215
216 def _handle_cas(self, addr: int, token: SMToken):
217 if addr >= ATOMIC_CELL_LIMIT:
218 logger.warning(
219 "SM%d: CAS on cell %d >= %d", self.sm_id, addr, ATOMIC_CELL_LIMIT
220 )
221 return
222
223 cell = self.cells[addr]
224 if cell.pres != Presence.FULL:
225 logger.warning("SM%d: CAS on non-FULL cell %d", self.sm_id, addr)
226 return
227
228 old_value = cell.data_l if cell.data_l is not None else 0
229 expected = token.flags if token.flags is not None else 0
230 if old_value == expected:
231 cell.data_l = token.data
232 yield self.env.timeout(1) # compare-and-swap cycle
233 yield from self._send_result(token.ret, old_value)
234
235 def _send_result(self, return_route: CMToken, data: int):
236 result = replace(return_route, data=data)
237 self._on_event(ResultSent(time=self.env.now, component=self._component, token=result))
238 yield self.env.timeout(1) # response/delivery cycle (inline, blocks SM)
239 yield self.route_table[return_route.target].put(result)
240
241 def _handle_t0_read(self, addr: int, token: SMToken):
242 """T0 READ: return stored data immediately, no presence tracking or deferral."""
243 if token.ret is None:
244 return
245 t0_idx = addr - self.tier_boundary
246 yield self.env.timeout(1) # process cycle
247 if t0_idx < len(self.t0_store):
248 data = self.t0_store[t0_idx]
249 yield from self._send_result(token.ret, data)
250 else:
251 yield from self._send_result(token.ret, 0)
252
253 def _handle_t0_write(self, addr: int, token: SMToken):
254 """T0 WRITE: store raw integer data without presence checking.
255
256 Stores token.data (an int) into t0_store at the T0-relative index.
257 For EXEC bootstrap, pre-load t0_store with packed flit sequences
258 (via pack_token()) before simulation starts. T0 is now list[int] with
259 token serialisation/deserialisation via pack_token()/unpack_token().
260 """
261 t0_idx = addr - self.tier_boundary
262 while len(self.t0_store) <= t0_idx:
263 self.t0_store.append(0)
264 self.t0_store[t0_idx] = token.data
265
266 def _handle_exec(self, addr: int):
267 """EXEC: read raw int sequences from T0, reconstitute tokens, inject into network.
268
269 Uses flit_count() to determine packet boundaries, unpack_token() to reconstitute
270 tokens from raw int sequences. Uses send() to properly trigger SimPy Store.put()
271 events, ensuring tokens wake up pending get() operations in target PEs/SMs.
272 """
273 if self.system is None:
274 logger.warning("SM%d: EXEC but no system reference", self.sm_id)
275 return
276 t0_idx = addr - self.tier_boundary
277 if t0_idx >= len(self.t0_store):
278 return
279 yield self.env.timeout(1) # process cycle
280
281 pos = t0_idx
282 while pos < len(self.t0_store):
283 header = self.t0_store[pos]
284 count = flit_count(header)
285 if pos + count > len(self.t0_store):
286 logger.warning("SM%d: truncated packet at T0[%d], stopping EXEC", self.sm_id, pos)
287 break
288 flits = self.t0_store[pos:pos + count]
289 try:
290 token = unpack_token(flits)
291 except (ValueError, KeyError):
292 logger.warning("SM%d: failed to unpack token at T0[%d], stopping EXEC", self.sm_id, pos)
293 break
294 yield from self.system.send(token)
295 yield self.env.timeout(1) # per-token injection cycle
296 pos += count