OR-1 dataflow CPU sketch
1"""Tests for encoding.py boundary functions.
2
3Verifies:
4- pe-frame-redesign.AC2.2: pack_instruction / unpack_instruction round-trip
5- Instruction encoding/decoding via hardware word format
6- Flit 1 packing/unpacking for FrameDest
7- Token packing/unpacking for T0 storage
8"""
9
10import pytest
11from hypothesis import given, example, strategies as st
12
13from cm_inst import (
14 ArithOp, FrameDest, FrameOp, Instruction, LogicOp, MemOp,
15 OutputStyle, Port, RoutingOp, TokenKind,
16)
17from encoding import (
18 _decode_mode, _decode_opcode, _encode_mode, _encode_opcode,
19 flit_count, pack_flit1, pack_instruction, pack_token, unpack_flit1,
20 unpack_instruction, unpack_token,
21)
22from tokens import DyadToken, MonadToken, SMToken
23
24
25# ============================================================================
26# Mode Encoding/Decoding
27# ============================================================================
28
29class TestModeEncoding:
30 """Mode field encoding matches design-notes table."""
31
32 def test_encode_mode_inherit_single_no_const(self):
33 """INHERIT + single output + no const = mode 0."""
34 mode = _encode_mode(OutputStyle.INHERIT, False, 1)
35 assert mode == 0b000
36
37 def test_encode_mode_inherit_single_with_const(self):
38 """INHERIT + single output + with const = mode 1."""
39 mode = _encode_mode(OutputStyle.INHERIT, True, 1)
40 assert mode == 0b001
41
42 def test_encode_mode_inherit_fanout_no_const(self):
43 """INHERIT + fan-out + no const = mode 2."""
44 mode = _encode_mode(OutputStyle.INHERIT, False, 2)
45 assert mode == 0b010
46
47 def test_encode_mode_inherit_fanout_with_const(self):
48 """INHERIT + fan-out + with const = mode 3."""
49 mode = _encode_mode(OutputStyle.INHERIT, True, 2)
50 assert mode == 0b011
51
52 def test_encode_mode_change_tag_no_const(self):
53 """CHANGE_TAG + no const = mode 4."""
54 mode = _encode_mode(OutputStyle.CHANGE_TAG, False, 1)
55 assert mode == 0b100
56
57 def test_encode_mode_change_tag_with_const(self):
58 """CHANGE_TAG + with const = mode 5."""
59 mode = _encode_mode(OutputStyle.CHANGE_TAG, True, 1)
60 assert mode == 0b101
61
62 def test_encode_mode_sink_no_const(self):
63 """SINK + no const = mode 6."""
64 mode = _encode_mode(OutputStyle.SINK, False, 0)
65 assert mode == 0b110
66
67 def test_encode_mode_sink_with_const(self):
68 """SINK + with const = mode 7."""
69 mode = _encode_mode(OutputStyle.SINK, True, 0)
70 assert mode == 0b111
71
72 def test_encode_mode_invalid_inherit_dest_count(self):
73 """INHERIT requires dest_count 1 or 2."""
74 with pytest.raises(ValueError):
75 _encode_mode(OutputStyle.INHERIT, False, 3)
76
77 def test_encode_mode_invalid_output_style(self):
78 """Invalid OutputStyle raises ValueError."""
79 with pytest.raises(ValueError):
80 _encode_mode(42, False, 1) # type: ignore
81
82 def test_encode_mode_change_tag_invalid_dest_count(self):
83 """CHANGE_TAG requires dest_count == 1."""
84 with pytest.raises(ValueError, match="CHANGE_TAG requires dest_count == 1"):
85 _encode_mode(OutputStyle.CHANGE_TAG, False, 0)
86 with pytest.raises(ValueError, match="CHANGE_TAG requires dest_count == 1"):
87 _encode_mode(OutputStyle.CHANGE_TAG, False, 2)
88
89 def test_encode_mode_sink_invalid_dest_count(self):
90 """SINK requires dest_count == 0."""
91 with pytest.raises(ValueError, match="SINK requires dest_count == 0"):
92 _encode_mode(OutputStyle.SINK, False, 1)
93 with pytest.raises(ValueError, match="SINK requires dest_count == 0"):
94 _encode_mode(OutputStyle.SINK, False, 2)
95
96
97class TestModeDecoding:
98 """Mode field decoding extracts OutputStyle, has_const, dest_count."""
99
100 def test_decode_mode_0(self):
101 """Mode 0 decodes to INHERIT, no const, single output."""
102 output, has_const, dest_count = _decode_mode(0b000)
103 assert output == OutputStyle.INHERIT
104 assert has_const is False
105 assert dest_count == 1
106
107 def test_decode_mode_1(self):
108 """Mode 1 decodes to INHERIT, with const, single output."""
109 output, has_const, dest_count = _decode_mode(0b001)
110 assert output == OutputStyle.INHERIT
111 assert has_const is True
112 assert dest_count == 1
113
114 def test_decode_mode_2(self):
115 """Mode 2 decodes to INHERIT, no const, fan-out."""
116 output, has_const, dest_count = _decode_mode(0b010)
117 assert output == OutputStyle.INHERIT
118 assert has_const is False
119 assert dest_count == 2
120
121 def test_decode_mode_3(self):
122 """Mode 3 decodes to INHERIT, with const, fan-out."""
123 output, has_const, dest_count = _decode_mode(0b011)
124 assert output == OutputStyle.INHERIT
125 assert has_const is True
126 assert dest_count == 2
127
128 def test_decode_mode_4(self):
129 """Mode 4 decodes to CHANGE_TAG, no const."""
130 output, has_const, dest_count = _decode_mode(0b100)
131 assert output == OutputStyle.CHANGE_TAG
132 assert has_const is False
133 assert dest_count == 1 # nominal
134
135 def test_decode_mode_5(self):
136 """Mode 5 decodes to CHANGE_TAG, with const."""
137 output, has_const, dest_count = _decode_mode(0b101)
138 assert output == OutputStyle.CHANGE_TAG
139 assert has_const is True
140 assert dest_count == 1 # nominal
141
142 def test_decode_mode_6(self):
143 """Mode 6 decodes to SINK, no const."""
144 output, has_const, dest_count = _decode_mode(0b110)
145 assert output == OutputStyle.SINK
146 assert has_const is False
147 assert dest_count == 0
148
149 def test_decode_mode_7(self):
150 """Mode 7 decodes to SINK, with const."""
151 output, has_const, dest_count = _decode_mode(0b111)
152 assert output == OutputStyle.SINK
153 assert has_const is True
154 assert dest_count == 0
155
156
157class TestModeRoundTrip:
158 """Mode encoding round-trips correctly."""
159
160 def test_mode_roundtrip_all_combinations(self):
161 """All valid mode combinations round-trip."""
162 for mode in range(8):
163 output, has_const, dest_count = _decode_mode(mode)
164 encoded = _encode_mode(output, has_const, dest_count)
165 assert encoded == mode
166
167
168# ============================================================================
169# Opcode Encoding/Decoding
170# ============================================================================
171
172class TestOpcodeEncoding:
173 """Opcode encoding and type_bit."""
174
175 def test_encode_alu_opcode(self):
176 """ALU opcodes use type_bit=0."""
177 type_bit, opcode = _encode_opcode(ArithOp.ADD)
178 assert type_bit == 0
179 assert opcode == 0
180
181 def test_encode_arith_ops(self):
182 """ArithOp values encode correctly."""
183 for op in [ArithOp.ADD, ArithOp.SUB, ArithOp.INC, ArithOp.DEC,
184 ArithOp.SHL, ArithOp.SHR, ArithOp.ASR]:
185 type_bit, opcode = _encode_opcode(op)
186 assert type_bit == 0
187 assert opcode == int(op)
188
189 def test_encode_logic_ops(self):
190 """LogicOp values encode correctly."""
191 for op in LogicOp:
192 type_bit, opcode = _encode_opcode(op)
193 assert type_bit == 0
194 assert opcode == int(op)
195
196 def test_encode_routing_ops(self):
197 """RoutingOp values encode correctly."""
198 for op in RoutingOp:
199 type_bit, opcode = _encode_opcode(op)
200 assert type_bit == 0
201 assert opcode == int(op)
202
203 def test_encode_memop(self):
204 """MemOp opcodes use type_bit=1."""
205 for op in MemOp:
206 type_bit, opcode = _encode_opcode(op)
207 assert type_bit == 1
208 assert opcode == int(op)
209
210
211class TestOpcodeDecoding:
212 """Opcode decoding from type_bit + raw opcode."""
213
214 def test_decode_alu_opcode(self):
215 """ALU opcode decodes correctly."""
216 op = _decode_opcode(0, 0) # ArithOp.ADD
217 assert op == ArithOp.ADD
218
219 def test_decode_logic_opcode(self):
220 """LogicOp decodes correctly."""
221 op = _decode_opcode(0, 11) # LogicOp.EQ
222 assert op == LogicOp.EQ
223
224 def test_decode_routing_opcode(self):
225 """RoutingOp decodes correctly."""
226 op = _decode_opcode(0, 16) # RoutingOp.BREQ
227 assert op == RoutingOp.BREQ
228
229 def test_decode_memop(self):
230 """MemOp decodes correctly."""
231 op = _decode_opcode(1, 0) # MemOp.READ
232 assert op == MemOp.READ
233
234
235class TestOpcodeRoundTrip:
236 """Opcode encoding round-trips correctly."""
237
238 def test_all_aluops_roundtrip(self):
239 """All ALU ops round-trip."""
240 for op in list(ArithOp) + list(LogicOp) + list(RoutingOp):
241 type_bit, opcode = _encode_opcode(op)
242 decoded = _decode_opcode(type_bit, opcode)
243 assert decoded == op
244
245 def test_all_memops_roundtrip(self):
246 """All MemOps round-trip."""
247 for op in MemOp:
248 type_bit, opcode = _encode_opcode(op)
249 decoded = _decode_opcode(type_bit, opcode)
250 assert decoded == op
251
252
253# ============================================================================
254# Instruction Pack/Unpack
255# ============================================================================
256
257class TestInstructionPacking:
258 """pack_instruction encodes to 16-bit word."""
259
260 def test_pack_simple_add(self):
261 """Simple ADD instruction packs correctly."""
262 inst = Instruction(
263 opcode=ArithOp.ADD,
264 output=OutputStyle.INHERIT,
265 has_const=False,
266 dest_count=2,
267 wide=False,
268 fref=0,
269 )
270 word = pack_instruction(inst)
271 # type_bit=0, opcode=0, mode=2, wide=0, fref=0
272 # [0][00000][010][0][000000] = 0x0100
273 assert word == 0x0100
274
275 def test_pack_read_with_const(self):
276 """READ instruction with const packs correctly."""
277 inst = Instruction(
278 opcode=MemOp.READ,
279 output=OutputStyle.INHERIT,
280 has_const=True,
281 dest_count=1,
282 wide=False,
283 fref=10,
284 )
285 word = pack_instruction(inst)
286 # type_bit=1, opcode=0, mode=1, wide=0, fref=10
287 # [1][00000][001][0][001010]
288 expected = (1 << 15) | (0 << 10) | (1 << 7) | (0 << 6) | 10
289 assert word == expected
290
291 def test_pack_wide_instruction(self):
292 """Wide instructions pack with wide_bit=1."""
293 inst = Instruction(
294 opcode=ArithOp.ADD,
295 output=OutputStyle.INHERIT,
296 has_const=False,
297 dest_count=2,
298 wide=True,
299 fref=0,
300 )
301 word = pack_instruction(inst)
302 # wide_bit should be set
303 assert (word >> 6) & 1 == 1
304
305 def test_pack_max_fref(self):
306 """Maximum fref (63) packs correctly."""
307 inst = Instruction(
308 opcode=ArithOp.ADD,
309 output=OutputStyle.INHERIT,
310 has_const=False,
311 dest_count=1,
312 wide=False,
313 fref=63,
314 )
315 word = pack_instruction(inst)
316 assert (word & 0x3F) == 63
317
318 def test_pack_sink_output(self):
319 """SINK output style encodes to mode 6 or 7."""
320 inst = Instruction(
321 opcode=ArithOp.ADD,
322 output=OutputStyle.SINK,
323 has_const=False,
324 dest_count=0,
325 wide=False,
326 fref=0,
327 )
328 word = pack_instruction(inst)
329 mode = (word >> 7) & 0x7
330 assert mode == 0b110
331
332
333class TestInstructionUnpacking:
334 """unpack_instruction decodes 16-bit word to Instruction."""
335
336 def test_unpack_simple_add(self):
337 """Unpacking simple ADD gives correct Instruction."""
338 word = 0x0100 # mode=2 (INHERIT, single, no const) -> actually dest_count=2
339 inst = unpack_instruction(word)
340 assert inst.opcode == ArithOp.ADD
341 assert inst.output == OutputStyle.INHERIT
342 assert inst.has_const is False
343 assert inst.dest_count == 2
344 assert inst.wide is False
345 assert inst.fref == 0
346
347 def test_unpack_read_with_const(self):
348 """Unpacking READ with const gives correct Instruction."""
349 word = (1 << 15) | (0 << 10) | (1 << 7) | (0 << 6) | 10
350 inst = unpack_instruction(word)
351 assert isinstance(inst.opcode, MemOp)
352 assert inst.opcode == MemOp.READ
353 assert inst.has_const is True
354 assert inst.dest_count == 1
355 assert inst.fref == 10
356
357 def test_unpack_wide_bit(self):
358 """Unpacking preserves wide flag."""
359 word = 0x0080 | (1 << 6) # Set wide_bit
360 inst = unpack_instruction(word)
361 assert inst.wide is True
362
363 def test_unpack_change_tag_output(self):
364 """Unpacking CHANGE_TAG mode gives correct output style."""
365 word = (0 << 15) | (16 << 10) | (4 << 7) # BREQ, mode 4
366 inst = unpack_instruction(word)
367 assert inst.output == OutputStyle.CHANGE_TAG
368
369
370class TestInstructionRoundTrip:
371 """pack_instruction and unpack_instruction round-trip."""
372
373 @given(
374 opcode=st.sampled_from(
375 list(ArithOp) + list(LogicOp) + list(RoutingOp) + list(MemOp)
376 ),
377 output=st.sampled_from(list(OutputStyle)),
378 has_const=st.booleans(),
379 dest_count=st.integers(min_value=0, max_value=2),
380 wide=st.booleans(),
381 fref=st.integers(min_value=0, max_value=63),
382 )
383 # Boundary examples for instruction encoding
384 @example(opcode=ArithOp.ADD, output=OutputStyle.INHERIT, has_const=False, dest_count=1, wide=False, fref=0)
385 @example(opcode=ArithOp.ADD, output=OutputStyle.INHERIT, has_const=False, dest_count=2, wide=False, fref=63)
386 @example(opcode=MemOp.READ, output=OutputStyle.INHERIT, has_const=True, dest_count=1, wide=True, fref=0)
387 @example(opcode=RoutingOp.EXTRACT_TAG, output=OutputStyle.CHANGE_TAG, has_const=False, dest_count=1, wide=False, fref=31)
388 @example(opcode=ArithOp.ADD, output=OutputStyle.SINK, has_const=False, dest_count=0, wide=False, fref=0)
389 @example(opcode=ArithOp.ADD, output=OutputStyle.SINK, has_const=True, dest_count=0, wide=False, fref=63)
390 def test_roundtrip_all_valid_combinations(self, opcode, output, has_const, dest_count, wide, fref):
391 """All valid instruction combinations round-trip."""
392 # Skip invalid combinations
393 if output == OutputStyle.INHERIT and dest_count not in (1, 2):
394 return
395 if output == OutputStyle.SINK and dest_count != 0:
396 return
397 if output == OutputStyle.CHANGE_TAG:
398 dest_count = 1 # CHANGE_TAG always decodes to dest_count=1
399
400 inst = Instruction(
401 opcode=opcode,
402 output=output,
403 has_const=has_const,
404 dest_count=dest_count,
405 wide=wide,
406 fref=fref,
407 )
408 word = pack_instruction(inst)
409 unpacked = unpack_instruction(word)
410
411 assert unpacked.opcode == inst.opcode
412 assert unpacked.output == inst.output
413 assert unpacked.has_const == inst.has_const
414 assert unpacked.dest_count == inst.dest_count
415 assert unpacked.wide == inst.wide
416 assert unpacked.fref == inst.fref
417
418
419# ============================================================================
420# Flit 1 Packing/Unpacking
421# ============================================================================
422
423class TestFlit1Packing:
424 """pack_flit1 encodes FrameDest to 16-bit flit 1."""
425
426 def test_pack_dyadic_dest(self):
427 """DYADIC destination packs correctly."""
428 dest = FrameDest(
429 target_pe=1,
430 offset=32,
431 act_id=3,
432 port=Port.L,
433 token_kind=TokenKind.DYADIC,
434 )
435 flit1 = pack_flit1(dest)
436 # Format: [00][port:1][PE:2][offset:8][act_id:3]
437 # port=0, PE=01, offset=00100000, act_id=011
438 assert (flit1 >> 14) == 0b00
439 assert (flit1 >> 13) & 1 == 0 # Port.L = 0
440 assert (flit1 >> 11) & 0x3 == 1 # PE = 1
441 assert (flit1 >> 3) & 0xFF == 32 # offset
442 assert flit1 & 0x7 == 3 # act_id
443
444 def test_pack_monadic_dest(self):
445 """MONADIC destination packs correctly."""
446 dest = FrameDest(
447 target_pe=2,
448 offset=16,
449 act_id=5,
450 port=Port.R, # ignored for MONADIC
451 token_kind=TokenKind.MONADIC,
452 )
453 flit1 = pack_flit1(dest)
454 # Format: [010][PE:2][offset:8][act_id:3]
455 assert (flit1 >> 13) == 0b010
456 assert (flit1 >> 11) & 0x3 == 2 # PE = 2
457 assert (flit1 >> 3) & 0xFF == 16 # offset
458 assert flit1 & 0x7 == 5 # act_id
459
460 def test_pack_inline_dest(self):
461 """INLINE destination packs correctly."""
462 dest = FrameDest(
463 target_pe=3,
464 offset=64,
465 act_id=0, # ignored for INLINE
466 port=Port.L,
467 token_kind=TokenKind.INLINE,
468 )
469 flit1 = pack_flit1(dest)
470 # Format: [011][PE:2][10][offset:7][spare:2]
471 assert (flit1 >> 13) == 0b011
472 assert (flit1 >> 11) & 0x3 == 3 # PE = 3
473 assert (flit1 >> 9) & 0x3 == 0b10
474 assert (flit1 >> 2) & 0x7F == 64 # offset
475
476
477class TestFlit1Unpacking:
478 """unpack_flit1 decodes 16-bit flit 1 to FrameDest."""
479
480 def test_unpack_dyadic(self):
481 """DYADIC flit unpacks correctly."""
482 flit1 = (0b00 << 14) | (1 << 13) | (2 << 11) | (48 << 3) | 4
483 dest = unpack_flit1(flit1)
484 assert dest.token_kind == TokenKind.DYADIC
485 assert dest.port == Port.R
486 assert dest.target_pe == 2
487 assert dest.offset == 48
488 assert dest.act_id == 4
489
490 def test_unpack_monadic(self):
491 """MONADIC flit unpacks correctly."""
492 flit1 = (0b010 << 13) | (1 << 11) | (20 << 3) | 2
493 dest = unpack_flit1(flit1)
494 assert dest.token_kind == TokenKind.MONADIC
495 assert dest.port == Port.L
496 assert dest.target_pe == 1
497 assert dest.offset == 20
498 assert dest.act_id == 2
499
500 def test_unpack_inline(self):
501 """INLINE flit unpacks correctly."""
502 flit1 = (0b011 << 13) | (3 << 11) | (0b10 << 9) | (50 << 2)
503 dest = unpack_flit1(flit1)
504 assert dest.token_kind == TokenKind.INLINE
505 assert dest.port == Port.L
506 assert dest.target_pe == 3
507 assert dest.offset == 50
508 assert dest.act_id == 0
509
510
511class TestFlit1RoundTrip:
512 """pack_flit1 and unpack_flit1 round-trip."""
513
514 @given(
515 target_pe=st.integers(min_value=0, max_value=3),
516 offset_dyadic=st.integers(min_value=0, max_value=255),
517 offset_monadic=st.integers(min_value=0, max_value=255),
518 offset_inline=st.integers(min_value=0, max_value=127),
519 act_id=st.integers(min_value=0, max_value=7),
520 port=st.sampled_from(list(Port)),
521 )
522 # Boundary examples for flit1 encoding
523 @example(target_pe=0, offset_dyadic=0, offset_monadic=0, offset_inline=0, act_id=0, port=Port.L)
524 @example(target_pe=3, offset_dyadic=255, offset_monadic=255, offset_inline=127, act_id=7, port=Port.R)
525 @example(target_pe=1, offset_dyadic=128, offset_monadic=128, offset_inline=64, act_id=3, port=Port.L)
526 @example(target_pe=2, offset_dyadic=1, offset_monadic=1, offset_inline=1, act_id=1, port=Port.R)
527 def test_roundtrip_all_token_kinds(self, target_pe, offset_dyadic, offset_monadic, offset_inline, act_id, port):
528 """All FrameDest combinations round-trip."""
529 for token_kind, offset in [
530 (TokenKind.DYADIC, offset_dyadic),
531 (TokenKind.MONADIC, offset_monadic),
532 (TokenKind.INLINE, offset_inline),
533 ]:
534 dest = FrameDest(
535 target_pe=target_pe,
536 offset=offset,
537 act_id=act_id,
538 port=port,
539 token_kind=token_kind,
540 )
541 flit1 = pack_flit1(dest)
542 unpacked = unpack_flit1(flit1)
543
544 assert unpacked.target_pe == dest.target_pe
545 assert unpacked.offset == dest.offset
546 assert unpacked.token_kind == dest.token_kind
547 # act_id should round-trip for all token kinds
548 if token_kind != TokenKind.INLINE:
549 assert unpacked.act_id == dest.act_id
550 if token_kind == TokenKind.DYADIC:
551 assert unpacked.port == dest.port
552 else:
553 assert unpacked.port == Port.L
554
555
556# ============================================================================
557# Token Packing/Unpacking
558# ============================================================================
559
560class TestTokenPacking:
561 """pack_token encodes tokens to flit sequences."""
562
563 def test_pack_dyadic_token(self):
564 """DyadToken packs to 2 flits."""
565 token = DyadToken(
566 target=1,
567 offset=32,
568 act_id=3,
569 data=0x1234,
570 port=Port.L,
571 )
572 flits = pack_token(token)
573 assert len(flits) == 2
574 assert flits[1] == 0x1234
575
576 def test_pack_monadic_token_normal(self):
577 """Monadic (normal) token packs to 2 flits."""
578 token = MonadToken(
579 target=2,
580 offset=16,
581 act_id=5,
582 data=0x5678,
583 inline=False,
584 )
585 flits = pack_token(token)
586 assert len(flits) == 2
587 assert flits[1] == 0x5678
588
589 def test_pack_monadic_token_inline(self):
590 """Monadic (inline) token packs to 1 flit."""
591 token = MonadToken(
592 target=1,
593 offset=64,
594 act_id=0,
595 data=0, # ignored for inline
596 inline=True,
597 )
598 flits = pack_token(token)
599 assert len(flits) == 1
600
601 def test_pack_smtoken(self):
602 """SMToken packs to 2 flits."""
603 token = SMToken(
604 target=3,
605 addr=100,
606 op=MemOp.READ,
607 flags=None,
608 data=0xABCD,
609 ret=None,
610 )
611 flits = pack_token(token)
612 assert len(flits) == 2
613 assert (flits[0] >> 15) & 1 == 1 # SM token marker
614 assert flits[1] == 0xABCD
615
616 def test_pack_smtoken_tier2_memop_raises(self):
617 """SMToken with Tier 2 MemOp (value > 7) raises ValueError."""
618 # Tier 2 MemOps: RD_DEC=8, CMP_SW=9, RAW_READ=10, SET_PAGE=11, WRITE_IMM=12
619 # These cannot fit in 3 bits and pack_token should reject them
620 token = SMToken(
621 target=0,
622 addr=50,
623 op=MemOp.RD_DEC, # value = 8, exceeds 3-bit limit
624 flags=None,
625 data=100,
626 ret=None,
627 )
628 with pytest.raises(ValueError, match="exceeds 3-bit encoding limit"):
629 pack_token(token)
630
631 def test_pack_smtoken_all_tier2_memops_raise(self):
632 """All Tier 2 MemOps raise ValueError on pack."""
633 tier2_ops = [MemOp.RD_DEC, MemOp.CMP_SW, MemOp.RAW_READ, MemOp.SET_PAGE, MemOp.WRITE_IMM]
634 for op in tier2_ops:
635 if int(op) > 7: # Only test if it's actually > 7
636 token = SMToken(
637 target=0,
638 addr=0,
639 op=op,
640 flags=None,
641 data=0,
642 ret=None,
643 )
644 with pytest.raises(ValueError, match="not yet supported"):
645 pack_token(token)
646
647
648class TestTokenUnpacking:
649 """unpack_token decodes flit sequences to tokens."""
650
651 def test_unpack_dyadic_token(self):
652 """DyadToken unpacks from 2 flits."""
653 token = DyadToken(
654 target=1,
655 offset=32,
656 act_id=3,
657 data=0x1234,
658 port=Port.L,
659 )
660 flits = pack_token(token)
661 unpacked = unpack_token(flits)
662
663 assert isinstance(unpacked, DyadToken)
664 assert unpacked.target == 1
665 assert unpacked.offset == 32
666 assert unpacked.act_id == 3
667 assert unpacked.data == 0x1234
668 assert unpacked.port == Port.L
669
670 def test_unpack_monadic_normal(self):
671 """Monadic (normal) token unpacks correctly."""
672 token = MonadToken(
673 target=2,
674 offset=16,
675 act_id=5,
676 data=0x5678,
677 inline=False,
678 )
679 flits = pack_token(token)
680 unpacked = unpack_token(flits)
681
682 assert isinstance(unpacked, MonadToken)
683 assert unpacked.target == 2
684 assert unpacked.offset == 16
685 assert unpacked.act_id == 5
686 assert unpacked.data == 0x5678
687 assert unpacked.inline is False
688
689 def test_unpack_monadic_inline(self):
690 """Monadic (inline) token unpacks correctly."""
691 token = MonadToken(
692 target=1,
693 offset=64,
694 act_id=0,
695 data=0,
696 inline=True,
697 )
698 flits = pack_token(token)
699 unpacked = unpack_token(flits)
700
701 assert isinstance(unpacked, MonadToken)
702 assert unpacked.inline is True
703 assert unpacked.target == 1
704
705 def test_unpack_smtoken(self):
706 """SMToken unpacks correctly."""
707 token = SMToken(
708 target=3,
709 addr=100,
710 op=MemOp.READ,
711 flags=None,
712 data=0xABCD,
713 ret=None,
714 )
715 flits = pack_token(token)
716 unpacked = unpack_token(flits)
717
718 assert isinstance(unpacked, SMToken)
719 assert unpacked.target == 3
720 assert unpacked.addr == 100
721 assert unpacked.op == MemOp.READ
722 assert unpacked.data == 0xABCD
723 assert unpacked.ret is None # SMToken.ret not preserved through pack/unpack
724
725
726class TestTokenRoundTrip:
727 """pack_token and unpack_token round-trip."""
728
729 def test_dyadic_roundtrip(self):
730 """DyadToken round-trips."""
731 token = DyadToken(
732 target=1,
733 offset=32,
734 act_id=3,
735 data=0x1234,
736 port=Port.R,
737 )
738 flits = pack_token(token)
739 unpacked = unpack_token(flits)
740
741 assert isinstance(unpacked, DyadToken)
742 assert unpacked.target == token.target
743 assert unpacked.offset == token.offset
744 assert unpacked.act_id == token.act_id
745 assert unpacked.data == token.data
746 assert unpacked.port == token.port
747
748 def test_monadic_roundtrip(self):
749 """MonadToken round-trips."""
750 token = MonadToken(
751 target=2,
752 offset=16,
753 act_id=5,
754 data=0x5678,
755 inline=False,
756 )
757 flits = pack_token(token)
758 unpacked = unpack_token(flits)
759
760 assert isinstance(unpacked, MonadToken)
761 assert unpacked.target == token.target
762 assert unpacked.offset == token.offset
763 assert unpacked.act_id == token.act_id
764 assert unpacked.data == token.data
765 assert unpacked.inline == token.inline
766
767 def test_smtoken_roundtrip_except_ret(self):
768 """SMToken round-trips (except ret field)."""
769 token = SMToken(
770 target=3,
771 addr=100,
772 op=MemOp.READ,
773 flags=None,
774 data=0xABCD,
775 ret=None, # ret is not preserved
776 )
777 flits = pack_token(token)
778 unpacked = unpack_token(flits)
779
780 assert isinstance(unpacked, SMToken)
781 assert unpacked.target == token.target
782 assert unpacked.addr == token.addr
783 assert unpacked.op == token.op
784 assert unpacked.data == token.data
785 assert unpacked.ret is None
786
787
788# ============================================================================
789# Flit Count
790# ============================================================================
791
792class TestFlitCount:
793 """flit_count determines packet size from flit 1."""
794
795 def test_flit_count_dyadic(self):
796 """Dyadic tokens are 2 flits."""
797 dest = FrameDest(
798 target_pe=1,
799 offset=32,
800 act_id=3,
801 port=Port.L,
802 token_kind=TokenKind.DYADIC,
803 )
804 flit1 = pack_flit1(dest)
805 assert flit_count(flit1) == 2
806
807 def test_flit_count_monadic_normal(self):
808 """Monadic (normal) tokens are 2 flits."""
809 dest = FrameDest(
810 target_pe=1,
811 offset=32,
812 act_id=3,
813 port=Port.L,
814 token_kind=TokenKind.MONADIC,
815 )
816 flit1 = pack_flit1(dest)
817 assert flit_count(flit1) == 2
818
819 def test_flit_count_monadic_inline(self):
820 """Monadic (inline) tokens are 1 flit."""
821 dest = FrameDest(
822 target_pe=1,
823 offset=32,
824 act_id=0,
825 port=Port.L,
826 token_kind=TokenKind.INLINE,
827 )
828 flit1 = pack_flit1(dest)
829 assert flit_count(flit1) == 1
830
831 def test_flit_count_smtoken(self):
832 """SM tokens are 2 flits."""
833 # SM token: [1][SM_id:2][op:3][addr:10]
834 flit1 = (1 << 15) | (0 << 13) | (0 << 10) | 100
835 assert flit_count(flit1) == 2