OR-1 dataflow CPU sketch
1"""Convert IRGraph to JSON-serialisable structure for the frontend.
2
3Produces a flat graph representation with all nodes, edges, regions,
4errors, and metadata needed for both logical and physical views.
5Synthesizes SM nodes and edges from MemOp instructions and data definitions.
6"""
7
8from __future__ import annotations
9
10from typing import Any
11
12from cm_inst import FrameDest, MemOp
13from asm.ir import (
14 IRNode, IREdge, IRGraph, IRRegion, RegionKind,
15 SourceLoc, ResolvedDest,
16 collect_all_nodes_and_edges, collect_all_data_defs,
17)
18from asm.errors import AssemblyError
19from asm.opcodes import OP_TO_MNEMONIC
20from dfgraph.pipeline import PipelineResult
21from dfgraph.categories import OpcodeCategory, CATEGORY_COLOURS
22
23
24SM_NODE_PREFIX = "__sm_"
25
26
27def _serialise_loc(loc: SourceLoc) -> dict[str, Any]:
28 return {
29 "line": loc.line,
30 "column": loc.column,
31 "end_line": loc.end_line,
32 "end_column": loc.end_column,
33 }
34
35
36def _serialise_frame_dest(dest: FrameDest) -> dict[str, Any]:
37 return {
38 "target_pe": dest.target_pe,
39 "offset": dest.offset,
40 "act_id": dest.act_id,
41 "port": dest.port.name,
42 "token_kind": dest.token_kind.name,
43 }
44
45
46def _serialise_node(node: IRNode, error_node_names: set[str]) -> dict[str, Any]:
47 from dfgraph.categories import categorise
48 category = categorise(node.opcode)
49 mnemonic = OP_TO_MNEMONIC[node.opcode]
50
51 return {
52 "id": node.name,
53 "opcode": mnemonic,
54 "category": category.value,
55 "colour": CATEGORY_COLOURS[category],
56 "const": node.const,
57 "pe": node.pe,
58 "iram_offset": node.iram_offset,
59 "act_id": node.act_id,
60 "has_error": node.name in error_node_names,
61 "loc": _serialise_loc(node.loc),
62 }
63
64
65def _serialise_edge(edge: IREdge, all_nodes: dict[str, IRNode],
66 error_lines: set[int]) -> dict[str, Any]:
67 result: dict[str, Any] = {
68 "source": edge.source,
69 "target": edge.dest,
70 "port": edge.port.name,
71 "source_port": edge.source_port.name if edge.source_port else None,
72 "has_error": edge.loc.line in error_lines,
73 }
74
75 source_node = all_nodes.get(edge.source)
76 if source_node:
77 if (isinstance(source_node.dest_l, ResolvedDest)
78 and source_node.dest_l.name == edge.dest
79 and source_node.dest_l.frame_dest is not None):
80 result["frame_dest"] = _serialise_frame_dest(source_node.dest_l.frame_dest)
81 elif (isinstance(source_node.dest_r, ResolvedDest)
82 and source_node.dest_r.name == edge.dest
83 and source_node.dest_r.frame_dest is not None):
84 result["frame_dest"] = _serialise_frame_dest(source_node.dest_r.frame_dest)
85
86 return result
87
88
89def _serialise_error(error: AssemblyError) -> dict[str, Any]:
90 return {
91 "line": error.loc.line,
92 "column": error.loc.column,
93 "category": error.category.value,
94 "message": error.message,
95 "suggestions": error.suggestions,
96 }
97
98
99def _serialise_region(region: IRRegion) -> dict[str, Any]:
100 node_ids = list(region.body.nodes.keys())
101 for sub_region in region.body.regions:
102 node_ids.extend(sub_region.body.nodes.keys())
103
104 return {
105 "tag": region.tag,
106 "kind": region.kind.value,
107 "node_ids": node_ids,
108 }
109
110
111def _collect_error_node_names(errors: list[AssemblyError],
112 all_nodes: dict[str, IRNode]) -> set[str]:
113 error_lines: set[int] = {e.loc.line for e in errors}
114 return {
115 name for name, node in all_nodes.items()
116 if node.loc.line in error_lines
117 }
118
119
120def _collect_referenced_sm_ids(
121 all_nodes: dict[str, IRNode],
122 graph: IRGraph,
123) -> set[int]:
124 """Collect SM IDs referenced by MemOp nodes or data definitions."""
125 sm_ids: set[int] = set()
126 for node in all_nodes.values():
127 if isinstance(node.opcode, MemOp) and node.sm_id is not None:
128 sm_ids.add(node.sm_id)
129 for data_def in collect_all_data_defs(graph):
130 if data_def.sm_id is not None:
131 sm_ids.add(data_def.sm_id)
132 return sm_ids
133
134
135def _build_sm_label(
136 sm_id: int,
137 all_nodes: dict[str, IRNode],
138 graph: IRGraph,
139) -> str:
140 """Build a label for an SM node showing referenced cell addresses."""
141 lines = [f"SM {sm_id}"]
142
143 # Collect cell addresses referenced by MemOp nodes targeting this SM
144 cell_ops: dict[int, list[str]] = {}
145 for node in all_nodes.values():
146 if isinstance(node.opcode, MemOp) and node.sm_id == sm_id and node.const is not None:
147 addr = node.const
148 mnemonic = OP_TO_MNEMONIC[node.opcode]
149 cell_ops.setdefault(addr, []).append(mnemonic)
150
151 # Collect data definitions for this SM
152 for data_def in collect_all_data_defs(graph):
153 if data_def.sm_id == sm_id and data_def.cell_addr is not None:
154 addr = data_def.cell_addr
155 cell_ops.setdefault(addr, []).append(f"init={data_def.value}")
156
157 for addr in sorted(cell_ops):
158 ops = ", ".join(cell_ops[addr])
159 lines.append(f"[{addr}] {ops}")
160
161 return "\n".join(lines)
162
163
164def _synthesize_sm_nodes(
165 sm_ids: set[int],
166 all_nodes: dict[str, IRNode],
167 graph: IRGraph,
168) -> list[dict[str, Any]]:
169 """Create synthetic graph nodes for each referenced SM instance."""
170 category = OpcodeCategory.STRUCTURE_MEMORY
171 return [
172 {
173 "id": f"{SM_NODE_PREFIX}{sm_id}",
174 "opcode": "sm",
175 "label": _build_sm_label(sm_id, all_nodes, graph),
176 "category": category.value,
177 "colour": CATEGORY_COLOURS[category],
178 "const": None,
179 "pe": None,
180 "iram_offset": None,
181 "act_id": None,
182 "has_error": False,
183 "loc": {"line": 0, "column": 0, "end_line": None, "end_column": None},
184 "sm_id": sm_id,
185 "synthetic": True,
186 }
187 for sm_id in sorted(sm_ids)
188 ]
189
190
191def _synthesize_sm_edges(
192 all_nodes: dict[str, IRNode],
193) -> list[dict[str, Any]]:
194 """Create synthetic edges between MemOp nodes and their target SM nodes.
195
196 Produces:
197 - Request edge: MemOp node → SM node (the memory operation request)
198 - Return edge: SM node → destination node (if a return route exists)
199 """
200 edges: list[dict[str, Any]] = []
201 for node in all_nodes.values():
202 if not isinstance(node.opcode, MemOp) or node.sm_id is None:
203 continue
204
205 sm_node_id = f"{SM_NODE_PREFIX}{node.sm_id}"
206
207 # Request edge: instruction → SM
208 edges.append({
209 "source": node.name,
210 "target": sm_node_id,
211 "port": "REQ",
212 "source_port": None,
213 "has_error": False,
214 "synthetic": True,
215 })
216
217 # Return edge: SM → requesting node (data flows back to the reader)
218 if isinstance(node.dest_l, ResolvedDest):
219 edges.append({
220 "source": sm_node_id,
221 "target": node.name,
222 "port": "RET",
223 "source_port": None,
224 "has_error": False,
225 "synthetic": True,
226 })
227
228 return edges
229
230
231def graph_to_json(result: PipelineResult) -> dict[str, Any]:
232 if result.graph is None:
233 return {
234 "type": "graph_update",
235 "stage": result.stage.value,
236 "nodes": [],
237 "edges": [],
238 "regions": [],
239 "errors": [],
240 "parse_error": result.parse_error,
241 "metadata": {
242 "stage": result.stage.value,
243 "pe_count": 0,
244 "sm_count": 0,
245 },
246 }
247
248 graph = result.graph
249 all_nodes, all_edges = collect_all_nodes_and_edges(graph)
250 error_lines: set[int] = {e.loc.line for e in result.errors}
251 error_node_names = _collect_error_node_names(result.errors, all_nodes)
252
253 nodes_json = [
254 _serialise_node(node, error_node_names)
255 for node in all_nodes.values()
256 ]
257
258 edges_json = [
259 _serialise_edge(edge, all_nodes, error_lines)
260 for edge in all_edges
261 ]
262
263 # Synthesize SM nodes and edges
264 sm_ids = _collect_referenced_sm_ids(all_nodes, graph)
265 nodes_json.extend(_synthesize_sm_nodes(sm_ids, all_nodes, graph))
266 edges_json.extend(_synthesize_sm_edges(all_nodes))
267
268 regions_json = []
269 for region in graph.regions:
270 if region.kind == RegionKind.FUNCTION:
271 regions_json.append(_serialise_region(region))
272
273 errors_json = [_serialise_error(e) for e in result.errors]
274
275 pe_count = graph.system.pe_count if graph.system else 0
276 sm_count = graph.system.sm_count if graph.system else 0
277
278 return {
279 "type": "graph_update",
280 "stage": result.stage.value,
281 "nodes": nodes_json,
282 "edges": edges_json,
283 "regions": regions_json,
284 "errors": errors_json,
285 "parse_error": None,
286 "metadata": {
287 "stage": result.stage.value,
288 "pe_count": pe_count,
289 "sm_count": sm_count,
290 },
291 }