OR-1 dataflow CPU sketch
at main 269 lines 8.5 kB view raw
1"""Tests for seed constants, inline constants, and non-commutative validation. 2 3Tests verify: 4- Seed syntax (const N |> &dest) produces seed tokens, not IRAM slots 5- Inline const syntax (&foo <| add 7) is equivalent to (&foo <| add, 7) 6- Non-commutative ops with IRAM const require explicit port on incoming edges 7- Seed tokens match target instruction arity (DyadToken for dyadic, MonadToken for monadic) 8- Triggerable constants (&c <| const, 7) still work as before 9""" 10 11import simpy 12 13from asm import assemble, run_pipeline 14from asm.codegen import generate_direct 15from asm.errors import ErrorSeverity 16from asm.ir import ( 17 IRGraph, IRNode, IREdge, SystemConfig, SourceLoc, ResolvedDest, 18) 19from asm.opcodes import is_dyadic 20from cm_inst import OutputStyle, ArithOp, LogicOp, RoutingOp, Port 21from emu import build_topology 22from tokens import DyadToken, MonadToken 23 24 25class TestSeedSyntax: 26 """Tests for the `const N |> &dest` seed syntax.""" 27 28 def test_seed_no_iram_slot(self): 29 """Seed constant nodes don't consume IRAM slots.""" 30 source = """\ 31@system pe=1, sm=1 32&target|pe0 <| pass 33const 42 |> &target 34""" 35 graph = run_pipeline(source) 36 errors = [e for e in graph.errors] 37 assert not errors, f"Pipeline errors: {errors}" 38 39 # Find the seed node and the target node 40 from asm.ir import collect_all_nodes 41 all_nodes = collect_all_nodes(graph) 42 43 seed_nodes = [n for n in all_nodes.values() if n.seed] 44 assert len(seed_nodes) == 1 45 assert seed_nodes[0].iram_offset is None 46 47 target_nodes = [n for n in all_nodes.values() if not n.seed and n.pe is not None] 48 assert len(target_nodes) == 1 49 assert target_nodes[0].iram_offset is not None 50 51 def test_seed_generates_monad_for_monadic_target(self): 52 """Seed into a monadic target produces MonadToken.""" 53 source = """\ 54@system pe=1, sm=1 55&target|pe0 <| pass 56const 42 |> &target 57""" 58 result = assemble(source) 59 seeds = result.seed_tokens 60 assert len(seeds) == 1 61 assert isinstance(seeds[0], MonadToken) 62 assert seeds[0].data == 42 63 assert seeds[0].target == 0 64 65 def test_seed_generates_dyad_for_dyadic_target(self): 66 """Seed into a dyadic target produces DyadToken.""" 67 source = """\ 68@system pe=1, sm=1 69&target|pe0 <| add 70const 5 |> &target:L 71const 3 |> &target:R 72""" 73 result = assemble(source) 74 seeds = result.seed_tokens 75 assert len(seeds) == 2 76 77 dyad_seeds = [s for s in seeds if isinstance(s, DyadToken)] 78 assert len(dyad_seeds) == 2 79 80 # Check ports 81 ports = {s.port for s in dyad_seeds} 82 assert ports == {Port.L, Port.R} 83 84 # Check data values 85 data_values = {s.data for s in dyad_seeds} 86 assert data_values == {5, 3} 87 88 def test_seed_e2e_add_two_seeds(self): 89 """End-to-end: two seed constants into ADD, routed to second PE.""" 90 source = """\ 91@system pe=2, sm=1 92&sum|pe0 <| add 93&sink|pe1 <| pass 94&sum |> &sink 95const 10 |> &sum:L 96const 32 |> &sum:R 97""" 98 result = assemble(source) 99 env = simpy.Environment() 100 sys = build_topology(env, result.pe_configs, result.sm_configs) 101 for seed in result.seed_tokens: 102 sys.inject(seed) 103 env.run(until=100) 104 105 # ADD(10, 32) = 42, routed to PE1; PE0 output_log has the emitted token 106 assert len(sys.pes[0].output_log) >= 1 107 out = sys.pes[0].output_log[0] 108 assert out.data == 42 109 110 111class TestInlineConst: 112 """Tests for the `&foo <| add 7` inline constant shorthand.""" 113 114 def test_inline_const_parsing(self): 115 """Inline const shorthand parses correctly.""" 116 source = """\ 117@system pe=1, sm=1 118&foo|pe0 <| add 7 119const 3 |> &foo:L 120""" 121 result = assemble(source) 122 assert len(result.seed_tokens) >= 1 123 124 def test_inline_const_equivalent_to_comma(self): 125 """Inline const `add 7` is equivalent to `add, 7`.""" 126 source_inline = """\ 127@system pe=1, sm=1 128&foo|pe0 <| add 7 129const 3 |> &foo:L 130""" 131 source_comma = """\ 132@system pe=1, sm=1 133&foo|pe0 <| add, 7 134const 3 |> &foo:L 135""" 136 result_inline = assemble(source_inline) 137 result_comma = assemble(source_comma) 138 139 # Both should produce same IRAM contents 140 assert len(result_inline.pe_configs) == len(result_comma.pe_configs) 141 pe_inline = result_inline.pe_configs[0] 142 pe_comma = result_comma.pe_configs[0] 143 144 # Same IRAM instruction 145 for offset in pe_inline.iram.keys(): 146 inst_inline = pe_inline.iram[offset] 147 inst_comma = pe_comma.iram[offset] 148 assert inst_inline.opcode == inst_comma.opcode 149 150 def test_inline_const_hex(self): 151 """Inline const supports hex values.""" 152 source = """\ 153@system pe=1, sm=1 154&foo|pe0 <| add 0xFF 155&out|pe0 <| pass 156&foo |> &out 157const 1 |> &foo:L 158""" 159 result = assemble(source) 160 pe = result.pe_configs[0] 161 # Find the ADD instruction (should have has_const=True) 162 add_insts = [inst for inst in pe.iram.values() if inst.has_const] 163 assert len(add_insts) > 0 164 inst = add_insts[0] 165 from cm_inst import ArithOp 166 assert inst.opcode == ArithOp.ADD 167 assert inst.has_const == True 168 169 170class TestNonCommutativeValidation: 171 """Tests for non-commutative op + IRAM const port validation.""" 172 173 def test_noncommutative_no_port_warning(self): 174 """Non-commutative op with const and no explicit port emits warning.""" 175 source = """\ 176@system pe=1, sm=1 177&diff|pe0 <| sub 3 178&src|pe0 <| pass 179&src |> &diff 180const 1 |> &src 181""" 182 # Should still assemble (warnings don't block) 183 result = assemble(source) 184 assert len(result.pe_configs) == 1 185 186 # But run_pipeline should show the warning in graph.errors 187 graph = run_pipeline(source) 188 warnings = [ 189 e for e in graph.errors 190 if e.severity == ErrorSeverity.WARNING and 'Non-commutative' in e.message 191 ] 192 assert len(warnings) >= 1 193 assert 'SUB' in warnings[0].message 194 195 def test_noncommutative_explicit_port_ok(self): 196 """Non-commutative op with const and explicit port is fine.""" 197 source = """\ 198@system pe=1, sm=1 199&diff|pe0 <| sub 3 200&src|pe0 <| pass 201&src |> &diff:L 202const 1 |> &src 203""" 204 result = assemble(source) 205 # Should assemble without error 206 assert len(result.pe_configs) == 1 207 208 def test_commutative_no_port_ok(self): 209 """Commutative op with const and no explicit port is fine.""" 210 source = """\ 211@system pe=1, sm=1 212&total|pe0 <| add 10 213&src|pe0 <| pass 214&src |> &total 215const 1 |> &src 216""" 217 result = assemble(source) 218 assert len(result.pe_configs) == 1 219 220 def test_comparison_ops_noncommutative(self): 221 """LT, GT, LTE, GTE are non-commutative and emit warnings.""" 222 for op_name in ['lt', 'gt', 'lte', 'gte']: 223 source = f"""\ 224@system pe=1, sm=1 225&cmp|pe0 <| {op_name} 5 226&src|pe0 <| pass 227&src |> &cmp 228const 1 |> &src 229""" 230 graph = run_pipeline(source) 231 warnings = [ 232 e for e in graph.errors 233 if e.severity == ErrorSeverity.WARNING and 'Non-commutative' in e.message 234 ] 235 assert len(warnings) >= 1, f"{op_name} should emit non-commutative warning" 236 237 238class TestTriggerableConstant: 239 """Verify that triggerable constants (&c <| const, 7) still work.""" 240 241 def test_triggerable_const_in_iram(self): 242 """Triggerable constant occupies IRAM slot and fires on input.""" 243 source = """\ 244@system pe=1, sm=1 245&c|pe0 <| const, 7 246""" 247 result = assemble(source) 248 pe = result.pe_configs[0] 249 assert len(pe.iram) == 1 250 251 # Should also produce a seed token (no incoming edges) 252 assert len(result.seed_tokens) == 1 253 assert isinstance(result.seed_tokens[0], MonadToken) 254 assert result.seed_tokens[0].data == 7 255 256 def test_triggerable_const_with_incoming_edge_not_seed(self): 257 """Triggerable const with incoming edge is NOT a seed.""" 258 source = """\ 259@system pe=1, sm=1 260&src|pe0 <| pass 261&c|pe0 <| const, 42 262&src |> &c 263const 1 |> &src 264""" 265 result = assemble(source) 266 # The triggerable const has an incoming edge, so only &src seed counts 267 # plus the seed node for const 1 268 const_seeds = [s for s in result.seed_tokens if s.data == 42] 269 assert len(const_seeds) == 0