OR-1 dataflow CPU sketch
at main 381 lines 15 kB view raw
1"""FastAPI server for the OR1 Monitor with bidirectional WebSocket protocol. 2 3Exposes the SimulationBackend to a browser-based frontend via: 4- WebSocket /ws for bidirectional command/result communication 5- REST endpoints for non-WebSocket clients 6- Static file serving for the frontend 7""" 8 9from __future__ import annotations 10 11import asyncio 12import json 13import logging 14from contextlib import asynccontextmanager 15from pathlib import Path 16 17from fastapi import FastAPI, WebSocket, WebSocketDisconnect 18from fastapi.staticfiles import StaticFiles 19 20from cm_inst import MemOp, Port 21from asm.ir import IRGraph 22from monitor.backend import SimulationBackend 23from monitor.commands import ( 24 ErrorResult, GraphLoaded, LoadCmd, ResetCmd, RunUntilCmd, 25 SendCmd, InjectCmd, StepEventCmd, StepResult, StepTickCmd, 26) 27from monitor.graph_json import graph_loaded_json, graph_to_monitor_json 28from tokens import DyadToken, SMToken 29 30logger = logging.getLogger(__name__) 31 32 33class ConnectionManager: 34 """Manages WebSocket connections and broadcasts to all connected clients.""" 35 36 def __init__(self) -> None: 37 self.active_connections: list[WebSocket] = [] 38 39 async def connect(self, websocket: WebSocket) -> None: 40 """Accept and register a new WebSocket connection.""" 41 await websocket.accept() 42 self.active_connections.append(websocket) 43 44 def disconnect(self, websocket: WebSocket) -> None: 45 """Unregister a disconnected WebSocket.""" 46 try: 47 self.active_connections.remove(websocket) 48 except ValueError: 49 # Already removed 50 pass 51 52 async def broadcast(self, message: dict) -> None: 53 """Broadcast a message to all connected clients.""" 54 disconnected: list[WebSocket] = [] 55 for connection in self.active_connections: 56 try: 57 await connection.send_json(message) 58 except Exception: 59 disconnected.append(connection) 60 for conn in disconnected: 61 self.disconnect(conn) 62 63 64def create_app(backend: SimulationBackend) -> FastAPI: 65 """Create and configure the FastAPI application. 66 67 Args: 68 backend: SimulationBackend instance (shared with REPL if both running) 69 70 Returns: 71 Configured FastAPI application 72 """ 73 manager = ConnectionManager() 74 current_graph_loaded: dict | None = None 75 current_ir_graph = None 76 77 @asynccontextmanager 78 async def lifespan(app: FastAPI): 79 """Lifespan context manager for startup/shutdown.""" 80 yield 81 82 app = FastAPI(lifespan=lifespan) 83 84 @app.websocket("/ws") 85 async def websocket_endpoint(websocket: WebSocket) -> None: 86 """WebSocket endpoint for bidirectional command/result communication. 87 88 On connect, sends the current graph state if loaded. 89 Receives JSON commands and dispatches to backend. 90 Broadcasts results to all connected clients. 91 """ 92 nonlocal current_graph_loaded, current_ir_graph 93 await manager.connect(websocket) 94 try: 95 # Send initial state if available 96 if current_graph_loaded is not None: 97 await websocket.send_json(current_graph_loaded) 98 99 # Receive loop: parse commands and dispatch to backend 100 while True: 101 data = await websocket.receive_text() 102 cmd_json = json.loads(data) 103 cmd_type = cmd_json.get("cmd") 104 105 # Dispatch command to backend in thread executor 106 loop = asyncio.get_running_loop() 107 result = None 108 109 try: 110 if cmd_type == "load": 111 source = cmd_json.get("source", "") 112 result = await loop.run_in_executor( 113 None, backend.send_command, 114 LoadCmd(source=source), 10.0 115 ) 116 117 elif cmd_type == "load_file": 118 path = cmd_json.get("path", "") 119 try: 120 source = Path(path).read_text() 121 result = await loop.run_in_executor( 122 None, backend.send_command, 123 LoadCmd(source=source), 10.0 124 ) 125 except Exception as e: 126 result = ErrorResult(message=f"Failed to read file: {e}") 127 128 elif cmd_type == "step_tick": 129 result = await loop.run_in_executor( 130 None, backend.send_command, 131 StepTickCmd(), 10.0 132 ) 133 134 elif cmd_type == "step_event": 135 result = await loop.run_in_executor( 136 None, backend.send_command, 137 StepEventCmd(), 10.0 138 ) 139 140 elif cmd_type == "run_until": 141 until = cmd_json.get("until", 0.0) 142 result = await loop.run_in_executor( 143 None, backend.send_command, 144 RunUntilCmd(until=until), 30.0 145 ) 146 147 elif cmd_type == "inject": 148 # Construct a CMToken from command data 149 target = cmd_json.get("target", 0) 150 offset = cmd_json.get("offset", 0) 151 act_id = cmd_json.get("act_id", 0) 152 data = cmd_json.get("data", 0) 153 token = DyadToken( 154 target=target, offset=offset, act_id=act_id, data=data, 155 port=Port.L 156 ) 157 result = await loop.run_in_executor( 158 None, backend.send_command, 159 InjectCmd(token=token), 10.0 160 ) 161 162 elif cmd_type == "send": 163 # Same as inject but respects backpressure 164 target = cmd_json.get("target", 0) 165 offset = cmd_json.get("offset", 0) 166 act_id = cmd_json.get("act_id", 0) 167 data = cmd_json.get("data", 0) 168 token = DyadToken( 169 target=target, offset=offset, act_id=act_id, data=data, 170 port=Port.L 171 ) 172 result = await loop.run_in_executor( 173 None, backend.send_command, 174 SendCmd(token=token), 10.0 175 ) 176 177 elif cmd_type == "inject_sm": 178 # Inject an SM token 179 target = cmd_json.get("target", 0) 180 addr = cmd_json.get("addr", 0) 181 op_name = cmd_json.get("op", "READ") 182 data = cmd_json.get("data", 0) 183 op = MemOp[op_name] 184 token = SMToken( 185 target=target, addr=addr, op=op, flags=0, data=data, ret=None 186 ) 187 result = await loop.run_in_executor( 188 None, backend.send_command, 189 InjectCmd(token=token), 10.0 190 ) 191 192 elif cmd_type == "send_sm": 193 # Send an SM token with backpressure 194 target = cmd_json.get("target", 0) 195 addr = cmd_json.get("addr", 0) 196 op_name = cmd_json.get("op", "READ") 197 data = cmd_json.get("data", 0) 198 op = MemOp[op_name] 199 token = SMToken( 200 target=target, addr=addr, op=op, flags=0, data=data, ret=None 201 ) 202 result = await loop.run_in_executor( 203 None, backend.send_command, 204 SendCmd(token=token), 10.0 205 ) 206 207 elif cmd_type == "reset": 208 reload = cmd_json.get("reload", False) 209 result = await loop.run_in_executor( 210 None, backend.send_command, 211 ResetCmd(reload=reload), 10.0 212 ) 213 214 else: 215 result = ErrorResult(message=f"Unknown command: {cmd_type}") 216 217 except Exception as e: 218 logger.exception("Error processing command %s", cmd_type) 219 result = ErrorResult(message=str(e)) 220 221 # Convert result to JSON and broadcast 222 if result is not None: 223 try: 224 is_reset = cmd_type == "reset" 225 response = _result_to_json(result, current_ir_graph, is_reset) 226 except Exception as e: 227 logger.exception("Error serializing result") 228 response = { 229 "type": "error", 230 "message": f"Serialization error: {e}", 231 } 232 await manager.broadcast(response) 233 234 # Update current state if this was a GraphLoaded result 235 if isinstance(result, GraphLoaded): 236 current_ir_graph = result.ir_graph 237 current_graph_loaded = graph_loaded_json( 238 result.ir_graph, result.snapshot 239 ) 240 # If this was a reset without reload, clear cached state 241 elif isinstance(result, StepResult) and cmd_type == "reset": 242 reload = cmd_json.get("reload", False) 243 if not reload: 244 current_ir_graph = None 245 current_graph_loaded = None 246 247 except WebSocketDisconnect: 248 manager.disconnect(websocket) 249 except Exception as e: 250 logger.exception("WebSocket error") 251 manager.disconnect(websocket) 252 253 @app.post("/load") 254 async def load_endpoint(request_data: dict) -> dict: 255 """REST endpoint for loading a program. 256 257 Body: {"source": "...dfasm source..."} 258 """ 259 nonlocal current_graph_loaded, current_ir_graph 260 source = request_data.get("source", "") 261 loop = asyncio.get_running_loop() 262 result = await loop.run_in_executor( 263 None, backend.send_command, 264 LoadCmd(source=source), 10.0 265 ) 266 try: 267 response = _result_to_json(result, current_ir_graph, is_reset=False) 268 except Exception as e: 269 logger.exception("Error serializing result") 270 response = { 271 "type": "error", 272 "message": f"Serialization error: {e}", 273 } 274 if isinstance(result, GraphLoaded): 275 current_ir_graph = result.ir_graph 276 current_graph_loaded = graph_loaded_json( 277 result.ir_graph, result.snapshot 278 ) 279 await manager.broadcast(response) 280 return response 281 282 @app.post("/reset") 283 async def reset_endpoint(request_data: dict | None = None) -> dict: 284 """REST endpoint for reset. 285 286 Body: {"reload": false} 287 """ 288 nonlocal current_graph_loaded, current_ir_graph 289 reload = request_data.get("reload", False) if request_data else False 290 loop = asyncio.get_running_loop() 291 result = await loop.run_in_executor( 292 None, backend.send_command, 293 ResetCmd(reload=reload), 10.0 294 ) 295 try: 296 response = _result_to_json(result, current_ir_graph, is_reset=True) 297 except Exception as e: 298 logger.exception("Error serializing result") 299 response = { 300 "type": "error", 301 "message": f"Serialization error: {e}", 302 } 303 if isinstance(result, GraphLoaded): 304 current_ir_graph = result.ir_graph 305 current_graph_loaded = graph_loaded_json( 306 result.ir_graph, result.snapshot 307 ) 308 # If reset without reload, clear cached state 309 elif isinstance(result, StepResult) and not reload: 310 current_ir_graph = None 311 current_graph_loaded = None 312 await manager.broadcast(response) 313 return response 314 315 @app.get("/state") 316 async def state_endpoint() -> dict: 317 """REST endpoint for current state snapshot.""" 318 if current_graph_loaded is None: 319 return {"error": "No program loaded"} 320 return current_graph_loaded 321 322 # Mount static files for frontend 323 frontend_dir = Path(__file__).parent / "frontend" 324 if (frontend_dir / "dist").exists(): 325 app.mount( 326 "/dist", 327 StaticFiles(directory=str(frontend_dir / "dist")), 328 name="dist", 329 ) 330 if frontend_dir.exists(): 331 app.mount( 332 "/", 333 StaticFiles(directory=str(frontend_dir), html=True), 334 name="frontend", 335 ) 336 337 return app 338 339 340def _result_to_json(result: GraphLoaded | StepResult | ErrorResult, ir_graph: IRGraph | None, is_reset: bool = False) -> dict: 341 """Convert a command result to JSON for transmission. 342 343 Args: 344 result: Result dataclass from backend 345 ir_graph: Current IRGraph (for context in responses) 346 is_reset: True if this result is from a reset command 347 348 Returns: 349 JSON-serializable dict 350 """ 351 if isinstance(result, GraphLoaded): 352 return graph_loaded_json(result.ir_graph, result.snapshot) 353 354 elif isinstance(result, StepResult): 355 # Handle reset without reload: snapshot is None but this is not an error 356 if result.snapshot is None: 357 if is_reset and result.finished and result.sim_time == 0.0: 358 return { 359 "type": "reset", 360 "sim_time": 0.0, 361 "message": "Simulation reset", 362 } 363 else: 364 return { 365 "type": "error", 366 "message": "No program loaded", 367 } 368 return graph_to_monitor_json(ir_graph, result.snapshot, list(result.events)) 369 370 elif isinstance(result, ErrorResult): 371 return { 372 "type": "error", 373 "message": result.message, 374 "errors": result.errors, 375 } 376 377 else: 378 return { 379 "type": "error", 380 "message": f"Unknown result type: {type(result).__name__}", 381 }