An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
at new-directions 85 lines 2.5 kB view raw
1use super::{LlmClient, Message, Response, ResponseContent, ToolCall, ToolDefinition}; 2use async_trait::async_trait; 3use std::collections::VecDeque; 4use std::sync::{Arc, Mutex}; 5 6type RecordedCalls = Vec<(Vec<Message>, Vec<ToolDefinition>)>; 7type MockResponseQueue = VecDeque<(ResponseContent, Option<String>)>; 8 9pub struct MockLlmClient { 10 responses: Arc<Mutex<MockResponseQueue>>, 11 recorded_calls: Arc<Mutex<RecordedCalls>>, 12 token_counts: Arc<Mutex<Option<(usize, usize)>>>, // (input_tokens, output_tokens) 13} 14 15impl MockLlmClient { 16 pub fn new() -> Self { 17 Self { 18 responses: Arc::new(Mutex::new(VecDeque::new())), 19 recorded_calls: Arc::new(Mutex::new(Vec::new())), 20 token_counts: Arc::new(Mutex::new(None)), 21 } 22 } 23 24 pub fn queue_text_response(&self, text: &str) { 25 self.responses.lock().unwrap().push_back(( 26 ResponseContent::Text(text.to_string()), 27 Some("end_turn".to_string()), 28 )); 29 } 30 31 pub fn queue_tool_call(&self, name: &str, params: serde_json::Value) { 32 self.responses.lock().unwrap().push_back(( 33 ResponseContent::ToolCalls(vec![ToolCall { 34 id: format!("call_{}", uuid::Uuid::new_v4()), 35 name: name.to_string(), 36 parameters: params, 37 }]), 38 Some("tool_use".to_string()), 39 )); 40 } 41 42 pub fn set_token_counts(&self, input: usize, output: usize) { 43 *self.token_counts.lock().unwrap() = Some((input, output)); 44 } 45 46 pub fn get_recorded_calls(&self) -> Vec<(Vec<Message>, Vec<ToolDefinition>)> { 47 self.recorded_calls.lock().unwrap().clone() 48 } 49} 50 51impl Default for MockLlmClient { 52 fn default() -> Self { 53 Self::new() 54 } 55} 56 57#[async_trait] 58impl LlmClient for MockLlmClient { 59 async fn chat( 60 &self, 61 messages: Vec<Message>, 62 tools: &[ToolDefinition], 63 ) -> anyhow::Result<Response> { 64 self.recorded_calls 65 .lock() 66 .unwrap() 67 .push((messages, tools.to_vec())); 68 69 let (content, stop_reason) = self 70 .responses 71 .lock() 72 .unwrap() 73 .pop_front() 74 .ok_or_else(|| anyhow::anyhow!("No more mock responses queued"))?; 75 76 let token_counts = *self.token_counts.lock().unwrap(); 77 78 Ok(Response { 79 content, 80 stop_reason, 81 input_tokens: token_counts.map(|(i, _)| i), 82 output_tokens: token_counts.map(|(_, o)| o), 83 }) 84 } 85}