OR-1 dataflow CPU sketch
1"""Tests for function call wiring in the expand pass.
2
3Verifies AC4.1 through AC4.10: call syntax parsing, lowering, and wiring.
4"""
5
6import pytest
7from asm import _get_parser
8from asm.expand import expand
9from asm.ir import IRRegion, RegionKind, CallSite
10from asm.errors import ErrorCategory
11from cm_inst import Port, RoutingOp
12from tests.pipeline import parse_and_lower
13
14
15def parse_lower_expand(source: str):
16 """Parse, lower, and expand dfasm source."""
17 parser = _get_parser()
18 graph = parse_and_lower(parser, source)
19 return expand(graph)
20
21
22# AC4.1: Call syntax with ctx_override edges
23def test_call_syntax_basic():
24 """AC4.1: Parse and lower function call with named argument."""
25 parser = _get_parser()
26 source = """
27 $add |> {
28 &a <| pass
29 &sum <| pass
30 }
31
32 &x <| const, 5
33 $add a=&x |> @result
34 """
35
36 graph = parse_and_lower(parser, source)
37
38 # Verify raw_call_sites was populated
39 assert len(graph.raw_call_sites) == 1
40 call_site = graph.raw_call_sites[0]
41 assert call_site.func_name == "$add"
42 assert len(call_site.input_args) == 1
43 # Input args are stored as (param_name, ref_dict)
44 param_name, ref = call_site.input_args[0]
45 assert param_name == "a"
46 assert isinstance(ref, dict) and ref.get("name") == "&x"
47
48
49def test_call_syntax_multiple_args():
50 """AC4.1: Call with multiple named arguments."""
51 parser = _get_parser()
52 source = """
53 $adder |> {
54 &a <| pass
55 &b <| pass
56 }
57
58 &x <| const, 3
59 &y <| const, 4
60 $adder a=&x, b=&y |> @out
61 """
62
63 graph = parse_and_lower(parser, source)
64
65 assert len(graph.raw_call_sites) == 1
66 call_site = graph.raw_call_sites[0]
67 assert len(call_site.input_args) == 2
68 # Check args are present with correct parameter names
69 param_names = [arg[0] for arg in call_site.input_args]
70 assert "a" in param_names
71 assert "b" in param_names
72
73
74def test_call_syntax_positional_args():
75 """AC4.1: Call with positional arguments."""
76 parser = _get_parser()
77 source = """
78 $func |> {
79 &a <| pass
80 }
81
82 &x <| const, 1
83 $func &x |> @result
84 """
85
86 graph = parse_and_lower(parser, source)
87
88 assert len(graph.raw_call_sites) == 1
89 call_site = graph.raw_call_sites[0]
90 # Positional args are stored as (None, ref_dict) tuples
91 assert len(call_site.input_args) == 1
92 assert call_site.input_args[0][0] is None # param_name is None for positional
93 assert call_site.input_args[0][1]["name"] == "&x" # source ref is &x
94
95
96# AC4.2: Synthetic @ret rendezvous node creation
97def test_synthetic_ret_node_creation():
98 """AC4.2: Expand creates synthetic $func.@ret pass node."""
99 source = """
100 $add |> {
101 &a <| pass
102 &b <| pass
103 &sum <| add
104 &a |> &sum:L
105 &b |> &sum:R
106 &sum |> @ret
107 }
108
109 &x <| const, 3
110 &y <| const, 4
111 $add a=&x, b=&y |> @result
112 """
113
114 graph = parse_lower_expand(source)
115
116 # Check that synthetic node was created
117 assert "$add.@ret" in graph.nodes
118 synthetic_ret = graph.nodes["$add.@ret"]
119 assert synthetic_ret.opcode == RoutingOp.PASS
120
121
122def test_trampoline_node_creation():
123 """AC4.2: Expand creates per-call-site trampoline pass node."""
124 source = """
125 $add |> {
126 &a <| pass
127 &sum <| add
128 &sum |> @ret
129 }
130
131 &x <| const, 5
132 $add a=&x |> @result
133 """
134
135 graph = parse_lower_expand(source)
136
137 # Check that trampoline node was created
138 trampoline_found = False
139 for node_name in graph.nodes:
140 if node_name.startswith("$add.__ret_trampoline_"):
141 trampoline_found = True
142 tramp = graph.nodes[node_name]
143 assert tramp.opcode == RoutingOp.PASS
144 break
145
146 assert trampoline_found, "No trampoline node found"
147
148
149# AC4.3: Named returns with dual outputs
150def test_named_returns_multiple():
151 """AC4.3: Multiple named @ret_name variants create separate synthetic nodes."""
152 source = """
153 $adder |> {
154 &a <| pass
155 &b <| pass
156 &sum <| add
157 &carry <| pass
158 &a |> &sum:L
159 &b |> &sum:R
160 &sum |> @ret_sum
161 &carry |> @ret_carry
162 }
163
164 &three <| const, 3
165 &two <| const, 2
166 $adder a=&three, b=&two |> sum=@s, carry=@c
167 """
168
169 graph = parse_lower_expand(source)
170
171 # Check for both synthetic nodes
172 assert "$adder.@ret_sum" in graph.nodes
173 assert "$adder.@ret_carry" in graph.nodes
174
175 # Both should be pass nodes
176 assert graph.nodes["$adder.@ret_sum"].opcode == RoutingOp.PASS
177 assert graph.nodes["$adder.@ret_carry"].opcode == RoutingOp.PASS
178
179
180# AC4.4: Named output wiring
181def test_named_output_wiring():
182 """AC4.4: Call with sum=@dest wires trampoline to specified destination."""
183 source = """
184 $add |> {
185 &a <| pass
186 &sum <| add
187 &sum |> @ret_sum
188 }
189
190 &x <| const, 5
191 $add a=&x |> sum=@my_output
192 """
193
194 graph = parse_lower_expand(source)
195
196 # Verify trampoline exists
197 trampoline_found = False
198 for node_name in graph.nodes:
199 if "__ret_trampoline_" in node_name:
200 trampoline_found = True
201 break
202
203 assert trampoline_found, "No trampoline node found"
204
205 # Find edge from trampoline to @my_output with ctx_override
206 tramp_to_output_found = False
207 for edge in graph.edges:
208 if isinstance(edge.dest, str) and "my_output" in edge.dest:
209 if "trampoline" in str(edge.source):
210 tramp_to_output_found = True
211 assert edge.ctx_override == True
212 break
213
214 assert tramp_to_output_found, "Trampoline to output edge not found"
215
216
217# AC4.5: free_ctx auto-insertion
218def test_free_ctx_auto_insertion():
219 """AC4.5: free_ctx node auto-inserted on every return path."""
220 source = """
221 $add |> {
222 &a <| pass
223 &sum <| add
224 &sum |> @ret
225 }
226
227 &x <| const, 5
228 $add a=&x |> @result
229 """
230
231 graph = parse_lower_expand(source)
232
233 # Check for free_frame node
234 free_frame_found = False
235 for node_name in graph.nodes:
236 if node_name.startswith("$add.__free_frame_"):
237 free_frame_found = True
238 free_frame = graph.nodes[node_name]
239 assert free_frame.opcode == RoutingOp.FREE_FRAME
240 break
241
242 assert free_frame_found, "No free_frame node found"
243
244 # Check that trampoline's dest_r wires to free_frame
245 tramp_to_free_found = False
246 for edge in graph.edges:
247 if "trampoline" in edge.source and "free_frame" in edge.dest:
248 tramp_to_free_found = True
249 assert edge.source_port == Port.R # Output from trampoline R port
250 break
251
252 assert tramp_to_free_found, "Trampoline to free_frame edge not found"
253
254
255# AC4.6: Multiple call sites get distinct contexts and trampolines
256def test_multiple_call_sites():
257 """AC4.6: Two calls to same function get separate trampolines and ctx slots."""
258 source = """
259 $add |> {
260 &a <| pass
261 &sum <| add
262 &sum |> @ret
263 }
264
265 &x <| const, 3
266 &y <| const, 4
267 $add a=&x |> @r1
268 $add a=&y |> @r2
269 """
270
271 graph = parse_lower_expand(source)
272
273 # Check that we have 2 CallSite entries
274 assert len(graph.call_sites) == 2
275 assert graph.call_sites[0].call_id == 0
276 assert graph.call_sites[1].call_id == 1
277
278 # Check that we have separate trampolines
279 trampoline_names = [
280 n for n in graph.nodes.keys()
281 if "__ret_trampoline_" in n
282 ]
283 assert len(trampoline_names) == 2, f"Expected 2 trampolines, got {len(trampoline_names)}"
284
285
286# AC4.7: Cross-PE function calls
287def test_cross_pe_function_calls():
288 """AC4.7: Call from one PE to function on another PE."""
289 source = """
290 @system pe=2, sm=1
291
292 $add |> {
293 &a <| pass
294 &sum <| add
295 &sum |> @ret
296 }
297
298 &x <| const, 5
299 $add a=&x |> @result
300 """
301
302 graph = parse_lower_expand(source)
303
304 # Verify function region exists
305 func_region = None
306 for region in graph.regions:
307 if region.kind == RegionKind.FUNCTION and region.tag == "$add":
308 func_region = region
309 break
310
311 assert func_region is not None
312
313 # Check input edges have ctx_override=True
314 input_ctx_override_found = False
315 for edge in graph.edges:
316 if edge.ctx_override:
317 input_ctx_override_found = True
318 break
319
320 assert input_ctx_override_found, "Input edge with ctx_override not found"
321
322
323# AC4.9: Named arg not matching any label produces NAME error
324def test_undefined_argument_label():
325 """AC4.9: Call with argument that doesn't match any label in function."""
326 source = """
327 $add |> {
328 &a <| pass
329 }
330
331 &x <| const, 5
332 $add b=&x |> @result
333 """
334
335 graph = parse_lower_expand(source)
336
337 # Should have an error about argument 'b' not matching
338 errors = graph.errors
339 assert len(errors) > 0
340 assert any(
341 err.category == ErrorCategory.CALL and "b" in err.message
342 for err in errors
343 )
344
345
346# AC4.10: Call to undefined function produces NAME error
347def test_undefined_function():
348 """AC4.10: Call to non-existent function produces NAME error."""
349 source = """
350 &x <| const, 5
351 $nonexistent a=&x |> @result
352 """
353
354 graph = parse_lower_expand(source)
355
356 # Should have an error about undefined function
357 errors = graph.errors
358 assert len(errors) > 0
359 assert any(
360 err.category == ErrorCategory.CALL and "undefined" in err.message
361 for err in errors
362 )
363
364
365def test_call_site_metadata():
366 """CallSite metadata correctly populated."""
367 source = """
368 $add |> {
369 &a <| pass
370 &sum <| add
371 &sum |> @ret
372 }
373
374 &x <| const, 5
375 $add a=&x |> @result
376 """
377
378 graph = parse_lower_expand(source)
379
380 assert len(graph.call_sites) == 1
381 call_site = graph.call_sites[0]
382 assert call_site.func_name == "$add"
383 assert call_site.call_id == 0
384 assert len(call_site.trampoline_nodes) > 0
385 assert len(call_site.free_frame_nodes) > 0
386
387
388def test_input_edges_use_inherit_not_ctx_override():
389 """Input edges from call site to function parameters use INHERIT, not CHANGE_TAG.
390
391 In the frame-based model, cross-context routing for input edges is handled by
392 the FrameDest's act_id (which differs between caller and function activation).
393 Only return trampolines use ctx_override/CHANGE_TAG because they decode a
394 packed FrameDest from EXTRACT_TAG.
395 """
396 source = """
397 $add |> {
398 &a <| pass
399 &sum <| add
400 &sum |> @ret
401 }
402
403 &x <| const, 5
404 $add a=&x |> @result
405 """
406
407 graph = parse_lower_expand(source)
408
409 # Since &x has const=5, a trampoline is inserted: &x -> trampoline -> $add.&a
410 # Neither edge should have ctx_override — input routing uses INHERIT mode
411 for edge in graph.edges:
412 if "$add.&a" in str(edge.dest):
413 assert not edge.ctx_override, (
414 f"Input edge to $add.&a should NOT have ctx_override "
415 f"(frame-based routing uses INHERIT with act_id in FrameDest)"
416 )
417
418 # Only return trampoline edges should have ctx_override
419 ret_ctx_override = [e for e in graph.edges if e.ctx_override]
420 assert len(ret_ctx_override) > 0, "Return trampoline should have ctx_override"
421 for e in ret_ctx_override:
422 assert "__ret_trampoline" in e.source, (
423 f"Only return trampolines should have ctx_override, got: {e.source} -> {e.dest}"
424 )
425
426
427def test_shared_function_body():
428 """Function body nodes and edges are shared across multiple call sites."""
429 source = """
430 $add |> {
431 &a <| pass
432 &sum <| add
433 &sum |> @ret
434 }
435
436 &x <| const, 3
437 &y <| const, 4
438 $add a=&x |> @r1
439 $add a=&y |> @r2
440 """
441
442 graph = parse_lower_expand(source)
443
444 # Verify that we have 2 call sites
445 assert len(graph.call_sites) == 2
446
447 # Function body nodes are in the function region, not top-level
448 # Just verify the structure is correct
449 func_region = None
450 for region in graph.regions:
451 if region.kind == RegionKind.FUNCTION and region.tag == "$add":
452 func_region = region
453 break
454
455 assert func_region is not None
456 # Body should have &a and &sum nodes
457 assert "$add.&a" in func_region.body.nodes
458 assert "$add.&sum" in func_region.body.nodes
459
460
461if __name__ == "__main__":
462 pytest.main([__file__, "-v"])