"""FastAPI server for the OR1 Monitor with bidirectional WebSocket protocol. Exposes the SimulationBackend to a browser-based frontend via: - WebSocket /ws for bidirectional command/result communication - REST endpoints for non-WebSocket clients - Static file serving for the frontend """ from __future__ import annotations import asyncio import json import logging from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI, WebSocket, WebSocketDisconnect from fastapi.staticfiles import StaticFiles from cm_inst import MemOp, Port from asm.ir import IRGraph from monitor.backend import SimulationBackend from monitor.commands import ( ErrorResult, GraphLoaded, LoadCmd, ResetCmd, RunUntilCmd, SendCmd, InjectCmd, StepEventCmd, StepResult, StepTickCmd, ) from monitor.graph_json import graph_loaded_json, graph_to_monitor_json from tokens import DyadToken, SMToken logger = logging.getLogger(__name__) class ConnectionManager: """Manages WebSocket connections and broadcasts to all connected clients.""" def __init__(self) -> None: self.active_connections: list[WebSocket] = [] async def connect(self, websocket: WebSocket) -> None: """Accept and register a new WebSocket connection.""" await websocket.accept() self.active_connections.append(websocket) def disconnect(self, websocket: WebSocket) -> None: """Unregister a disconnected WebSocket.""" try: self.active_connections.remove(websocket) except ValueError: # Already removed pass async def broadcast(self, message: dict) -> None: """Broadcast a message to all connected clients.""" disconnected: list[WebSocket] = [] for connection in self.active_connections: try: await connection.send_json(message) except Exception: disconnected.append(connection) for conn in disconnected: self.disconnect(conn) def create_app(backend: SimulationBackend) -> FastAPI: """Create and configure the FastAPI application. Args: backend: SimulationBackend instance (shared with REPL if both running) Returns: Configured FastAPI application """ manager = ConnectionManager() current_graph_loaded: dict | None = None current_ir_graph = None @asynccontextmanager async def lifespan(app: FastAPI): """Lifespan context manager for startup/shutdown.""" yield app = FastAPI(lifespan=lifespan) @app.websocket("/ws") async def websocket_endpoint(websocket: WebSocket) -> None: """WebSocket endpoint for bidirectional command/result communication. On connect, sends the current graph state if loaded. Receives JSON commands and dispatches to backend. Broadcasts results to all connected clients. """ nonlocal current_graph_loaded, current_ir_graph await manager.connect(websocket) try: # Send initial state if available if current_graph_loaded is not None: await websocket.send_json(current_graph_loaded) # Receive loop: parse commands and dispatch to backend while True: data = await websocket.receive_text() cmd_json = json.loads(data) cmd_type = cmd_json.get("cmd") # Dispatch command to backend in thread executor loop = asyncio.get_running_loop() result = None try: if cmd_type == "load": source = cmd_json.get("source", "") result = await loop.run_in_executor( None, backend.send_command, LoadCmd(source=source), 10.0 ) elif cmd_type == "load_file": path = cmd_json.get("path", "") try: source = Path(path).read_text() result = await loop.run_in_executor( None, backend.send_command, LoadCmd(source=source), 10.0 ) except Exception as e: result = ErrorResult(message=f"Failed to read file: {e}") elif cmd_type == "step_tick": result = await loop.run_in_executor( None, backend.send_command, StepTickCmd(), 10.0 ) elif cmd_type == "step_event": result = await loop.run_in_executor( None, backend.send_command, StepEventCmd(), 10.0 ) elif cmd_type == "run_until": until = cmd_json.get("until", 0.0) result = await loop.run_in_executor( None, backend.send_command, RunUntilCmd(until=until), 30.0 ) elif cmd_type == "inject": # Construct a CMToken from command data target = cmd_json.get("target", 0) offset = cmd_json.get("offset", 0) act_id = cmd_json.get("act_id", 0) data = cmd_json.get("data", 0) token = DyadToken( target=target, offset=offset, act_id=act_id, data=data, port=Port.L ) result = await loop.run_in_executor( None, backend.send_command, InjectCmd(token=token), 10.0 ) elif cmd_type == "send": # Same as inject but respects backpressure target = cmd_json.get("target", 0) offset = cmd_json.get("offset", 0) act_id = cmd_json.get("act_id", 0) data = cmd_json.get("data", 0) token = DyadToken( target=target, offset=offset, act_id=act_id, data=data, port=Port.L ) result = await loop.run_in_executor( None, backend.send_command, SendCmd(token=token), 10.0 ) elif cmd_type == "inject_sm": # Inject an SM token target = cmd_json.get("target", 0) addr = cmd_json.get("addr", 0) op_name = cmd_json.get("op", "READ") data = cmd_json.get("data", 0) op = MemOp[op_name] token = SMToken( target=target, addr=addr, op=op, flags=0, data=data, ret=None ) result = await loop.run_in_executor( None, backend.send_command, InjectCmd(token=token), 10.0 ) elif cmd_type == "send_sm": # Send an SM token with backpressure target = cmd_json.get("target", 0) addr = cmd_json.get("addr", 0) op_name = cmd_json.get("op", "READ") data = cmd_json.get("data", 0) op = MemOp[op_name] token = SMToken( target=target, addr=addr, op=op, flags=0, data=data, ret=None ) result = await loop.run_in_executor( None, backend.send_command, SendCmd(token=token), 10.0 ) elif cmd_type == "reset": reload = cmd_json.get("reload", False) result = await loop.run_in_executor( None, backend.send_command, ResetCmd(reload=reload), 10.0 ) else: result = ErrorResult(message=f"Unknown command: {cmd_type}") except Exception as e: logger.exception("Error processing command %s", cmd_type) result = ErrorResult(message=str(e)) # Convert result to JSON and broadcast if result is not None: try: is_reset = cmd_type == "reset" response = _result_to_json(result, current_ir_graph, is_reset) except Exception as e: logger.exception("Error serializing result") response = { "type": "error", "message": f"Serialization error: {e}", } await manager.broadcast(response) # Update current state if this was a GraphLoaded result if isinstance(result, GraphLoaded): current_ir_graph = result.ir_graph current_graph_loaded = graph_loaded_json( result.ir_graph, result.snapshot ) # If this was a reset without reload, clear cached state elif isinstance(result, StepResult) and cmd_type == "reset": reload = cmd_json.get("reload", False) if not reload: current_ir_graph = None current_graph_loaded = None except WebSocketDisconnect: manager.disconnect(websocket) except Exception as e: logger.exception("WebSocket error") manager.disconnect(websocket) @app.post("/load") async def load_endpoint(request_data: dict) -> dict: """REST endpoint for loading a program. Body: {"source": "...dfasm source..."} """ nonlocal current_graph_loaded, current_ir_graph source = request_data.get("source", "") loop = asyncio.get_running_loop() result = await loop.run_in_executor( None, backend.send_command, LoadCmd(source=source), 10.0 ) try: response = _result_to_json(result, current_ir_graph, is_reset=False) except Exception as e: logger.exception("Error serializing result") response = { "type": "error", "message": f"Serialization error: {e}", } if isinstance(result, GraphLoaded): current_ir_graph = result.ir_graph current_graph_loaded = graph_loaded_json( result.ir_graph, result.snapshot ) await manager.broadcast(response) return response @app.post("/reset") async def reset_endpoint(request_data: dict | None = None) -> dict: """REST endpoint for reset. Body: {"reload": false} """ nonlocal current_graph_loaded, current_ir_graph reload = request_data.get("reload", False) if request_data else False loop = asyncio.get_running_loop() result = await loop.run_in_executor( None, backend.send_command, ResetCmd(reload=reload), 10.0 ) try: response = _result_to_json(result, current_ir_graph, is_reset=True) except Exception as e: logger.exception("Error serializing result") response = { "type": "error", "message": f"Serialization error: {e}", } if isinstance(result, GraphLoaded): current_ir_graph = result.ir_graph current_graph_loaded = graph_loaded_json( result.ir_graph, result.snapshot ) # If reset without reload, clear cached state elif isinstance(result, StepResult) and not reload: current_ir_graph = None current_graph_loaded = None await manager.broadcast(response) return response @app.get("/state") async def state_endpoint() -> dict: """REST endpoint for current state snapshot.""" if current_graph_loaded is None: return {"error": "No program loaded"} return current_graph_loaded # Mount static files for frontend frontend_dir = Path(__file__).parent / "frontend" if (frontend_dir / "dist").exists(): app.mount( "/dist", StaticFiles(directory=str(frontend_dir / "dist")), name="dist", ) if frontend_dir.exists(): app.mount( "/", StaticFiles(directory=str(frontend_dir), html=True), name="frontend", ) return app def _result_to_json(result: GraphLoaded | StepResult | ErrorResult, ir_graph: IRGraph | None, is_reset: bool = False) -> dict: """Convert a command result to JSON for transmission. Args: result: Result dataclass from backend ir_graph: Current IRGraph (for context in responses) is_reset: True if this result is from a reset command Returns: JSON-serializable dict """ if isinstance(result, GraphLoaded): return graph_loaded_json(result.ir_graph, result.snapshot) elif isinstance(result, StepResult): # Handle reset without reload: snapshot is None but this is not an error if result.snapshot is None: if is_reset and result.finished and result.sim_time == 0.0: return { "type": "reset", "sim_time": 0.0, "message": "Simulation reset", } else: return { "type": "error", "message": "No program loaded", } return graph_to_monitor_json(ir_graph, result.snapshot, list(result.events)) elif isinstance(result, ErrorResult): return { "type": "error", "message": result.message, "errors": result.errors, } else: return { "type": "error", "message": f"Unknown result type: {type(result).__name__}", }