An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
at new-directions 270 lines 12 kB view raw
1use crate::agent::{AgentContext, AgentId, AgentOutcome, AgentProfile}; 2use crate::context::{ContextBudget, ContextBuilder}; 3use crate::llm::{LlmClient, Message, ResponseContent}; 4use crate::message::{MessageBus, WorkerMessage}; 5use crate::tools::ToolRegistry; 6use anyhow::Result; 7use std::sync::Arc; 8 9/// Configuration for the AgentRuntime 10#[derive(Clone)] 11pub struct RuntimeConfig { 12 /// Maximum number of turns to run (default: 100) 13 pub max_turns: usize, 14 /// Maximum consecutive LLM failures before blocking (default: 3) 15 pub max_consecutive_llm_failures: usize, 16 /// Maximum consecutive tool failures before blocking (default: 3) 17 pub max_consecutive_tool_failures: usize, 18 /// Token budget for this run (default: 200_000) 19 pub token_budget: usize, 20 /// Warning threshold as percentage of budget (default: 80) 21 pub token_budget_warning_pct: u8, 22 /// Optional message bus for check-in reports (None for single-agent mode) 23 pub message_bus: Option<Arc<dyn MessageBus>>, 24 /// Agent ID for check-in reports (None for single-agent mode) 25 pub agent_id: Option<AgentId>, 26 /// Check-in interval in turns (default: 10) 27 pub check_in_interval: usize, 28} 29 30impl std::fmt::Debug for RuntimeConfig { 31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 32 f.debug_struct("RuntimeConfig") 33 .field("max_turns", &self.max_turns) 34 .field( 35 "max_consecutive_llm_failures", 36 &self.max_consecutive_llm_failures, 37 ) 38 .field( 39 "max_consecutive_tool_failures", 40 &self.max_consecutive_tool_failures, 41 ) 42 .field("token_budget", &self.token_budget) 43 .field("token_budget_warning_pct", &self.token_budget_warning_pct) 44 .field("message_bus", &self.message_bus.is_some()) 45 .field("agent_id", &self.agent_id) 46 .field("check_in_interval", &self.check_in_interval) 47 .finish() 48 } 49} 50 51impl Default for RuntimeConfig { 52 fn default() -> Self { 53 Self { 54 max_turns: 100, 55 max_consecutive_llm_failures: 3, 56 max_consecutive_tool_failures: 3, 57 token_budget: 200_000, 58 token_budget_warning_pct: 80, 59 message_bus: None, 60 agent_id: None, 61 check_in_interval: 10, 62 } 63 } 64} 65 66/// The agentic loop: LLM call -> tool execution -> repeat 67pub struct AgentRuntime { 68 client: Arc<dyn LlmClient>, 69 tools: ToolRegistry, 70 #[allow(dead_code)] 71 profile: AgentProfile, 72 config: RuntimeConfig, 73} 74 75impl AgentRuntime { 76 /// Create a new AgentRuntime 77 pub fn new( 78 client: Arc<dyn LlmClient>, 79 tools: ToolRegistry, 80 profile: AgentProfile, 81 config: RuntimeConfig, 82 ) -> Self { 83 Self { 84 client, 85 tools, 86 profile, 87 config, 88 } 89 } 90 91 /// Run the agentic loop 92 pub async fn run(&self, ctx: AgentContext) -> Result<AgentOutcome> { 93 let budget = ContextBudget::default(); 94 let system_prompt = ContextBuilder::build_system_prompt_with_budget(&ctx, &budget); 95 let mut messages = vec![Message::system(system_prompt)]; 96 let mut cumulative_tokens: usize = 0; 97 let mut warned_about_budget = false; 98 let mut consecutive_llm_failures = 0; 99 let mut consecutive_tool_failures = 0; 100 let mut turn = 0; 101 102 loop { 103 // Check turn limit 104 if turn >= self.config.max_turns { 105 return Ok(AgentOutcome::Completed { 106 summary: format!("Turn limit reached after {} turns", self.config.max_turns), 107 tokens_used: cumulative_tokens, 108 }); 109 } 110 turn += 1; 111 112 // Send check-in progress report if configured 113 if let (Some(bus), Some(id)) = (&self.config.message_bus, &self.config.agent_id) 114 && turn > 1 115 && (turn - 1) % self.config.check_in_interval == 0 116 { 117 let _ = bus 118 .send( 119 &"orchestrator".to_string(), 120 WorkerMessage::ProgressReport { 121 agent_id: id.clone(), 122 turn: turn - 1, 123 summary: format!("Turn {}: processing", turn - 1), 124 }, 125 ) 126 .await; 127 } 128 129 // Check token budget warning threshold 130 let token_warning_threshold = 131 (self.config.token_budget * self.config.token_budget_warning_pct as usize) / 100; 132 if cumulative_tokens >= token_warning_threshold && !warned_about_budget { 133 warned_about_budget = true; 134 messages.push(Message::system( 135 "You are approaching your token budget. Wrap up your current work and signal completion.".to_string() 136 )); 137 } 138 139 // Check token budget exhausted 140 if cumulative_tokens >= self.config.token_budget { 141 return Ok(AgentOutcome::TokenBudgetExhausted { 142 summary: "Token budget exhausted".to_string(), 143 tokens_used: cumulative_tokens, 144 }); 145 } 146 147 // Call LLM 148 let tool_definitions = self.tools.definitions(); 149 let response = match self.client.chat(messages.clone(), &tool_definitions).await { 150 Ok(resp) => { 151 consecutive_llm_failures = 0; 152 resp 153 } 154 Err(e) => { 155 consecutive_llm_failures += 1; 156 if consecutive_llm_failures >= self.config.max_consecutive_llm_failures { 157 return Ok(AgentOutcome::Blocked { 158 reason: format!( 159 "LLM failures: {} consecutive failures ({:?})", 160 self.config.max_consecutive_llm_failures, e 161 ), 162 }); 163 } 164 // Send error back to LLM for self-correction 165 messages.push(Message::assistant(format!("Error: {}", e))); 166 continue; 167 } 168 }; 169 170 // Track token usage 171 if let Some(input_tokens) = response.input_tokens { 172 cumulative_tokens += input_tokens; 173 } 174 if let Some(output_tokens) = response.output_tokens { 175 cumulative_tokens += output_tokens; 176 } 177 178 // Process response content 179 match response.content { 180 ResponseContent::Text(text) => { 181 messages.push(Message::assistant(text)); 182 } 183 ResponseContent::ToolCalls(tool_calls) => { 184 // Add the assistant's tool calls to the message history 185 let tool_calls_json = serde_json::to_string(&tool_calls)?; 186 messages.push(Message::assistant(tool_calls_json)); 187 188 // Execute each tool 189 for tool_call in tool_calls { 190 // Check for signal_completion 191 if tool_call.name == "signal_completion" { 192 if let Some(tool) = self.tools.get(&tool_call.name) { 193 match tool.execute(tool_call.parameters).await { 194 Ok(output) => { 195 if output.contains("SIGNAL:complete") { 196 // Extract message from output 197 let message = output 198 .strip_prefix("SIGNAL:complete:") 199 .unwrap_or("Task completed") 200 .to_string(); 201 return Ok(AgentOutcome::Completed { 202 summary: message, 203 tokens_used: cumulative_tokens, 204 }); 205 } else if output.contains("SIGNAL:blocked") { 206 let reason = output 207 .strip_prefix("SIGNAL:blocked:") 208 .unwrap_or("Task blocked") 209 .to_string(); 210 return Ok(AgentOutcome::Blocked { reason }); 211 } 212 } 213 Err(e) => { 214 consecutive_tool_failures += 1; 215 let error_msg = format!("Tool execution failed: {}", e); 216 messages.push(Message::tool_result( 217 tool_call.id.clone(), 218 error_msg, 219 )); 220 } 221 } 222 } 223 continue; 224 } 225 226 // Execute regular tool 227 match self.tools.get(&tool_call.name) { 228 Some(tool) => match tool.execute(tool_call.parameters).await { 229 Ok(output) => { 230 consecutive_tool_failures = 0; 231 messages.push(Message::tool_result(tool_call.id, output)); 232 } 233 Err(e) => { 234 consecutive_tool_failures += 1; 235 if consecutive_tool_failures 236 >= self.config.max_consecutive_tool_failures 237 { 238 return Ok(AgentOutcome::Blocked { 239 reason: format!( 240 "Tool failures: {} consecutive failures", 241 self.config.max_consecutive_tool_failures 242 ), 243 }); 244 } 245 let error_msg = format!("Tool error: {}", e); 246 messages.push(Message::tool_result(tool_call.id, error_msg)); 247 } 248 }, 249 None => { 250 consecutive_tool_failures += 1; 251 if consecutive_tool_failures 252 >= self.config.max_consecutive_tool_failures 253 { 254 return Ok(AgentOutcome::Blocked { 255 reason: format!( 256 "Tool failures: {} consecutive failures", 257 self.config.max_consecutive_tool_failures 258 ), 259 }); 260 } 261 let error_msg = format!("Unknown tool: {}", tool_call.name); 262 messages.push(Message::tool_result(tool_call.id, error_msg)); 263 } 264 } 265 } 266 } 267 } 268 } 269 } 270}