An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
1use super::{LlmClient, Message, Response, ResponseContent, ToolCall, ToolDefinition};
2use async_trait::async_trait;
3use std::collections::VecDeque;
4use std::sync::{Arc, Mutex};
5
6type RecordedCalls = Vec<(Vec<Message>, Vec<ToolDefinition>)>;
7type MockResponseQueue = VecDeque<(ResponseContent, Option<String>)>;
8
9pub struct MockLlmClient {
10 responses: Arc<Mutex<MockResponseQueue>>,
11 recorded_calls: Arc<Mutex<RecordedCalls>>,
12 token_counts: Arc<Mutex<Option<(usize, usize)>>>, // (input_tokens, output_tokens)
13}
14
15impl MockLlmClient {
16 pub fn new() -> Self {
17 Self {
18 responses: Arc::new(Mutex::new(VecDeque::new())),
19 recorded_calls: Arc::new(Mutex::new(Vec::new())),
20 token_counts: Arc::new(Mutex::new(None)),
21 }
22 }
23
24 pub fn queue_text_response(&self, text: &str) {
25 self.responses.lock().unwrap().push_back((
26 ResponseContent::Text(text.to_string()),
27 Some("end_turn".to_string()),
28 ));
29 }
30
31 pub fn queue_tool_call(&self, name: &str, params: serde_json::Value) {
32 self.responses.lock().unwrap().push_back((
33 ResponseContent::ToolCalls(vec![ToolCall {
34 id: format!("call_{}", uuid::Uuid::new_v4()),
35 name: name.to_string(),
36 parameters: params,
37 }]),
38 Some("tool_use".to_string()),
39 ));
40 }
41
42 pub fn set_token_counts(&self, input: usize, output: usize) {
43 *self.token_counts.lock().unwrap() = Some((input, output));
44 }
45
46 pub fn get_recorded_calls(&self) -> Vec<(Vec<Message>, Vec<ToolDefinition>)> {
47 self.recorded_calls.lock().unwrap().clone()
48 }
49}
50
51impl Default for MockLlmClient {
52 fn default() -> Self {
53 Self::new()
54 }
55}
56
57#[async_trait]
58impl LlmClient for MockLlmClient {
59 async fn chat(
60 &self,
61 messages: Vec<Message>,
62 tools: &[ToolDefinition],
63 ) -> anyhow::Result<Response> {
64 self.recorded_calls
65 .lock()
66 .unwrap()
67 .push((messages, tools.to_vec()));
68
69 let (content, stop_reason) = self
70 .responses
71 .lock()
72 .unwrap()
73 .pop_front()
74 .ok_or_else(|| anyhow::anyhow!("No more mock responses queued"))?;
75
76 let token_counts = *self.token_counts.lock().unwrap();
77
78 Ok(Response {
79 content,
80 stop_reason,
81 input_tokens: token_counts.map(|(i, _)| i),
82 output_tokens: token_counts.map(|(_, o)| o),
83 })
84 }
85}