An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
1use crate::llm::error::LlmError;
2use reqwest::Response as HttpResponse;
3use reqwest::header::RETRY_AFTER;
4use std::time::{Duration, SystemTime};
5use tokio::time::sleep;
6use tracing::{debug, warn};
7
8#[derive(Debug, Clone)]
9pub struct RetryConfig {
10 pub max_attempts: u32,
11 pub base_delay: Duration,
12 pub max_delay: Duration,
13 pub jitter_ratio: f64,
14}
15
16impl Default for RetryConfig {
17 fn default() -> Self {
18 Self {
19 max_attempts: 4,
20 base_delay: Duration::from_millis(800),
21 max_delay: Duration::from_secs(60),
22 jitter_ratio: 0.2,
23 }
24 }
25}
26
27pub async fn with_retry<T, Fut>(
28 provider: &'static str,
29 cfg: &RetryConfig,
30 mut op: impl FnMut() -> Fut,
31) -> Result<T, LlmError>
32where
33 Fut: std::future::Future<Output = Result<T, LlmError>>,
34{
35 let mut attempt: u32 = 1;
36
37 loop {
38 match op().await {
39 Ok(v) => return Ok(v),
40 Err(err) => {
41 let is_retryable = err.is_retryable();
42 let can_retry = attempt < cfg.max_attempts && is_retryable;
43
44 if !can_retry {
45 if !is_retryable {
46 debug!(
47 provider = provider,
48 error_kind = ?err.kind,
49 "Non-retryable error"
50 );
51 }
52 return Err(err);
53 }
54
55 let delay = compute_delay(attempt, cfg, err.retry_after);
56
57 warn!(
58 provider = provider,
59 attempt = attempt,
60 max_attempts = cfg.max_attempts,
61 delay_ms = delay.as_millis() as u64,
62 error_kind = ?err.kind,
63 "Retrying after error"
64 );
65
66 if let Some(raw) = &err.raw {
67 debug!(provider = provider, raw = %raw, "Raw error details");
68 }
69
70 sleep(delay).await;
71 attempt += 1;
72 }
73 }
74 }
75}
76
77fn compute_delay(attempt: u32, cfg: &RetryConfig, server_hint: Option<Duration>) -> Duration {
78 let base_delay = if let Some(hint) = server_hint {
79 hint.min(cfg.max_delay)
80 } else {
81 let exp = 2u32.saturating_pow(attempt.saturating_sub(1));
82 cfg.base_delay.saturating_mul(exp).min(cfg.max_delay)
83 };
84
85 add_jitter(base_delay, cfg.jitter_ratio)
86}
87
88fn add_jitter(delay: Duration, jitter_ratio: f64) -> Duration {
89 if jitter_ratio <= 0.0 {
90 return delay;
91 }
92
93 let max_extra = (delay.as_millis() as f64 * jitter_ratio) as u64;
94 if max_extra == 0 {
95 return delay;
96 }
97
98 let extra_ms = SystemTime::now()
99 .duration_since(SystemTime::UNIX_EPOCH)
100 .unwrap_or_default()
101 .subsec_nanos() as u64
102 % (max_extra + 1);
103
104 delay + Duration::from_millis(extra_ms)
105}
106
107pub fn parse_retry_after_header(resp: &HttpResponse) -> Option<Duration> {
108 let value = resp.headers().get(RETRY_AFTER)?.to_str().ok()?.trim();
109
110 if let Ok(secs) = value.parse::<u64>() {
111 return Some(Duration::from_secs(secs));
112 }
113
114 if let Ok(secs) = value.parse::<f64>()
115 && secs.is_finite()
116 && secs > 0.0
117 {
118 return Some(Duration::from_secs_f64(secs));
119 }
120
121 parse_http_date(value)
122}
123
124fn parse_http_date(value: &str) -> Option<Duration> {
125 use chrono::{DateTime, NaiveDateTime, Utc};
126
127 if let Ok(parsed) = DateTime::parse_from_rfc2822(value) {
128 let now = Utc::now();
129 let diff = parsed.signed_duration_since(now);
130 if diff.num_seconds() > 0 {
131 return Some(Duration::from_secs(diff.num_seconds() as u64));
132 }
133 return None;
134 }
135
136 let formats = ["%a, %d %b %Y %H:%M:%S GMT", "%A, %d-%b-%y %H:%M:%S GMT"];
137
138 for fmt in &formats {
139 if let Ok(naive) = NaiveDateTime::parse_from_str(value, fmt) {
140 let parsed = naive.and_utc();
141 let now = Utc::now();
142 let diff = parsed.signed_duration_since(now);
143 if diff.num_seconds() > 0 {
144 return Some(Duration::from_secs(diff.num_seconds() as u64));
145 }
146 }
147 }
148
149 None
150}
151
152pub fn parse_retry_from_message(message: &str) -> Option<Duration> {
153 let lower = message.to_lowercase();
154
155 if let Some(idx) = lower.find("try again in ") {
156 let tail = &lower[idx + "try again in ".len()..];
157
158 let num_end = tail
159 .find(|c: char| !c.is_ascii_digit() && c != '.')
160 .unwrap_or(tail.len());
161 let num_str = tail[..num_end].trim();
162
163 if let Ok(num) = num_str.parse::<f64>()
164 && num.is_finite()
165 && num > 0.0
166 {
167 let unit_start = num_end;
168 let unit_tail = tail[unit_start..].trim_start();
169
170 let multiplier = if unit_tail.starts_with("ms") {
171 1.0
172 } else if unit_tail.starts_with('s') || unit_tail.starts_with("second") {
173 1000.0
174 } else if unit_tail.starts_with('m') && !unit_tail.starts_with("ms") {
175 60_000.0
176 } else {
177 1000.0
178 };
179
180 return Some(Duration::from_millis((num * multiplier) as u64));
181 }
182 }
183
184 None
185}
186
187#[cfg(test)]
188mod tests {
189 use super::*;
190
191 #[test]
192 fn test_parse_retry_from_message_seconds() {
193 let msg = "Rate limit reached. Please try again in 45.622s.";
194 let dur = parse_retry_from_message(msg).unwrap();
195 assert!(dur.as_millis() >= 45000 && dur.as_millis() <= 46000);
196 }
197
198 #[test]
199 fn test_parse_retry_from_message_ms() {
200 let msg = "Try again in 500ms";
201 let dur = parse_retry_from_message(msg).unwrap();
202 assert_eq!(dur.as_millis(), 500);
203 }
204
205 #[test]
206 fn test_parse_retry_from_message_none() {
207 let msg = "Something went wrong";
208 assert!(parse_retry_from_message(msg).is_none());
209 }
210}