OR-1 dataflow CPU sketch
at main 462 lines 12 kB view raw
1"""Tests for function call wiring in the expand pass. 2 3Verifies AC4.1 through AC4.10: call syntax parsing, lowering, and wiring. 4""" 5 6import pytest 7from asm import _get_parser 8from asm.expand import expand 9from asm.ir import IRRegion, RegionKind, CallSite 10from asm.errors import ErrorCategory 11from cm_inst import Port, RoutingOp 12from tests.pipeline import parse_and_lower 13 14 15def parse_lower_expand(source: str): 16 """Parse, lower, and expand dfasm source.""" 17 parser = _get_parser() 18 graph = parse_and_lower(parser, source) 19 return expand(graph) 20 21 22# AC4.1: Call syntax with ctx_override edges 23def test_call_syntax_basic(): 24 """AC4.1: Parse and lower function call with named argument.""" 25 parser = _get_parser() 26 source = """ 27 $add |> { 28 &a <| pass 29 &sum <| pass 30 } 31 32 &x <| const, 5 33 $add a=&x |> @result 34 """ 35 36 graph = parse_and_lower(parser, source) 37 38 # Verify raw_call_sites was populated 39 assert len(graph.raw_call_sites) == 1 40 call_site = graph.raw_call_sites[0] 41 assert call_site.func_name == "$add" 42 assert len(call_site.input_args) == 1 43 # Input args are stored as (param_name, ref_dict) 44 param_name, ref = call_site.input_args[0] 45 assert param_name == "a" 46 assert isinstance(ref, dict) and ref.get("name") == "&x" 47 48 49def test_call_syntax_multiple_args(): 50 """AC4.1: Call with multiple named arguments.""" 51 parser = _get_parser() 52 source = """ 53 $adder |> { 54 &a <| pass 55 &b <| pass 56 } 57 58 &x <| const, 3 59 &y <| const, 4 60 $adder a=&x, b=&y |> @out 61 """ 62 63 graph = parse_and_lower(parser, source) 64 65 assert len(graph.raw_call_sites) == 1 66 call_site = graph.raw_call_sites[0] 67 assert len(call_site.input_args) == 2 68 # Check args are present with correct parameter names 69 param_names = [arg[0] for arg in call_site.input_args] 70 assert "a" in param_names 71 assert "b" in param_names 72 73 74def test_call_syntax_positional_args(): 75 """AC4.1: Call with positional arguments.""" 76 parser = _get_parser() 77 source = """ 78 $func |> { 79 &a <| pass 80 } 81 82 &x <| const, 1 83 $func &x |> @result 84 """ 85 86 graph = parse_and_lower(parser, source) 87 88 assert len(graph.raw_call_sites) == 1 89 call_site = graph.raw_call_sites[0] 90 # Positional args are stored as (None, ref_dict) tuples 91 assert len(call_site.input_args) == 1 92 assert call_site.input_args[0][0] is None # param_name is None for positional 93 assert call_site.input_args[0][1]["name"] == "&x" # source ref is &x 94 95 96# AC4.2: Synthetic @ret rendezvous node creation 97def test_synthetic_ret_node_creation(): 98 """AC4.2: Expand creates synthetic $func.@ret pass node.""" 99 source = """ 100 $add |> { 101 &a <| pass 102 &b <| pass 103 &sum <| add 104 &a |> &sum:L 105 &b |> &sum:R 106 &sum |> @ret 107 } 108 109 &x <| const, 3 110 &y <| const, 4 111 $add a=&x, b=&y |> @result 112 """ 113 114 graph = parse_lower_expand(source) 115 116 # Check that synthetic node was created 117 assert "$add.@ret" in graph.nodes 118 synthetic_ret = graph.nodes["$add.@ret"] 119 assert synthetic_ret.opcode == RoutingOp.PASS 120 121 122def test_trampoline_node_creation(): 123 """AC4.2: Expand creates per-call-site trampoline pass node.""" 124 source = """ 125 $add |> { 126 &a <| pass 127 &sum <| add 128 &sum |> @ret 129 } 130 131 &x <| const, 5 132 $add a=&x |> @result 133 """ 134 135 graph = parse_lower_expand(source) 136 137 # Check that trampoline node was created 138 trampoline_found = False 139 for node_name in graph.nodes: 140 if node_name.startswith("$add.__ret_trampoline_"): 141 trampoline_found = True 142 tramp = graph.nodes[node_name] 143 assert tramp.opcode == RoutingOp.PASS 144 break 145 146 assert trampoline_found, "No trampoline node found" 147 148 149# AC4.3: Named returns with dual outputs 150def test_named_returns_multiple(): 151 """AC4.3: Multiple named @ret_name variants create separate synthetic nodes.""" 152 source = """ 153 $adder |> { 154 &a <| pass 155 &b <| pass 156 &sum <| add 157 &carry <| pass 158 &a |> &sum:L 159 &b |> &sum:R 160 &sum |> @ret_sum 161 &carry |> @ret_carry 162 } 163 164 &three <| const, 3 165 &two <| const, 2 166 $adder a=&three, b=&two |> sum=@s, carry=@c 167 """ 168 169 graph = parse_lower_expand(source) 170 171 # Check for both synthetic nodes 172 assert "$adder.@ret_sum" in graph.nodes 173 assert "$adder.@ret_carry" in graph.nodes 174 175 # Both should be pass nodes 176 assert graph.nodes["$adder.@ret_sum"].opcode == RoutingOp.PASS 177 assert graph.nodes["$adder.@ret_carry"].opcode == RoutingOp.PASS 178 179 180# AC4.4: Named output wiring 181def test_named_output_wiring(): 182 """AC4.4: Call with sum=@dest wires trampoline to specified destination.""" 183 source = """ 184 $add |> { 185 &a <| pass 186 &sum <| add 187 &sum |> @ret_sum 188 } 189 190 &x <| const, 5 191 $add a=&x |> sum=@my_output 192 """ 193 194 graph = parse_lower_expand(source) 195 196 # Verify trampoline exists 197 trampoline_found = False 198 for node_name in graph.nodes: 199 if "__ret_trampoline_" in node_name: 200 trampoline_found = True 201 break 202 203 assert trampoline_found, "No trampoline node found" 204 205 # Find edge from trampoline to @my_output with ctx_override 206 tramp_to_output_found = False 207 for edge in graph.edges: 208 if isinstance(edge.dest, str) and "my_output" in edge.dest: 209 if "trampoline" in str(edge.source): 210 tramp_to_output_found = True 211 assert edge.ctx_override == True 212 break 213 214 assert tramp_to_output_found, "Trampoline to output edge not found" 215 216 217# AC4.5: free_ctx auto-insertion 218def test_free_ctx_auto_insertion(): 219 """AC4.5: free_ctx node auto-inserted on every return path.""" 220 source = """ 221 $add |> { 222 &a <| pass 223 &sum <| add 224 &sum |> @ret 225 } 226 227 &x <| const, 5 228 $add a=&x |> @result 229 """ 230 231 graph = parse_lower_expand(source) 232 233 # Check for free_frame node 234 free_frame_found = False 235 for node_name in graph.nodes: 236 if node_name.startswith("$add.__free_frame_"): 237 free_frame_found = True 238 free_frame = graph.nodes[node_name] 239 assert free_frame.opcode == RoutingOp.FREE_FRAME 240 break 241 242 assert free_frame_found, "No free_frame node found" 243 244 # Check that trampoline's dest_r wires to free_frame 245 tramp_to_free_found = False 246 for edge in graph.edges: 247 if "trampoline" in edge.source and "free_frame" in edge.dest: 248 tramp_to_free_found = True 249 assert edge.source_port == Port.R # Output from trampoline R port 250 break 251 252 assert tramp_to_free_found, "Trampoline to free_frame edge not found" 253 254 255# AC4.6: Multiple call sites get distinct contexts and trampolines 256def test_multiple_call_sites(): 257 """AC4.6: Two calls to same function get separate trampolines and ctx slots.""" 258 source = """ 259 $add |> { 260 &a <| pass 261 &sum <| add 262 &sum |> @ret 263 } 264 265 &x <| const, 3 266 &y <| const, 4 267 $add a=&x |> @r1 268 $add a=&y |> @r2 269 """ 270 271 graph = parse_lower_expand(source) 272 273 # Check that we have 2 CallSite entries 274 assert len(graph.call_sites) == 2 275 assert graph.call_sites[0].call_id == 0 276 assert graph.call_sites[1].call_id == 1 277 278 # Check that we have separate trampolines 279 trampoline_names = [ 280 n for n in graph.nodes.keys() 281 if "__ret_trampoline_" in n 282 ] 283 assert len(trampoline_names) == 2, f"Expected 2 trampolines, got {len(trampoline_names)}" 284 285 286# AC4.7: Cross-PE function calls 287def test_cross_pe_function_calls(): 288 """AC4.7: Call from one PE to function on another PE.""" 289 source = """ 290 @system pe=2, sm=1 291 292 $add |> { 293 &a <| pass 294 &sum <| add 295 &sum |> @ret 296 } 297 298 &x <| const, 5 299 $add a=&x |> @result 300 """ 301 302 graph = parse_lower_expand(source) 303 304 # Verify function region exists 305 func_region = None 306 for region in graph.regions: 307 if region.kind == RegionKind.FUNCTION and region.tag == "$add": 308 func_region = region 309 break 310 311 assert func_region is not None 312 313 # Check input edges have ctx_override=True 314 input_ctx_override_found = False 315 for edge in graph.edges: 316 if edge.ctx_override: 317 input_ctx_override_found = True 318 break 319 320 assert input_ctx_override_found, "Input edge with ctx_override not found" 321 322 323# AC4.9: Named arg not matching any label produces NAME error 324def test_undefined_argument_label(): 325 """AC4.9: Call with argument that doesn't match any label in function.""" 326 source = """ 327 $add |> { 328 &a <| pass 329 } 330 331 &x <| const, 5 332 $add b=&x |> @result 333 """ 334 335 graph = parse_lower_expand(source) 336 337 # Should have an error about argument 'b' not matching 338 errors = graph.errors 339 assert len(errors) > 0 340 assert any( 341 err.category == ErrorCategory.CALL and "b" in err.message 342 for err in errors 343 ) 344 345 346# AC4.10: Call to undefined function produces NAME error 347def test_undefined_function(): 348 """AC4.10: Call to non-existent function produces NAME error.""" 349 source = """ 350 &x <| const, 5 351 $nonexistent a=&x |> @result 352 """ 353 354 graph = parse_lower_expand(source) 355 356 # Should have an error about undefined function 357 errors = graph.errors 358 assert len(errors) > 0 359 assert any( 360 err.category == ErrorCategory.CALL and "undefined" in err.message 361 for err in errors 362 ) 363 364 365def test_call_site_metadata(): 366 """CallSite metadata correctly populated.""" 367 source = """ 368 $add |> { 369 &a <| pass 370 &sum <| add 371 &sum |> @ret 372 } 373 374 &x <| const, 5 375 $add a=&x |> @result 376 """ 377 378 graph = parse_lower_expand(source) 379 380 assert len(graph.call_sites) == 1 381 call_site = graph.call_sites[0] 382 assert call_site.func_name == "$add" 383 assert call_site.call_id == 0 384 assert len(call_site.trampoline_nodes) > 0 385 assert len(call_site.free_frame_nodes) > 0 386 387 388def test_input_edges_use_inherit_not_ctx_override(): 389 """Input edges from call site to function parameters use INHERIT, not CHANGE_TAG. 390 391 In the frame-based model, cross-context routing for input edges is handled by 392 the FrameDest's act_id (which differs between caller and function activation). 393 Only return trampolines use ctx_override/CHANGE_TAG because they decode a 394 packed FrameDest from EXTRACT_TAG. 395 """ 396 source = """ 397 $add |> { 398 &a <| pass 399 &sum <| add 400 &sum |> @ret 401 } 402 403 &x <| const, 5 404 $add a=&x |> @result 405 """ 406 407 graph = parse_lower_expand(source) 408 409 # Since &x has const=5, a trampoline is inserted: &x -> trampoline -> $add.&a 410 # Neither edge should have ctx_override — input routing uses INHERIT mode 411 for edge in graph.edges: 412 if "$add.&a" in str(edge.dest): 413 assert not edge.ctx_override, ( 414 f"Input edge to $add.&a should NOT have ctx_override " 415 f"(frame-based routing uses INHERIT with act_id in FrameDest)" 416 ) 417 418 # Only return trampoline edges should have ctx_override 419 ret_ctx_override = [e for e in graph.edges if e.ctx_override] 420 assert len(ret_ctx_override) > 0, "Return trampoline should have ctx_override" 421 for e in ret_ctx_override: 422 assert "__ret_trampoline" in e.source, ( 423 f"Only return trampolines should have ctx_override, got: {e.source} -> {e.dest}" 424 ) 425 426 427def test_shared_function_body(): 428 """Function body nodes and edges are shared across multiple call sites.""" 429 source = """ 430 $add |> { 431 &a <| pass 432 &sum <| add 433 &sum |> @ret 434 } 435 436 &x <| const, 3 437 &y <| const, 4 438 $add a=&x |> @r1 439 $add a=&y |> @r2 440 """ 441 442 graph = parse_lower_expand(source) 443 444 # Verify that we have 2 call sites 445 assert len(graph.call_sites) == 2 446 447 # Function body nodes are in the function region, not top-level 448 # Just verify the structure is correct 449 func_region = None 450 for region in graph.regions: 451 if region.kind == RegionKind.FUNCTION and region.tag == "$add": 452 func_region = region 453 break 454 455 assert func_region is not None 456 # Body should have &a and &sum nodes 457 assert "$add.&a" in func_region.body.nodes 458 assert "$add.&sum" in func_region.body.nodes 459 460 461if __name__ == "__main__": 462 pytest.main([__file__, "-v"])