OR-1 dataflow CPU sketch
1"""Tests for Enhancement 3: @ret wiring for macros (macro-enh.E3.*).
2
3Tests verify:
4- E3.1: Grammar accepts |> output list on macro_call_stmt
5- E3.2: Lower stores output_dests on IRMacroCall
6- E3.3: Expand rewrites @ret edges to concrete destinations (positional)
7- E3.4: Expand rewrites @ret_name edges to concrete destinations (named)
8- E3.5: Unmatched @ret marker produces MACRO error
9- E3.6: Macro with @ret but no |> at call site produces MACRO error
10- E3.7: Multiple @ret markers with mixed positional/named outputs
11- E3.8: @ret in nested macro invocation
12"""
13
14from pathlib import Path
15
16from lark import Lark
17
18from asm.expand import expand
19from asm.lower import lower
20from asm.errors import ErrorCategory
21from asm.ir import IRMacroCall
22
23
24def _get_parser():
25 grammar_path = Path(__file__).parent.parent / "dfasm.lark"
26 return Lark(
27 grammar_path.read_text(),
28 parser="earley",
29 propagate_positions=True,
30 )
31
32
33def parse_and_lower(source: str):
34 parser = _get_parser()
35 tree = parser.parse(source)
36 return lower(tree)
37
38
39def parse_lower_expand(source: str):
40 graph = parse_and_lower(source)
41 return expand(graph)
42
43
44class TestE31_GrammarAcceptsOutputList:
45 """E3.1: Grammar accepts |> output list on macro_call_stmt.
46
47 Note: #macro |> &dest (no args, single positional output) is ambiguous
48 with plain_edge syntax and parses as plain_edge. Disambiguation requires
49 either named outputs (name=&dest) or at least one argument before |>.
50 """
51
52 def test_macro_call_with_named_output_no_args(self):
53 """#macro |> name=&dest parses as macro_call_stmt (named output disambiguates)."""
54 source = """
55 @system pe=1, sm=1
56 #simple |> {
57 &g <| pass
58 &g |> @ret_out
59 }
60 &sink <| add
61 #simple |> out=&sink
62 """
63 graph = parse_and_lower(source)
64 assert not graph.errors
65 assert len(graph.macro_calls) == 1
66
67 def test_macro_call_with_named_outputs(self):
68 """#macro |> name=&dest, name2=&dest2 parses without error."""
69 source = """
70 @system pe=1, sm=1
71 #dual |> {
72 &g <| pass
73 &g |> @ret_body
74 &g |> @ret_exit:R
75 }
76 &body_sink <| add
77 &exit_sink <| add
78 #dual |> body=&body_sink, exit=&exit_sink
79 """
80 graph = parse_and_lower(source)
81 assert not graph.errors
82
83 def test_macro_call_with_args_and_positional_output(self):
84 """#macro arg |> &dest parses as macro_call_stmt (args disambiguate)."""
85 source = """
86 @system pe=1, sm=1
87 #with_arg val |> {
88 &g <| const, ${val}
89 &g |> @ret
90 }
91 &sink <| add
92 #with_arg 42 |> &sink
93 """
94 graph = parse_and_lower(source)
95 assert not graph.errors
96 assert len(graph.macro_calls) == 1
97
98
99class TestE32_LowerStoresOutputDests:
100 """E3.2: Lower stores output_dests on IRMacroCall."""
101
102 def test_output_dests_populated_with_arg(self):
103 """IRMacroCall.output_dests contains output destinations (arg disambiguates)."""
104 source = """
105 @system pe=1, sm=1
106 #simple val |> {
107 &g <| const, ${val}
108 &g |> @ret
109 }
110 &sink <| add
111 #simple 1 |> &sink
112 """
113 graph = parse_and_lower(source)
114 assert not graph.errors
115 assert len(graph.macro_calls) == 1
116 call = graph.macro_calls[0]
117 assert len(call.output_dests) >= 1
118
119 def test_named_output_dests(self):
120 """Named outputs stored as {"name": ..., "ref": ...} dicts."""
121 source = """
122 @system pe=1, sm=1
123 #dual |> {
124 &g <| pass
125 }
126 &sink <| add
127 #dual |> body=&sink
128 """
129 graph = parse_and_lower(source)
130 assert not graph.errors
131 call = graph.macro_calls[0]
132 assert len(call.output_dests) >= 1
133 output = call.output_dests[0]
134 assert isinstance(output, dict)
135 assert output.get("name") == "body"
136
137
138class TestE33_ExpandRewritesPositionalRet:
139 """E3.3: Expand rewrites @ret edges to concrete destinations (positional)."""
140
141 def test_bare_ret_rewritten(self):
142 """@ret in macro body becomes concrete &sink after expansion."""
143 source = """
144 @system pe=1, sm=1
145 #simple val |> {
146 &g <| const, ${val}
147 &g |> @ret
148 }
149 &sink <| add
150 #simple 1 |> &sink
151 """
152 graph = parse_lower_expand(source)
153 assert not graph.errors
154 # Find the edge from the expanded &g to &sink
155 ret_edges = [
156 e for e in graph.edges
157 if e.dest == "&sink" and "#simple_0.&g" in e.source
158 ]
159 assert len(ret_edges) == 1
160
161 def test_bare_ret_with_named_output(self):
162 """@ret in macro body resolved via named output (no args needed)."""
163 source = """
164 @system pe=1, sm=1
165 #simple |> {
166 &g <| pass
167 &g |> @ret_out
168 }
169 &sink <| add
170 #simple |> out=&sink
171 """
172 graph = parse_lower_expand(source)
173 assert not graph.errors
174 ret_edges = [
175 e for e in graph.edges
176 if e.dest == "&sink" and "#simple_0.&g" in e.source
177 ]
178 assert len(ret_edges) == 1
179
180
181class TestE34_ExpandRewritesNamedRet:
182 """E3.4: Expand rewrites @ret_name edges to concrete destinations (named)."""
183
184 def test_named_ret_body(self):
185 """@ret_body in macro resolves to body=&body_sink output."""
186 source = """
187 @system pe=1, sm=1
188 #dual |> {
189 &g <| pass
190 &g |> @ret_body
191 &g |> @ret_exit:R
192 }
193 &body_sink <| add
194 &exit_sink <| add
195 #dual |> body=&body_sink, exit=&exit_sink
196 """
197 graph = parse_lower_expand(source)
198 assert not graph.errors
199 # Find edges from expanded #dual_0.&g
200 g_edges = [e for e in graph.edges if "#dual_0.&g" in e.source]
201 dests = {e.dest for e in g_edges}
202 assert "&body_sink" in dests
203 assert "&exit_sink" in dests
204
205 def test_mixed_named_outputs(self):
206 """Multiple @ret_name markers all resolve correctly."""
207 source = """
208 @system pe=1, sm=1
209 #triple |> {
210 &a <| pass
211 &b <| pass
212 &c <| pass
213 &a |> @ret_x
214 &b |> @ret_y
215 &c |> @ret_z
216 }
217 &x_sink <| add
218 &y_sink <| add
219 &z_sink <| add
220 #triple |> x=&x_sink, y=&y_sink, z=&z_sink
221 """
222 graph = parse_lower_expand(source)
223 assert not graph.errors
224 all_dests = {e.dest for e in graph.edges}
225 assert "&x_sink" in all_dests
226 assert "&y_sink" in all_dests
227 assert "&z_sink" in all_dests
228
229
230class TestE35_UnmatchedRetError:
231 """E3.5: Unmatched @ret marker produces MACRO error."""
232
233 def test_unmatched_named_ret(self):
234 """@ret_missing has no matching named output -> MACRO error."""
235 source = """
236 @system pe=1, sm=1
237 #bad |> {
238 &g <| pass
239 &g |> @ret_missing
240 }
241 &sink <| add
242 #bad |> wrong_name=&sink
243 """
244 graph = parse_lower_expand(source)
245 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
246 assert len(macro_errors) >= 1
247 assert "@ret_missing" in macro_errors[0].message
248
249 def test_extra_ret_marker(self):
250 """Macro has @ret_body and @ret_exit but call only provides body output."""
251 source = """
252 @system pe=1, sm=1
253 #dual |> {
254 &a <| pass
255 &b <| pass
256 &a |> @ret_body
257 &b |> @ret_exit
258 }
259 &sink <| add
260 #dual |> body=&sink
261 """
262 graph = parse_lower_expand(source)
263 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
264 assert len(macro_errors) >= 1
265 assert "@ret_exit" in macro_errors[0].message
266
267
268class TestE36_NoOutputWiringError:
269 """E3.6: Macro with @ret but no |> at call site produces MACRO error."""
270
271 def test_ret_without_output_wiring(self):
272 """Macro body uses @ret but call has no |> -> error."""
273 source = """
274 @system pe=1, sm=1
275 #needs_output |> {
276 &g <| pass
277 &g |> @ret
278 }
279 #needs_output
280 """
281 graph = parse_lower_expand(source)
282 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
283 assert len(macro_errors) >= 1
284 assert "output" in macro_errors[0].message.lower() or "@ret" in macro_errors[0].message
285
286
287 def test_bare_ret_with_only_named_outputs(self):
288 """Bare @ret with only named outputs at call site -> MACRO error."""
289 source = """
290 @system pe=1, sm=1
291 #test val |> {
292 &g <| const, ${val}
293 &g |> @ret
294 }
295 &sink <| add
296 #test 1 |> out=&sink
297 """
298 graph = parse_lower_expand(source)
299 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
300 assert len(macro_errors) >= 1
301 assert "@ret" in macro_errors[0].message
302
303
304class TestE37_MultipleRetMarkers:
305 """E3.7: Multiple @ret markers with mixed positional/named outputs."""
306
307 def test_named_and_positional_ret(self):
308 """Macro with @ret_main and @ret_extra resolves to named outputs."""
309 source = """
310 @system pe=1, sm=1
311 #mixed |> {
312 &a <| pass
313 &b <| pass
314 &a |> @ret_main
315 &b |> @ret_extra
316 }
317 &main_sink <| add
318 &extra_sink <| add
319 #mixed |> main=&main_sink, extra=&extra_sink
320 """
321 graph = parse_lower_expand(source)
322 assert not graph.errors
323 all_dests = {e.dest for e in graph.edges}
324 assert "&main_sink" in all_dests
325 assert "&extra_sink" in all_dests
326
327
328class TestE38_NestedMacroRet:
329 """E3.8: @ret in nested macro invocation."""
330
331 def test_nested_macro_ret_resolves(self):
332 """Inner macro's @ret resolves independently from outer macro's @ret."""
333 source = """
334 @system pe=1, sm=1
335 #inner val |> {
336 &i <| const, ${val}
337 &i |> @ret
338 }
339 #outer val |> {
340 &o <| pass
341 #inner ${val} |> &o
342 &o |> @ret_out
343 }
344 &final_sink <| add
345 #outer 1 |> out=&final_sink
346 """
347 graph = parse_lower_expand(source)
348 assert not graph.errors
349 # The outer @ret_out should resolve to &final_sink
350 # The inner @ret should resolve to #outer_0.&o
351 outer_ret_edges = [
352 e for e in graph.edges
353 if e.dest == "&final_sink"
354 ]
355 assert len(outer_ret_edges) >= 1
356
357 def test_inner_ret_failure_does_not_cascade_to_outer(self):
358 """Inner macro @ret failure should not cause spurious outer errors.
359
360 Uses named output (@ret_out / out=&sink) to disambiguate from plain_edge.
361 """
362 source = """
363 @system pe=1, sm=1
364 #inner val |> {
365 &i <| const, ${val}
366 &i |> @ret
367 }
368 #outer |> {
369 &o <| pass
370 #inner 1
371 &o |> @ret_out
372 }
373 &sink <| add
374 #outer |> out=&sink
375 """
376 graph = parse_lower_expand(source)
377 macro_errors = [e for e in graph.errors if e.category == ErrorCategory.MACRO]
378 assert len(macro_errors) == 1, f"Expected 1 error (inner #inner), got {len(macro_errors)}: {[e.message for e in macro_errors]}"
379 assert "#inner" in macro_errors[0].message
380 sink_edges = [e for e in graph.edges if e.dest == "&sink"]
381 assert len(sink_edges) >= 1