An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
at new-directions 147 lines 4.2 kB view raw
1pub mod agents; 2pub mod graph; 3pub mod projects; 4pub mod search; 5 6use crate::db::Database; 7use crate::graph::store::GraphStore; 8use crate::message::MessageBus; 9use crate::project::ProjectStore; 10use axum::Json; 11use axum::http::StatusCode; 12use axum::response::{IntoResponse, Response}; 13use std::collections::HashMap; 14use std::sync::{Arc, Mutex}; 15use tokio::sync::broadcast; 16 17/// WebSocket event types matching the architecture specification. 18/// 19/// Emission sources: 20/// - AgentSpawned, AgentProgress, AgentCompleted: Bridged from WorkerMessage via MessageBus 21/// - NodeCreated: Bridged from WorkerMessage::NodeCreated via MessageBus 22/// - NodeStatusChanged: Deferred — requires hooks in GraphStore::update_node 23/// - EdgeCreated: Deferred — requires hooks in graph mutation 24/// - SessionEnded: Deferred — requires orchestrator to emit directly via ws_tx 25/// - ToolExecution: Deferred — requires AgentRuntime to emit tool calls 26/// - OrchestratorStateChanged: Deferred — requires orchestrator state machine 27/// 28/// In Phase 3, only events bridged from the MessageBus are emitted (the first 4). 29#[derive(Debug, Clone, serde::Serialize)] 30#[serde(tag = "type", rename_all = "snake_case")] 31pub enum WsEvent { 32 AgentSpawned { 33 agent_id: String, 34 profile: String, 35 goal_id: String, 36 }, 37 AgentProgress { 38 agent_id: String, 39 turn: usize, 40 summary: String, 41 }, 42 AgentCompleted { 43 agent_id: String, 44 outcome_type: String, 45 summary: String, 46 tokens_used: Option<usize>, 47 }, 48 NodeCreated { 49 #[serde(flatten)] 50 node: crate::graph::GraphNode, 51 parent_id: Option<String>, 52 }, 53 NodeStatusChanged { 54 node_id: String, 55 node_type: String, 56 old_status: String, 57 new_status: String, 58 }, 59 EdgeCreated { 60 #[serde(flatten)] 61 edge: crate::graph::GraphEdge, 62 }, 63 SessionEnded { 64 session_id: String, 65 handoff_notes: Option<String>, 66 }, 67 ToolExecution { 68 agent_id: String, 69 tool: String, 70 args: serde_json::Value, 71 result: String, 72 }, 73 OrchestratorStateChanged { 74 goal_id: String, 75 state: String, 76 }, 77} 78 79/// Lightweight handle for managing running orchestrators 80pub struct OrchestratorHandle { 81 pub goal_id: String, 82 pub project_id: String, 83 pub cancel_token: tokio_util::sync::CancellationToken, 84 pub started_at: chrono::DateTime<chrono::Utc>, 85} 86 87#[derive(Clone)] 88pub struct AppState { 89 pub db: Database, 90 pub graph_store: Arc<dyn GraphStore>, 91 pub project_store: ProjectStore, 92 pub message_bus: Arc<dyn MessageBus>, 93 pub ws_tx: broadcast::Sender<WsEvent>, 94 pub orchestrators: Arc<Mutex<HashMap<String, OrchestratorHandle>>>, 95} 96 97impl AppState { 98 pub fn new( 99 db: Database, 100 graph_store: Arc<dyn GraphStore>, 101 message_bus: Arc<dyn MessageBus>, 102 ) -> Self { 103 let (ws_tx, _) = broadcast::channel(256); 104 Self { 105 project_store: ProjectStore::new(db.clone()), 106 db, 107 graph_store, 108 message_bus, 109 ws_tx, 110 orchestrators: Arc::new(Mutex::new(HashMap::new())), 111 } 112 } 113} 114 115pub enum ApiError { 116 NotFound(String), 117 BadRequest(String), 118 Conflict(String), 119 Internal(String), 120} 121 122impl IntoResponse for ApiError { 123 fn into_response(self) -> Response { 124 let (status, error_type, message) = match self { 125 ApiError::NotFound(msg) => (StatusCode::NOT_FOUND, "not found", msg), 126 ApiError::BadRequest(msg) => (StatusCode::BAD_REQUEST, "bad request", msg), 127 ApiError::Conflict(msg) => (StatusCode::CONFLICT, "conflict", msg), 128 ApiError::Internal(msg) => { 129 tracing::error!("Internal error: {}", msg); 130 (StatusCode::INTERNAL_SERVER_ERROR, "internal error", msg) 131 } 132 }; 133 134 let body = serde_json::json!({ 135 "error": error_type, 136 "message": message, 137 }); 138 139 (status, Json(body)).into_response() 140 } 141} 142 143impl From<anyhow::Error> for ApiError { 144 fn from(err: anyhow::Error) -> Self { 145 ApiError::Internal(err.to_string()) 146 } 147}