An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
at new-directions 184 lines 5.9 kB view raw
1use crate::agent::AgentId; 2use crate::graph::GraphNode; 3use anyhow::{Result, anyhow}; 4use async_trait::async_trait; 5use std::collections::HashMap; 6use std::path::PathBuf; 7use std::sync::Mutex; 8use tokio::sync::{broadcast, mpsc}; 9use tracing::warn; 10 11/// Messages exchanged between workers and the orchestrator. 12/// 13/// Worker->Orchestrator variants include `agent_id` so the orchestrator 14/// can identify the sender without relying on channel metadata. 15#[derive(Debug, Clone)] 16pub enum WorkerMessage { 17 // Worker -> Orchestrator 18 ProgressReport { 19 agent_id: AgentId, 20 turn: usize, 21 summary: String, 22 }, 23 TaskCompleted { 24 agent_id: AgentId, 25 task_id: String, 26 summary: String, 27 }, 28 TaskBlocked { 29 agent_id: AgentId, 30 task_id: String, 31 reason: String, 32 }, 33 NeedsDecision { 34 agent_id: AgentId, 35 task_id: String, 36 decision: GraphNode, 37 }, 38 NodeCreated { 39 agent_id: AgentId, 40 parent_id: String, 41 node: GraphNode, 42 }, 43 44 // Orchestrator -> Worker 45 Cancel { 46 reason: String, 47 }, 48 AdditionalContext { 49 content: String, 50 }, 51 52 // Worker <-> Worker (review flow) 53 ReviewRequest { 54 work_package_id: String, 55 changed_files: Vec<PathBuf>, 56 }, 57 ReviewFeedback { 58 approved: bool, 59 comments: Vec<String>, 60 }, 61} 62 63/// Trait for the in-memory message bus. 64/// 65/// The bus provides best-effort, fire-and-forget delivery. If any message 66/// is lost, correctness is preserved because the orchestrator discovers 67/// the same information by querying the DB on its next scheduling pass. 68#[async_trait] 69pub trait MessageBus: Send + Sync { 70 /// Send a message to a specific agent's channel. 71 async fn send(&self, to: &AgentId, msg: WorkerMessage) -> Result<()>; 72 73 /// Broadcast a message to all subscribers. 74 async fn broadcast(&self, msg: WorkerMessage) -> Result<()>; 75 76 /// Create a subscription for an agent. Returns a receiver that gets 77 /// both targeted messages (via send) and broadcast messages. 78 fn subscribe(&self, agent_id: &AgentId) -> mpsc::Receiver<WorkerMessage>; 79 80 /// Remove a subscriber, cleaning up resources. 81 fn remove_subscriber(&self, agent_id: &AgentId); 82} 83 84/// In-memory message bus using tokio broadcast + per-agent mpsc channels. 85/// 86/// Architecture: broadcast channel for fan-out + per-agent mpsc for targeted 87/// delivery. Each subscriber gets a unified mpsc receiver that aggregates 88/// both targeted sends and forwarded broadcast messages. 89pub struct TokioMessageBus { 90 broadcast_tx: broadcast::Sender<WorkerMessage>, 91 /// Per-agent mpsc senders. Protected by Mutex because subscribe/send are 92 /// called from different tasks but never held across await points. 93 agent_channels: Mutex<HashMap<AgentId, mpsc::Sender<WorkerMessage>>>, 94 /// Channel capacity for per-agent mpsc channels. 95 agent_channel_capacity: usize, 96} 97 98impl TokioMessageBus { 99 /// Create a new TokioMessageBus. 100 /// 101 /// `broadcast_capacity` — buffer size for the broadcast channel (default: 64). 102 /// `agent_channel_capacity` — buffer size for per-agent mpsc channels (default: 32). 103 pub fn new(broadcast_capacity: usize, agent_channel_capacity: usize) -> Self { 104 let (broadcast_tx, _) = broadcast::channel(broadcast_capacity); 105 Self { 106 broadcast_tx, 107 agent_channels: Mutex::new(HashMap::new()), 108 agent_channel_capacity, 109 } 110 } 111} 112 113impl Default for TokioMessageBus { 114 fn default() -> Self { 115 Self::new(64, 32) 116 } 117} 118 119#[async_trait] 120impl MessageBus for TokioMessageBus { 121 async fn send(&self, to: &AgentId, msg: WorkerMessage) -> Result<()> { 122 let sender = { 123 let channels = self 124 .agent_channels 125 .lock() 126 .map_err(|e| anyhow!("lock poisoned: {}", e))?; 127 channels.get(to).cloned() 128 }; 129 match sender { 130 Some(tx) => { 131 if tx.send(msg).await.is_err() { 132 warn!(agent_id = %to, "agent channel closed, message dropped"); 133 } 134 Ok(()) 135 } 136 None => Err(anyhow!("agent {} is not subscribed", to)), 137 } 138 } 139 140 async fn broadcast(&self, msg: WorkerMessage) -> Result<()> { 141 // Ignore SendError — means no active receivers, which is fine 142 let _ = self.broadcast_tx.send(msg); 143 Ok(()) 144 } 145 146 fn subscribe(&self, agent_id: &AgentId) -> mpsc::Receiver<WorkerMessage> { 147 let (tx, rx) = mpsc::channel(self.agent_channel_capacity); 148 149 // Store the sender for targeted messages 150 { 151 let mut channels = self.agent_channels.lock().expect("lock poisoned"); 152 channels.insert(agent_id.clone(), tx.clone()); 153 } 154 155 // Spawn a forwarding task: broadcast -> agent's mpsc 156 let mut broadcast_rx = self.broadcast_tx.subscribe(); 157 tokio::spawn(async move { 158 loop { 159 match broadcast_rx.recv().await { 160 Ok(msg) => { 161 if tx.send(msg).await.is_err() { 162 // Receiver dropped — exit forwarding loop 163 break; 164 } 165 } 166 Err(broadcast::error::RecvError::Lagged(n)) => { 167 warn!(lagged = n, "broadcast receiver lagged, skipping messages"); 168 continue; 169 } 170 Err(broadcast::error::RecvError::Closed) => { 171 break; 172 } 173 } 174 } 175 }); 176 177 rx 178 } 179 180 fn remove_subscriber(&self, agent_id: &AgentId) { 181 let mut channels = self.agent_channels.lock().expect("lock poisoned"); 182 channels.remove(agent_id); 183 } 184}