OR-1 dataflow CPU sketch
1"""FastAPI server for the OR1 Monitor with bidirectional WebSocket protocol.
2
3Exposes the SimulationBackend to a browser-based frontend via:
4- WebSocket /ws for bidirectional command/result communication
5- REST endpoints for non-WebSocket clients
6- Static file serving for the frontend
7"""
8
9from __future__ import annotations
10
11import asyncio
12import json
13import logging
14from contextlib import asynccontextmanager
15from pathlib import Path
16
17from fastapi import FastAPI, WebSocket, WebSocketDisconnect
18from fastapi.staticfiles import StaticFiles
19
20from cm_inst import MemOp, Port
21from asm.ir import IRGraph
22from monitor.backend import SimulationBackend
23from monitor.commands import (
24 ErrorResult, GraphLoaded, LoadCmd, ResetCmd, RunUntilCmd,
25 SendCmd, InjectCmd, StepEventCmd, StepResult, StepTickCmd,
26)
27from monitor.graph_json import graph_loaded_json, graph_to_monitor_json
28from tokens import DyadToken, SMToken
29
30logger = logging.getLogger(__name__)
31
32
33class ConnectionManager:
34 """Manages WebSocket connections and broadcasts to all connected clients."""
35
36 def __init__(self) -> None:
37 self.active_connections: list[WebSocket] = []
38
39 async def connect(self, websocket: WebSocket) -> None:
40 """Accept and register a new WebSocket connection."""
41 await websocket.accept()
42 self.active_connections.append(websocket)
43
44 def disconnect(self, websocket: WebSocket) -> None:
45 """Unregister a disconnected WebSocket."""
46 try:
47 self.active_connections.remove(websocket)
48 except ValueError:
49 # Already removed
50 pass
51
52 async def broadcast(self, message: dict) -> None:
53 """Broadcast a message to all connected clients."""
54 disconnected: list[WebSocket] = []
55 for connection in self.active_connections:
56 try:
57 await connection.send_json(message)
58 except Exception:
59 disconnected.append(connection)
60 for conn in disconnected:
61 self.disconnect(conn)
62
63
64def create_app(backend: SimulationBackend) -> FastAPI:
65 """Create and configure the FastAPI application.
66
67 Args:
68 backend: SimulationBackend instance (shared with REPL if both running)
69
70 Returns:
71 Configured FastAPI application
72 """
73 manager = ConnectionManager()
74 current_graph_loaded: dict | None = None
75 current_ir_graph = None
76
77 @asynccontextmanager
78 async def lifespan(app: FastAPI):
79 """Lifespan context manager for startup/shutdown."""
80 yield
81
82 app = FastAPI(lifespan=lifespan)
83
84 @app.websocket("/ws")
85 async def websocket_endpoint(websocket: WebSocket) -> None:
86 """WebSocket endpoint for bidirectional command/result communication.
87
88 On connect, sends the current graph state if loaded.
89 Receives JSON commands and dispatches to backend.
90 Broadcasts results to all connected clients.
91 """
92 nonlocal current_graph_loaded, current_ir_graph
93 await manager.connect(websocket)
94 try:
95 # Send initial state if available
96 if current_graph_loaded is not None:
97 await websocket.send_json(current_graph_loaded)
98
99 # Receive loop: parse commands and dispatch to backend
100 while True:
101 data = await websocket.receive_text()
102 cmd_json = json.loads(data)
103 cmd_type = cmd_json.get("cmd")
104
105 # Dispatch command to backend in thread executor
106 loop = asyncio.get_running_loop()
107 result = None
108
109 try:
110 if cmd_type == "load":
111 source = cmd_json.get("source", "")
112 result = await loop.run_in_executor(
113 None, backend.send_command,
114 LoadCmd(source=source), 10.0
115 )
116
117 elif cmd_type == "load_file":
118 path = cmd_json.get("path", "")
119 try:
120 source = Path(path).read_text()
121 result = await loop.run_in_executor(
122 None, backend.send_command,
123 LoadCmd(source=source), 10.0
124 )
125 except Exception as e:
126 result = ErrorResult(message=f"Failed to read file: {e}")
127
128 elif cmd_type == "step_tick":
129 result = await loop.run_in_executor(
130 None, backend.send_command,
131 StepTickCmd(), 10.0
132 )
133
134 elif cmd_type == "step_event":
135 result = await loop.run_in_executor(
136 None, backend.send_command,
137 StepEventCmd(), 10.0
138 )
139
140 elif cmd_type == "run_until":
141 until = cmd_json.get("until", 0.0)
142 result = await loop.run_in_executor(
143 None, backend.send_command,
144 RunUntilCmd(until=until), 30.0
145 )
146
147 elif cmd_type == "inject":
148 # Construct a CMToken from command data
149 target = cmd_json.get("target", 0)
150 offset = cmd_json.get("offset", 0)
151 act_id = cmd_json.get("act_id", 0)
152 data = cmd_json.get("data", 0)
153 token = DyadToken(
154 target=target, offset=offset, act_id=act_id, data=data,
155 port=Port.L
156 )
157 result = await loop.run_in_executor(
158 None, backend.send_command,
159 InjectCmd(token=token), 10.0
160 )
161
162 elif cmd_type == "send":
163 # Same as inject but respects backpressure
164 target = cmd_json.get("target", 0)
165 offset = cmd_json.get("offset", 0)
166 act_id = cmd_json.get("act_id", 0)
167 data = cmd_json.get("data", 0)
168 token = DyadToken(
169 target=target, offset=offset, act_id=act_id, data=data,
170 port=Port.L
171 )
172 result = await loop.run_in_executor(
173 None, backend.send_command,
174 SendCmd(token=token), 10.0
175 )
176
177 elif cmd_type == "inject_sm":
178 # Inject an SM token
179 target = cmd_json.get("target", 0)
180 addr = cmd_json.get("addr", 0)
181 op_name = cmd_json.get("op", "READ")
182 data = cmd_json.get("data", 0)
183 op = MemOp[op_name]
184 token = SMToken(
185 target=target, addr=addr, op=op, flags=0, data=data, ret=None
186 )
187 result = await loop.run_in_executor(
188 None, backend.send_command,
189 InjectCmd(token=token), 10.0
190 )
191
192 elif cmd_type == "send_sm":
193 # Send an SM token with backpressure
194 target = cmd_json.get("target", 0)
195 addr = cmd_json.get("addr", 0)
196 op_name = cmd_json.get("op", "READ")
197 data = cmd_json.get("data", 0)
198 op = MemOp[op_name]
199 token = SMToken(
200 target=target, addr=addr, op=op, flags=0, data=data, ret=None
201 )
202 result = await loop.run_in_executor(
203 None, backend.send_command,
204 SendCmd(token=token), 10.0
205 )
206
207 elif cmd_type == "reset":
208 reload = cmd_json.get("reload", False)
209 result = await loop.run_in_executor(
210 None, backend.send_command,
211 ResetCmd(reload=reload), 10.0
212 )
213
214 else:
215 result = ErrorResult(message=f"Unknown command: {cmd_type}")
216
217 except Exception as e:
218 logger.exception("Error processing command %s", cmd_type)
219 result = ErrorResult(message=str(e))
220
221 # Convert result to JSON and broadcast
222 if result is not None:
223 try:
224 is_reset = cmd_type == "reset"
225 response = _result_to_json(result, current_ir_graph, is_reset)
226 except Exception as e:
227 logger.exception("Error serializing result")
228 response = {
229 "type": "error",
230 "message": f"Serialization error: {e}",
231 }
232 await manager.broadcast(response)
233
234 # Update current state if this was a GraphLoaded result
235 if isinstance(result, GraphLoaded):
236 current_ir_graph = result.ir_graph
237 current_graph_loaded = graph_loaded_json(
238 result.ir_graph, result.snapshot
239 )
240 # If this was a reset without reload, clear cached state
241 elif isinstance(result, StepResult) and cmd_type == "reset":
242 reload = cmd_json.get("reload", False)
243 if not reload:
244 current_ir_graph = None
245 current_graph_loaded = None
246
247 except WebSocketDisconnect:
248 manager.disconnect(websocket)
249 except Exception as e:
250 logger.exception("WebSocket error")
251 manager.disconnect(websocket)
252
253 @app.post("/load")
254 async def load_endpoint(request_data: dict) -> dict:
255 """REST endpoint for loading a program.
256
257 Body: {"source": "...dfasm source..."}
258 """
259 nonlocal current_graph_loaded, current_ir_graph
260 source = request_data.get("source", "")
261 loop = asyncio.get_running_loop()
262 result = await loop.run_in_executor(
263 None, backend.send_command,
264 LoadCmd(source=source), 10.0
265 )
266 try:
267 response = _result_to_json(result, current_ir_graph, is_reset=False)
268 except Exception as e:
269 logger.exception("Error serializing result")
270 response = {
271 "type": "error",
272 "message": f"Serialization error: {e}",
273 }
274 if isinstance(result, GraphLoaded):
275 current_ir_graph = result.ir_graph
276 current_graph_loaded = graph_loaded_json(
277 result.ir_graph, result.snapshot
278 )
279 await manager.broadcast(response)
280 return response
281
282 @app.post("/reset")
283 async def reset_endpoint(request_data: dict | None = None) -> dict:
284 """REST endpoint for reset.
285
286 Body: {"reload": false}
287 """
288 nonlocal current_graph_loaded, current_ir_graph
289 reload = request_data.get("reload", False) if request_data else False
290 loop = asyncio.get_running_loop()
291 result = await loop.run_in_executor(
292 None, backend.send_command,
293 ResetCmd(reload=reload), 10.0
294 )
295 try:
296 response = _result_to_json(result, current_ir_graph, is_reset=True)
297 except Exception as e:
298 logger.exception("Error serializing result")
299 response = {
300 "type": "error",
301 "message": f"Serialization error: {e}",
302 }
303 if isinstance(result, GraphLoaded):
304 current_ir_graph = result.ir_graph
305 current_graph_loaded = graph_loaded_json(
306 result.ir_graph, result.snapshot
307 )
308 # If reset without reload, clear cached state
309 elif isinstance(result, StepResult) and not reload:
310 current_ir_graph = None
311 current_graph_loaded = None
312 await manager.broadcast(response)
313 return response
314
315 @app.get("/state")
316 async def state_endpoint() -> dict:
317 """REST endpoint for current state snapshot."""
318 if current_graph_loaded is None:
319 return {"error": "No program loaded"}
320 return current_graph_loaded
321
322 # Mount static files for frontend
323 frontend_dir = Path(__file__).parent / "frontend"
324 if (frontend_dir / "dist").exists():
325 app.mount(
326 "/dist",
327 StaticFiles(directory=str(frontend_dir / "dist")),
328 name="dist",
329 )
330 if frontend_dir.exists():
331 app.mount(
332 "/",
333 StaticFiles(directory=str(frontend_dir), html=True),
334 name="frontend",
335 )
336
337 return app
338
339
340def _result_to_json(result: GraphLoaded | StepResult | ErrorResult, ir_graph: IRGraph | None, is_reset: bool = False) -> dict:
341 """Convert a command result to JSON for transmission.
342
343 Args:
344 result: Result dataclass from backend
345 ir_graph: Current IRGraph (for context in responses)
346 is_reset: True if this result is from a reset command
347
348 Returns:
349 JSON-serializable dict
350 """
351 if isinstance(result, GraphLoaded):
352 return graph_loaded_json(result.ir_graph, result.snapshot)
353
354 elif isinstance(result, StepResult):
355 # Handle reset without reload: snapshot is None but this is not an error
356 if result.snapshot is None:
357 if is_reset and result.finished and result.sim_time == 0.0:
358 return {
359 "type": "reset",
360 "sim_time": 0.0,
361 "message": "Simulation reset",
362 }
363 else:
364 return {
365 "type": "error",
366 "message": "No program loaded",
367 }
368 return graph_to_monitor_json(ir_graph, result.snapshot, list(result.events))
369
370 elif isinstance(result, ErrorResult):
371 return {
372 "type": "error",
373 "message": result.message,
374 "errors": result.errors,
375 }
376
377 else:
378 return {
379 "type": "error",
380 "message": f"Unknown result type: {type(result).__name__}",
381 }