OR-1 dataflow CPU sketch
at pe-frame-redesign 584 lines 19 kB view raw
1"""Tests for the Lower pass (CST → IRGraph transformation). 2 3Tests verify: 4- Instruction definition (inst_def) → IRNode with opcode, placement, named args 5- Plain edges → IREdge with correct source, dest, ports 6- Strong/weak inline edges → anonymous nodes + wiring 7- Data definitions → IRDataDef with SM ID, cell address, packed values 8- System pragma → SystemConfig with pe_count, sm_count, etc. 9- Function scoping → qualified names, nested IRRegions 10- Location directives → LOCATION IRRegions 11- Error handling → reserved names, duplicate definitions 12""" 13 14from tests.pipeline import parse_and_lower 15 16from asm.ir import RegionKind, SourceLoc 17from asm.errors import ErrorCategory 18from cm_inst import ArithOp, LogicOp, MemOp, Port, RoutingOp 19 20 21class TestInstDef: 22 """Tests for instruction definition (AC2.1, AC2.8, AC2.9).""" 23 24 def test_basic_instruction(self, parser): 25 """Parse simple instruction definition.""" 26 graph = parse_and_lower(parser, """\ 27 &my_add <| add 28 """) 29 30 assert "&my_add" in graph.nodes 31 node = graph.nodes["&my_add"] 32 assert node.opcode == ArithOp.ADD 33 assert node.name == "&my_add" 34 35 def test_instruction_with_const(self, parser): 36 """Parse instruction with constant operand.""" 37 graph = parse_and_lower(parser, """\ 38 &my_const <| const, 42 39 """) 40 41 assert "&my_const" in graph.nodes 42 node = graph.nodes["&my_const"] 43 assert node.opcode == RoutingOp.CONST 44 assert node.const == 42 45 46 def test_instruction_with_hex_const(self, parser): 47 """Parse instruction with hexadecimal constant.""" 48 graph = parse_and_lower(parser, """\ 49 &mask <| const, 0xFF 50 """) 51 52 assert "&mask" in graph.nodes 53 node = graph.nodes["&mask"] 54 assert node.const == 0xFF 55 56 def test_instruction_with_pe_placement(self, parser): 57 """Parse instruction with PE placement qualifier (AC2.8).""" 58 graph = parse_and_lower(parser, """\ 59 &my_add|pe0 <| add 60 """) 61 62 assert "&my_add" in graph.nodes 63 node = graph.nodes["&my_add"] 64 assert node.pe == 0 65 66 def test_instruction_with_pe_placement_nonzero(self, parser): 67 """Parse instruction with non-zero PE placement.""" 68 graph = parse_and_lower(parser, """\ 69 &result|pe2 <| pass 70 """) 71 72 assert "&result" in graph.nodes 73 node = graph.nodes["&result"] 74 assert node.pe == 2 75 76 def test_instruction_with_named_args(self, parser): 77 """Parse instruction with named arguments (AC2.9).""" 78 graph = parse_and_lower(parser, """\ 79 &serial <| ior, dest=0x45 80 """) 81 82 # ior is not in MNEMONIC_TO_OP, so we should have an error 83 # but the instruction should still be created or we should check errors 84 assert len(graph.errors) > 0 85 86 def test_shift_instruction(self, parser): 87 """Parse shift instruction.""" 88 graph = parse_and_lower(parser, """\ 89 &shift_left <| shl 90 """) 91 92 assert "&shift_left" in graph.nodes 93 node = graph.nodes["&shift_left"] 94 assert node.opcode == ArithOp.SHL 95 96 97class TestPlainEdge: 98 """Tests for plain edges (AC2.2).""" 99 100 def test_basic_plain_edge(self, parser): 101 """Parse basic plain edge.""" 102 graph = parse_and_lower(parser, """\ 103 &a <| pass 104 &b <| add 105 &a |> &b:L 106 """) 107 108 assert len(graph.edges) == 1 109 edge = graph.edges[0] 110 assert edge.source == "&a" 111 assert edge.dest == "&b" 112 assert edge.port == Port.L 113 114 def test_plain_edge_to_right_port(self, parser): 115 """Parse plain edge to right port.""" 116 graph = parse_and_lower(parser, """\ 117 &a <| pass 118 &b <| add 119 &a |> &b:R 120 """) 121 122 assert len(graph.edges) == 1 123 edge = graph.edges[0] 124 assert edge.port == Port.R 125 126 def test_plain_edge_fanout(self, parser): 127 """Parse fanout (one source to multiple destinations).""" 128 graph = parse_and_lower(parser, """\ 129 &a <| pass 130 &b <| add 131 &c <| sub 132 &a |> &b:L, &c:R 133 """) 134 135 assert len(graph.edges) == 2 136 assert graph.edges[0].dest == "&b" 137 assert graph.edges[0].port == Port.L 138 assert graph.edges[1].dest == "&c" 139 assert graph.edges[1].port == Port.R 140 141 def test_plain_edge_with_source_port(self, parser): 142 """Parse plain edge with source port specification.""" 143 graph = parse_and_lower(parser, """\ 144 &a:L <| pass 145 &b <| add 146 &a:L |> &b:L 147 """) 148 149 assert len(graph.edges) == 1 150 edge = graph.edges[0] 151 assert edge.source_port == Port.L 152 153 154class TestStrongEdge: 155 """Tests for strong inline edges (AC2.3).""" 156 157 def test_basic_strong_edge(self, parser): 158 """Parse basic strong inline edge.""" 159 graph = parse_and_lower(parser, """\ 160 &a <| pass 161 &b <| pass 162 &c <| pass 163 &d <| pass 164 add &a, &b |> &c, &d 165 """) 166 167 # Should create anonymous node 168 anon_nodes = [n for n in graph.nodes.keys() if n.startswith("&__anon_")] 169 assert len(anon_nodes) == 1 170 anon_name = anon_nodes[0] 171 172 anon_node = graph.nodes[anon_name] 173 assert anon_node.opcode == ArithOp.ADD 174 175 # Should create 4 edges: 2 inputs, 2 outputs 176 assert len(graph.edges) == 4 177 178 # Verify input edges 179 input_edges = [e for e in graph.edges if e.dest == anon_name] 180 assert len(input_edges) == 2 181 left_input = [e for e in input_edges if e.port == Port.L][0] 182 right_input = [e for e in input_edges if e.port == Port.R][0] 183 assert left_input.source == "&a" 184 assert right_input.source == "&b" 185 186 # Verify output edges 187 output_edges = [e for e in graph.edges if e.source == anon_name] 188 assert len(output_edges) == 2 189 190 def test_strong_edge_anonymous_name_format(self, parser): 191 """Verify anonymous nodes have correct naming.""" 192 graph = parse_and_lower(parser, """\ 193 &a <| pass 194 &b <| pass 195 add &a |> &b 196 """) 197 198 anon_nodes = [n for n in graph.nodes.keys() if n.startswith("&__anon_")] 199 assert len(anon_nodes) == 1 200 assert anon_nodes[0].startswith("&__anon_") 201 202 203class TestWeakEdge: 204 """Tests for weak inline edges (AC2.4).""" 205 206 def test_basic_weak_edge(self, parser): 207 """Parse basic weak inline edge.""" 208 graph = parse_and_lower(parser, """\ 209 &a <| pass 210 &b <| pass 211 &c <| pass 212 &d <| pass 213 &c, &d sub <| &a, &b 214 """) 215 216 # Should create anonymous node 217 anon_nodes = [n for n in graph.nodes.keys() if n.startswith("&__anon_")] 218 assert len(anon_nodes) == 1 219 anon_name = anon_nodes[0] 220 221 anon_node = graph.nodes[anon_name] 222 assert anon_node.opcode == ArithOp.SUB 223 224 def test_weak_edge_equivalent_to_strong(self, parser): 225 """Verify weak edge produces same IR as equivalent strong edge.""" 226 # Parse weak edge version 227 graph_weak = parse_and_lower(parser, """\ 228 &a <| pass 229 &b <| pass 230 &c <| pass 231 &d <| pass 232 &c, &d sub <| &a, &b 233 """) 234 235 # Parse strong edge version 236 graph_strong = parse_and_lower(parser, """\ 237 &a <| pass 238 &b <| pass 239 &c <| pass 240 &d <| pass 241 sub &a, &b |> &c, &d 242 """) 243 244 # Both should have one anonymous node 245 anon_weak = [n for n in graph_weak.nodes.keys() if n.startswith("&__anon_")] 246 anon_strong = [n for n in graph_strong.nodes.keys() if n.startswith("&__anon_")] 247 assert len(anon_weak) == 1 248 assert len(anon_strong) == 1 249 250 # Both should have the same opcodes for the anon nodes 251 assert graph_weak.nodes[anon_weak[0]].opcode == graph_strong.nodes[anon_strong[0]].opcode 252 253 254class TestDataDef: 255 """Tests for data definitions (AC2.5, AC2.6).""" 256 257 def test_basic_data_def(self, parser): 258 """Parse basic data definition.""" 259 graph = parse_and_lower(parser, """\ 260 @hello|sm0:0 = 0x05 261 """) 262 263 assert len(graph.data_defs) == 1 264 data_def = graph.data_defs[0] 265 assert data_def.name == "@hello" 266 assert data_def.sm_id == 0 267 assert data_def.cell_addr == 0 268 assert data_def.value == 0x05 269 270 def test_data_def_with_different_sm(self, parser): 271 """Parse data definition with different SM.""" 272 graph = parse_and_lower(parser, """\ 273 @data|sm1:2 = 0x42 274 """) 275 276 data_def = graph.data_defs[0] 277 assert data_def.sm_id == 1 278 assert data_def.cell_addr == 2 279 280 def test_data_def_char_pair_big_endian(self, parser): 281 """Parse data definition with char pair (big-endian packing) (AC2.6).""" 282 graph = parse_and_lower(parser, """\ 283 @hello|sm0:0 = 'h', 'e' 284 """) 285 286 data_def = graph.data_defs[0] 287 # 'h' = 0x68, 'e' = 0x65 288 # Big-endian: (0x68 << 8) | 0x65 = 0x6865 289 expected = (ord('h') << 8) | ord('e') 290 assert data_def.value == expected 291 292 def test_data_def_char_pair_he_le(self, parser): 293 """Verify big-endian packing of char pair.""" 294 graph = parse_and_lower(parser, """\ 295 @data|sm0:1 = 'l', 'l' 296 """) 297 298 data_def = graph.data_defs[0] 299 expected = (ord('l') << 8) | ord('l') 300 assert data_def.value == expected 301 302 303class TestSystemConfig: 304 """Tests for system pragma (AC2.7).""" 305 306 def test_system_pragma_minimal(self, parser): 307 """Parse minimal system pragma.""" 308 graph = parse_and_lower(parser, """\ 309 @system pe=4, sm=1 310 """) 311 312 assert graph.system is not None 313 assert graph.system.pe_count == 4 314 assert graph.system.sm_count == 1 315 assert graph.system.iram_capacity == 256 # default 316 assert graph.system.frame_count == 8 # default 317 318 def test_system_pragma_full(self, parser): 319 """Parse full system pragma.""" 320 graph = parse_and_lower(parser, """\ 321 @system pe=2, sm=1, iram=128, frames=2 322 """) 323 324 assert graph.system.pe_count == 2 325 assert graph.system.sm_count == 1 326 assert graph.system.iram_capacity == 128 327 assert graph.system.frame_count == 2 328 329 def test_system_pragma_hex_values(self, parser): 330 """Parse system pragma with hexadecimal values.""" 331 graph = parse_and_lower(parser, """\ 332 @system pe=0x04, sm=0x01 333 """) 334 335 assert graph.system.pe_count == 4 336 assert graph.system.sm_count == 1 337 338 339class TestFunctionScoping: 340 """Tests for function scoping (AC3.1, AC3.2, AC3.3, AC3.4).""" 341 342 def test_label_inside_function_qualified(self, parser): 343 """Verify labels inside functions are qualified (AC3.1).""" 344 graph = parse_and_lower(parser, """\ 345 $main |> { 346 &add <| add 347 } 348 """) 349 350 # Label should be qualified in the function region 351 assert len(graph.regions) == 1 352 region = graph.regions[0] 353 assert "$main.&add" in region.body.nodes 354 node = region.body.nodes["$main.&add"] 355 assert node.opcode == ArithOp.ADD 356 357 def test_global_node_not_qualified(self, parser): 358 """Verify @nodes are never qualified (AC3.2).""" 359 graph = parse_and_lower(parser, """\ 360 @global <| pass 361 """) 362 363 # Should not be qualified 364 assert "@global" in graph.nodes 365 assert "$main.@global" not in graph.nodes 366 367 def test_top_level_label_not_qualified(self, parser): 368 """Verify top-level labels are not qualified (AC3.3).""" 369 graph = parse_and_lower(parser, """\ 370 &top <| pass 371 """) 372 373 # Should not be qualified 374 assert "&top" in graph.nodes 375 assert "$main.&top" not in graph.nodes 376 377 def test_same_label_in_different_functions(self, parser): 378 """Verify functions can each define &add without collision (AC3.4).""" 379 graph = parse_and_lower(parser, """\ 380 $foo |> { 381 &add <| add 382 } 383 $bar |> { 384 &add <| sub 385 } 386 """) 387 388 # Both should exist with different names in their respective regions 389 assert len(graph.regions) == 2 390 foo_region = next(r for r in graph.regions if r.tag == "$foo") 391 bar_region = next(r for r in graph.regions if r.tag == "$bar") 392 assert "$foo.&add" in foo_region.body.nodes 393 assert "$bar.&add" in bar_region.body.nodes 394 assert foo_region.body.nodes["$foo.&add"].opcode == ArithOp.ADD 395 assert bar_region.body.nodes["$bar.&add"].opcode == ArithOp.SUB 396 397 398class TestRegions: 399 """Tests for regions (AC3.7, AC3.8).""" 400 401 def test_function_region_creation(self, parser): 402 """Verify function creates FUNCTION region (AC3.7).""" 403 graph = parse_and_lower(parser, """\ 404 $func |> { 405 &a <| add 406 } 407 """) 408 409 assert len(graph.regions) == 1 410 region = graph.regions[0] 411 assert region.tag == "$func" 412 assert region.kind == RegionKind.FUNCTION 413 assert "$func.&a" in region.body.nodes 414 415 def test_location_directive_creates_region(self, parser): 416 """Verify location directive creates LOCATION region (AC3.8).""" 417 graph = parse_and_lower(parser, """\ 418 @data_section|sm0: 419 """) 420 421 assert len(graph.regions) == 1 422 region = graph.regions[0] 423 assert region.tag == "@data_section" 424 assert region.kind == RegionKind.LOCATION 425 426 def test_location_directive_with_label_and_colon(self, parser): 427 """AC6.1: Location directive with label and trailing colon.""" 428 graph = parse_and_lower(parser, """\ 429 &section: 430 """) 431 # A bare label with colon becomes a location_dir 432 assert len(graph.regions) == 1 433 region = graph.regions[0] 434 assert region.kind == RegionKind.LOCATION 435 436 437class TestErrorCases: 438 """Tests for error handling (AC3.5, AC3.6).""" 439 440 def test_reserved_name_system_error(self, parser): 441 """Verify reserved name @system produces error (AC3.5).""" 442 graph = parse_and_lower(parser, """\ 443 @system <| add 444 """) 445 446 # Should have an error (note: @system is a keyword, might parse as pragma) 447 # If it parses as inst_def, check for NAME error 448 assert len(graph.errors) > 0 449 assert any(e.category == ErrorCategory.NAME for e in graph.errors) 450 451 def test_duplicate_label_in_function_error(self, parser): 452 """Verify duplicate labels in same function produce error (AC3.6).""" 453 graph = parse_and_lower(parser, """\ 454 $main |> { 455 &add <| add 456 &add <| sub 457 } 458 """) 459 460 # Should have a SCOPE error 461 assert len(graph.errors) > 0 462 assert any(e.category == ErrorCategory.SCOPE for e in graph.errors) 463 464 def test_duplicate_label_at_top_level_error(self, parser): 465 """Verify duplicate labels at top level produce error.""" 466 graph = parse_and_lower(parser, """\ 467 &label <| add 468 &label <| sub 469 """) 470 471 # Should have a SCOPE error 472 assert len(graph.errors) > 0 473 assert any(e.category == ErrorCategory.SCOPE for e in graph.errors) 474 475 476class TestMemOps: 477 """Tests for memory operations.""" 478 479 def test_read_op(self, parser): 480 """Parse READ operation.""" 481 graph = parse_and_lower(parser, """\ 482 &cell <| read 483 """) 484 485 node = graph.nodes["&cell"] 486 assert node.opcode == MemOp.READ 487 488 def test_write_op(self, parser): 489 """Parse WRITE operation.""" 490 graph = parse_and_lower(parser, """\ 491 &cell <| write 492 """) 493 494 node = graph.nodes["&cell"] 495 assert node.opcode == MemOp.WRITE 496 497 def test_rd_inc_op(self, parser): 498 """Parse RD_INC operation.""" 499 graph = parse_and_lower(parser, """\ 500 &cell <| rd_inc 501 """) 502 503 node = graph.nodes["&cell"] 504 assert node.opcode == MemOp.RD_INC 505 506 507class TestEdgeCases: 508 """Tests for edge cases and integration.""" 509 510 def test_empty_program(self, parser): 511 """Parse empty program.""" 512 graph = parse_and_lower(parser, "") 513 514 assert len(graph.nodes) == 0 515 assert len(graph.edges) == 0 516 517 def test_program_with_comments(self, parser): 518 """Parse program with comments.""" 519 graph = parse_and_lower(parser, """\ 520 &my_add <| add ; this is a comment 521 &a |> &my_add:L ; wire a to add left 522 """) 523 524 assert "&my_add" in graph.nodes 525 assert len(graph.edges) == 1 526 527 def test_multiple_instructions(self, parser): 528 """Parse multiple instructions.""" 529 graph = parse_and_lower(parser, """\ 530 &a <| add 531 &b <| sub 532 &c <| pass 533 """) 534 535 assert len(graph.nodes) == 3 536 assert "&a" in graph.nodes 537 assert "&b" in graph.nodes 538 assert "&c" in graph.nodes 539 540 def test_complex_graph(self, parser): 541 """Parse a more complex program.""" 542 graph = parse_and_lower(parser, """\ 543 @system pe=4, sm=1 544 545 &init <| const, 0 546 &loop_add <| add 547 &cmp <| lte 548 &branch <| breq 549 550 &init |> &loop_add:L 551 &loop_add |> &cmp:L 552 &cmp |> &branch:L 553 """) 554 555 assert graph.system.pe_count == 4 556 assert len(graph.nodes) == 4 557 assert len(graph.edges) == 3 558 559 560class TestScalingAnonymousCounters: 561 """Tests that anonymous counter properly increments.""" 562 563 def test_multiple_strong_edges_increment_counter(self, parser): 564 """Verify each strong edge gets unique anonymous name.""" 565 graph = parse_and_lower(parser, """\ 566 &a <| pass 567 &b <| pass 568 &c <| pass 569 &d <| pass 570 &e <| pass 571 &f <| pass 572 add &a |> &b 573 sub &c |> &d 574 inc &e |> &f 575 """) 576 577 anon_nodes = [n for n in graph.nodes.keys() if n.startswith("&__anon_")] 578 assert len(anon_nodes) == 3 579 # Verify they have different counter values 580 counters = set() 581 for name in anon_nodes: 582 counter_str = name.split("_")[-1] 583 counters.add(int(counter_str)) 584 assert len(counters) == 3