OR-1 dataflow CPU sketch
at pe-frame-redesign 518 lines 16 kB view raw
1"""Tests for variadic repetition expansion (Phase 6). 2 3Tests verify: 4- Variadic macros expand correctly (repetition block once per argument) 5- ${_idx} produces correct iteration indices (0-based) 6- Mixed params: non-variadic first, variadic captures remaining args 7- Empty variadic invocation: no error, nothing expanded 8- Single variadic invocation: one iteration 9- Variadic parameter not last: error at lower pass 10- Full pipeline: variadic macro assembles and runs in emulator 11""" 12 13from pathlib import Path 14 15import simpy 16from lark import Lark 17 18from asm import assemble 19from asm.errors import ErrorCategory 20from asm.expand import expand 21from asm.ir import ( 22 IREdge, 23 IRGraph, 24 IRMacroCall, 25 IRNode, 26 IRRepetitionBlock, 27 MacroDef, 28 MacroParam, 29 ParamRef, 30 SourceLoc, 31) 32from asm.lower import lower 33from cm_inst import ArithOp 34from emu import build_topology 35 36 37def _get_parser(): 38 """Get the dfasm parser.""" 39 grammar_path = Path(__file__).parent.parent / "dfasm.lark" 40 return Lark( 41 grammar_path.read_text(), 42 parser="earley", 43 propagate_positions=True, 44 ) 45 46 47def parse_and_lower(source: str) -> IRGraph: 48 """Parse source and lower to IRGraph (before expansion).""" 49 parser = _get_parser() 50 tree = parser.parse(source) 51 return lower(tree) 52 53 54def parse_lower_expand(source: str) -> IRGraph: 55 """Parse, lower, and expand.""" 56 graph = parse_and_lower(source) 57 return expand(graph) 58 59 60class TestVariadicSimpleExpansion: 61 """Test basic variadic repetition expansion.""" 62 63 def test_simple_variadic_expands_three_iterations(self): 64 """Simple variadic: #inject *gates creates 3 pass nodes for 3 args.""" 65 source = """ 66 @system pe=1, sm=1 67 68 #inject *gates |> { 69 $( &g <| pass ),* 70 } 71 72 #inject &a, &b, &c 73 """ 74 graph = parse_lower_expand(source) 75 76 # Should have 3 nodes: #inject_0_rep0.&g, #inject_0_rep1.&g, #inject_0_rep2.&g 77 nodes = list(graph.nodes.keys()) 78 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" 79 80 # Each should have rep0, rep1, rep2 suffix to distinguish iterations 81 rep0 = [n for n in nodes if "rep0" in n] 82 rep1 = [n for n in nodes if "rep1" in n] 83 rep2 = [n for n in nodes if "rep2" in n] 84 assert len(rep0) == 1, f"Expected rep0 node in {nodes}" 85 assert len(rep1) == 1, f"Expected rep1 node in {nodes}" 86 assert len(rep2) == 1, f"Expected rep2 node in {nodes}" 87 88 def test_variadic_with_multiple_statements_per_iteration(self): 89 """Repetition block with multiple statements per iteration.""" 90 source = """ 91 @system pe=1, sm=1 92 93 #loop *items |> { 94 $( &item <| pass 95 &item |> &output:L ),* 96 } 97 98 #loop &x, &y 99 """ 100 graph = parse_lower_expand(source) 101 102 # 2 invocations * 2 statements = 4 nodes 103 nodes = list(graph.nodes.keys()) 104 assert len(nodes) >= 2, f"Expected at least 2 nodes, got {len(nodes)}: {nodes}" 105 106 # Should have rep0 and rep1 in names 107 assert any("rep0" in n for n in nodes), f"Expected rep0 in {nodes}" 108 assert any("rep1" in n for n in nodes), f"Expected rep1 in {nodes}" 109 110 111class TestVariadicIndexVariable: 112 """Test ${_idx} substitution in repetition blocks.""" 113 114 def test_idx_variable_expands_to_iteration_index(self): 115 """${_idx} becomes 0, 1, 2 in successive iterations via token pasting. 116 117 Tests that ParamRef with _idx parameter substitutes correctly during 118 variadic expansion. The _idx value is set to the iteration index (0-based) 119 and is available for token pasting concatenation. 120 """ 121 # Construct macro body with a ParamRef containing _idx 122 # Node name will be: &node_${_idx} -> ParamRef(param="_idx", prefix="&node_", suffix="") 123 param_ref = ParamRef(param="_idx", prefix="&node_", suffix="") 124 body_node = IRNode( 125 name=param_ref, 126 opcode=ArithOp.ADD, 127 loc=SourceLoc(0, 0), 128 ) 129 130 # Create repetition block with the node 131 rep_body = IRGraph( 132 nodes={"node_placeholder": body_node}, 133 edges=[], 134 macro_defs=[], 135 macro_calls=[], 136 ) 137 138 rep_block = IRRepetitionBlock( 139 body=rep_body, 140 variadic_param="vals", 141 loc=SourceLoc(0, 0), 142 ) 143 144 # Create macro definition with variadic parameter 145 macro_def = MacroDef( 146 name="maker", 147 params=(MacroParam(name="vals", variadic=True),), 148 body=IRGraph( 149 nodes={}, 150 edges=[], 151 macro_defs=[], 152 macro_calls=[], 153 ), 154 repetition_blocks=[rep_block], 155 loc=SourceLoc(0, 0), 156 ) 157 158 # Create macro call with 3 arguments 159 macro_call = IRMacroCall( 160 name="maker", 161 positional_args=(42, 100, 200), 162 named_args=(), 163 loc=SourceLoc(0, 0), 164 ) 165 166 # Create graph with macro definition and call 167 graph = IRGraph( 168 nodes={}, 169 edges=[], 170 regions=[], 171 data_defs=[], 172 macro_defs=[macro_def], 173 macro_calls=[macro_call], 174 ) 175 176 # Expand the graph 177 expanded = expand(graph) 178 179 # After expansion, should have 3 nodes with names: 180 # #maker_0_rep0.&node_0, #maker_0_rep1.&node_1, #maker_0_rep2.&node_2 181 node_names = list(expanded.nodes.keys()) 182 assert len(node_names) == 3, ( 183 f"Expected 3 nodes, got {len(node_names)}: {node_names}" 184 ) 185 186 # Verify that _idx was substituted correctly in node names 187 # Each iteration should have node_0, node_1, node_2 respectively 188 assert any("&node_0" in name for name in node_names), ( 189 f"Expected node with &node_0 (iteration 0), got {node_names}" 190 ) 191 assert any("&node_1" in name for name in node_names), ( 192 f"Expected node with &node_1 (iteration 1), got {node_names}" 193 ) 194 assert any("&node_2" in name for name in node_names), ( 195 f"Expected node with &node_2 (iteration 2), got {node_names}" 196 ) 197 198 199class TestVariadicMixedParams: 200 """Test variadic with non-variadic parameters.""" 201 202 def test_mixed_params_non_variadic_first(self): 203 """Macro with dest, *sources: first param is non-variadic.""" 204 source = """ 205 @system pe=1, sm=1 206 207 #route dest, *sources |> { 208 $( &src <| pass ),* 209 } 210 211 #route &output, &in1, &in2 212 """ 213 graph = parse_lower_expand(source) 214 215 # Should have 2 nodes (one per source) 216 nodes = list(graph.nodes.keys()) 217 assert len(nodes) == 2, f"Expected 2 nodes, got {len(nodes)}: {nodes}" 218 219 def test_mixed_params_three_args_two_non_variadic(self): 220 """Macro with a, b, *rest: args 3+ go to rest.""" 221 source = """ 222 @system pe=1, sm=1 223 224 #process a, b, *rest |> { 225 $( &r <| pass ),* 226 } 227 228 #process &x, &y, &z1, &z2, &z3 229 """ 230 graph = parse_lower_expand(source) 231 232 # 3 args go to *rest -> 3 iterations 233 nodes = list(graph.nodes.keys()) 234 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" 235 236 237class TestVariadicEdgeCases: 238 """Test edge cases: empty variadic, single arg, etc.""" 239 240 def test_empty_variadic_no_error(self): 241 """Invoke variadic macro with zero args: no error, nothing expanded.""" 242 source = """ 243 @system pe=1, sm=1 244 245 #optional *args |> { 246 $( &x <| pass ),* 247 } 248 249 #optional 250 """ 251 graph = parse_lower_expand(source) 252 253 # No nodes should be created 254 nodes = list(graph.nodes.keys()) 255 assert len(nodes) == 0, ( 256 f"Expected 0 nodes for empty variadic, got {len(nodes)}: {nodes}" 257 ) 258 259 def test_single_variadic_one_iteration(self): 260 """Invoke with one variadic arg: one iteration.""" 261 source = """ 262 @system pe=1, sm=1 263 264 #single *args |> { 265 $( &item <| pass ),* 266 } 267 268 #single &only 269 """ 270 graph = parse_lower_expand(source) 271 272 # One iteration -> one node with rep0 273 nodes = list(graph.nodes.keys()) 274 assert len(nodes) == 1 275 assert any("rep0" in n for n in nodes), f"Expected rep0 in {nodes}" 276 277 # Should NOT have rep1 278 assert not any("rep1" in n for n in nodes), f"Should not have rep1 in {nodes}" 279 280 281class TestVariadicGrammarValidation: 282 """Test grammar validation: variadic must be last, etc.""" 283 284 def test_variadic_not_last_is_error(self): 285 """Variadic parameter not last: parser/lower should reject.""" 286 source = """ 287 @system pe=1, sm=1 288 289 #bad *args, b |> { 290 $( &x <| pass ),* 291 } 292 """ 293 graph = parse_and_lower(source) 294 295 # Lower pass should catch this error 296 assert any(e.category == ErrorCategory.NAME for e in graph.errors), ( 297 f"Expected NAME error for variadic not last, got: {graph.errors}" 298 ) 299 300 def test_multiple_variadic_is_error(self): 301 """Multiple variadic parameters: parser/lower should reject.""" 302 source = """ 303 @system pe=1, sm=1 304 305 #bad *a, *b |> { 306 $( &x <| pass ),* 307 } 308 """ 309 graph = parse_and_lower(source) 310 311 # Lower pass should catch this error 312 assert any(e.category == ErrorCategory.NAME for e in graph.errors), ( 313 f"Expected NAME error for multiple variadic, got: {graph.errors}" 314 ) 315 316 317class TestVariadicIntegration: 318 """Integration tests with full pipeline.""" 319 320 def test_variadic_with_edges_between_iterations(self): 321 """Repetition block with edges wiring iterations together.""" 322 source = """ 323 @system pe=1, sm=1 324 325 #chain *items |> { 326 $( &item <| pass ),* 327 &item |> &output:L 328 } 329 330 #chain &a, &b, &c 331 """ 332 graph = parse_lower_expand(source) 333 334 # 3 nodes from repetition, plus potential edges 335 nodes = list(graph.nodes.keys()) 336 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" 337 338 # Should have edges (one per node in this case) 339 edges = graph.edges 340 assert len(edges) > 0, "Expected at least one edge" 341 342 def test_variadic_can_be_invoked_multiple_times(self): 343 """Same variadic macro invoked twice with different args.""" 344 source = """ 345 @system pe=1, sm=1 346 347 #expand *items |> { 348 $( &item <| pass ),* 349 } 350 351 #expand &a, &b 352 #expand &x, &y, &z 353 """ 354 graph = parse_lower_expand(source) 355 356 # First invocation: 2 nodes 357 # Second invocation: 3 nodes 358 # Total: 5 nodes 359 nodes = list(graph.nodes.keys()) 360 assert len(nodes) == 5, f"Expected 5 nodes, got {len(nodes)}: {nodes}" 361 362 # First invocation should have #expand_0_rep0, #expand_0_rep1 363 # Second invocation should have #expand_1_rep0, #expand_1_rep1, #expand_1_rep2 364 expand_0 = [n for n in nodes if "#expand_0" in n] 365 expand_1 = [n for n in nodes if "#expand_1" in n] 366 assert len(expand_0) == 2, ( 367 f"Expected 2 nodes from first invocation, got {len(expand_0)}" 368 ) 369 assert len(expand_1) == 3, ( 370 f"Expected 3 nodes from second invocation, got {len(expand_1)}" 371 ) 372 373 def test_variadic_nested_with_other_macros(self): 374 """Variadic macro combined with non-variadic macros.""" 375 source = """ 376 @system pe=1, sm=1 377 378 #simple |> { 379 &fixed <| pass 380 } 381 382 #expand *items |> { 383 $( &item <| pass ),* 384 } 385 386 #simple 387 #expand &a, &b 388 """ 389 graph = parse_lower_expand(source) 390 391 # 1 from #simple + 2 from #expand = 3 nodes 392 nodes = list(graph.nodes.keys()) 393 assert len(nodes) == 3, f"Expected 3 nodes, got {len(nodes)}: {nodes}" 394 395 # Should have both macro invocations in names 396 simple_nodes = [n for n in nodes if "#simple" in n] 397 expand_nodes = [n for n in nodes if "#expand" in n] 398 assert len(simple_nodes) == 1, f"Expected 1 #simple node" 399 assert len(expand_nodes) == 2, f"Expected 2 #expand nodes" 400 401 402class TestVariadicPositionalRet: 403 """Test positional @ret wiring in variadic repetition blocks.""" 404 405 def test_bare_ret_maps_to_positional_outputs_by_iteration(self): 406 """@ret in iteration N maps to Nth positional output at call site.""" 407 source = """ 408 @system pe=1, sm=1 409 410 #fan *vals |> { 411 $( &v <| pass 412 &v |> @ret ),* 413 } 414 415 &a <| pass 416 &b <| pass 417 &c <| pass 418 #fan &a, &b, &c |> &x, &y, &z 419 &x <| pass 420 &y <| pass 421 &z <| pass 422 """ 423 graph = parse_lower_expand(source) 424 assert not graph.errors, f"Unexpected errors: {graph.errors}" 425 426 # Check edges from expanded nodes to positional outputs 427 ret_edges = [e for e in graph.edges if e.dest in ("&x", "&y", "&z")] 428 dests = [e.dest for e in ret_edges] 429 assert "&x" in dests, f"Expected &x in ret edge dests: {dests}" 430 assert "&y" in dests, f"Expected &y in ret edge dests: {dests}" 431 assert "&z" in dests, f"Expected &z in ret edge dests: {dests}" 432 433 def test_positional_ret_with_two_outputs(self): 434 """Two variadic iterations map to two positional outputs.""" 435 source = """ 436 @system pe=1, sm=1 437 438 #pair *vals |> { 439 $( &v <| pass 440 &v |> @ret ),* 441 } 442 443 &a <| pass 444 &b <| pass 445 &left <| pass 446 &right <| pass 447 #pair &a, &b |> &left, &right 448 """ 449 graph = parse_lower_expand(source) 450 assert not graph.errors, f"Unexpected errors: {graph.errors}" 451 452 ret_edges = [e for e in graph.edges if e.dest in ("&left", "&right")] 453 assert len(ret_edges) == 2, ( 454 f"Expected 2 ret edges, got {len(ret_edges)}: {ret_edges}" 455 ) 456 457 def test_fewer_outputs_than_iterations_errors(self): 458 """More @ret iterations than positional outputs produces errors.""" 459 source = """ 460 @system pe=1, sm=1 461 462 #too_many *vals |> { 463 $( &v <| pass 464 &v |> @ret ),* 465 } 466 467 &a <| pass 468 &b <| pass 469 &c <| pass 470 &only_one <| pass 471 #too_many &a, &b, &c |> &only_one 472 """ 473 graph = parse_lower_expand(source) 474 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO] 475 assert len(macro_errors) >= 1, ( 476 f"Expected error for unmatched @ret, got: {graph.errors}" 477 ) 478 479 480class TestVariadicFullPipeline: 481 """Full pipeline test: variadic macro through assemble and emulator.""" 482 483 def test_variadic_macro_assembles_and_runs(self): 484 """Variadic macro with positional @ret wiring through full pipeline. 485 486 Each iteration's @ret maps to the next positional output at the 487 call site, so `#multi_const 3, 4 |> &sum:L, &sum:R` wires 488 iteration 0 → &sum:L and iteration 1 → &sum:R. 489 """ 490 source = """ 491 @system pe=1, sm=0 492 493 #multi_const *vals |> { 494 $( &c <| const, ${vals} 495 &c |> @ret ),* 496 } 497 498 &sum <| add 499 &out <| pass 500 #multi_const 3, 4 |> &sum:L, &sum:R 501 &sum |> &out:L 502 """ 503 result = assemble(source) 504 assert result is not None 505 assert len(result.pe_configs) > 0 506 507 env = simpy.Environment() 508 sys = build_topology(env, result.pe_configs, result.sm_configs) 509 for setup in result.setup_tokens: 510 sys.inject(setup) 511 for seed in result.seed_tokens: 512 sys.inject(seed) 513 env.run(until=500) 514 515 all_values = [] 516 for pe in sys.pes.values(): 517 all_values.extend(t.data for t in pe.output_log if hasattr(t, "data")) 518 assert 7 in all_values, f"Expected 3+4=7 in outputs, got {all_values}"