An AI agent built to do Ralph loops - plan mode for planning and ralph mode for implementing.
at main 210 lines 6.0 kB view raw
1use crate::llm::error::LlmError; 2use reqwest::Response as HttpResponse; 3use reqwest::header::RETRY_AFTER; 4use std::time::{Duration, SystemTime}; 5use tokio::time::sleep; 6use tracing::{debug, warn}; 7 8#[derive(Debug, Clone)] 9pub struct RetryConfig { 10 pub max_attempts: u32, 11 pub base_delay: Duration, 12 pub max_delay: Duration, 13 pub jitter_ratio: f64, 14} 15 16impl Default for RetryConfig { 17 fn default() -> Self { 18 Self { 19 max_attempts: 4, 20 base_delay: Duration::from_millis(800), 21 max_delay: Duration::from_secs(60), 22 jitter_ratio: 0.2, 23 } 24 } 25} 26 27pub async fn with_retry<T, Fut>( 28 provider: &'static str, 29 cfg: &RetryConfig, 30 mut op: impl FnMut() -> Fut, 31) -> Result<T, LlmError> 32where 33 Fut: std::future::Future<Output = Result<T, LlmError>>, 34{ 35 let mut attempt: u32 = 1; 36 37 loop { 38 match op().await { 39 Ok(v) => return Ok(v), 40 Err(err) => { 41 let is_retryable = err.is_retryable(); 42 let can_retry = attempt < cfg.max_attempts && is_retryable; 43 44 if !can_retry { 45 if !is_retryable { 46 debug!( 47 provider = provider, 48 error_kind = ?err.kind, 49 "Non-retryable error" 50 ); 51 } 52 return Err(err); 53 } 54 55 let delay = compute_delay(attempt, cfg, err.retry_after); 56 57 warn!( 58 provider = provider, 59 attempt = attempt, 60 max_attempts = cfg.max_attempts, 61 delay_ms = delay.as_millis() as u64, 62 error_kind = ?err.kind, 63 "Retrying after error" 64 ); 65 66 if let Some(raw) = &err.raw { 67 debug!(provider = provider, raw = %raw, "Raw error details"); 68 } 69 70 sleep(delay).await; 71 attempt += 1; 72 } 73 } 74 } 75} 76 77fn compute_delay(attempt: u32, cfg: &RetryConfig, server_hint: Option<Duration>) -> Duration { 78 let base_delay = if let Some(hint) = server_hint { 79 hint.min(cfg.max_delay) 80 } else { 81 let exp = 2u32.saturating_pow(attempt.saturating_sub(1)); 82 cfg.base_delay.saturating_mul(exp).min(cfg.max_delay) 83 }; 84 85 add_jitter(base_delay, cfg.jitter_ratio) 86} 87 88fn add_jitter(delay: Duration, jitter_ratio: f64) -> Duration { 89 if jitter_ratio <= 0.0 { 90 return delay; 91 } 92 93 let max_extra = (delay.as_millis() as f64 * jitter_ratio) as u64; 94 if max_extra == 0 { 95 return delay; 96 } 97 98 let extra_ms = SystemTime::now() 99 .duration_since(SystemTime::UNIX_EPOCH) 100 .unwrap_or_default() 101 .subsec_nanos() as u64 102 % (max_extra + 1); 103 104 delay + Duration::from_millis(extra_ms) 105} 106 107pub fn parse_retry_after_header(resp: &HttpResponse) -> Option<Duration> { 108 let value = resp.headers().get(RETRY_AFTER)?.to_str().ok()?.trim(); 109 110 if let Ok(secs) = value.parse::<u64>() { 111 return Some(Duration::from_secs(secs)); 112 } 113 114 if let Ok(secs) = value.parse::<f64>() 115 && secs.is_finite() 116 && secs > 0.0 117 { 118 return Some(Duration::from_secs_f64(secs)); 119 } 120 121 parse_http_date(value) 122} 123 124fn parse_http_date(value: &str) -> Option<Duration> { 125 use chrono::{DateTime, NaiveDateTime, Utc}; 126 127 if let Ok(parsed) = DateTime::parse_from_rfc2822(value) { 128 let now = Utc::now(); 129 let diff = parsed.signed_duration_since(now); 130 if diff.num_seconds() > 0 { 131 return Some(Duration::from_secs(diff.num_seconds() as u64)); 132 } 133 return None; 134 } 135 136 let formats = ["%a, %d %b %Y %H:%M:%S GMT", "%A, %d-%b-%y %H:%M:%S GMT"]; 137 138 for fmt in &formats { 139 if let Ok(naive) = NaiveDateTime::parse_from_str(value, fmt) { 140 let parsed = naive.and_utc(); 141 let now = Utc::now(); 142 let diff = parsed.signed_duration_since(now); 143 if diff.num_seconds() > 0 { 144 return Some(Duration::from_secs(diff.num_seconds() as u64)); 145 } 146 } 147 } 148 149 None 150} 151 152pub fn parse_retry_from_message(message: &str) -> Option<Duration> { 153 let lower = message.to_lowercase(); 154 155 if let Some(idx) = lower.find("try again in ") { 156 let tail = &lower[idx + "try again in ".len()..]; 157 158 let num_end = tail 159 .find(|c: char| !c.is_ascii_digit() && c != '.') 160 .unwrap_or(tail.len()); 161 let num_str = tail[..num_end].trim(); 162 163 if let Ok(num) = num_str.parse::<f64>() 164 && num.is_finite() 165 && num > 0.0 166 { 167 let unit_start = num_end; 168 let unit_tail = tail[unit_start..].trim_start(); 169 170 let multiplier = if unit_tail.starts_with("ms") { 171 1.0 172 } else if unit_tail.starts_with('s') || unit_tail.starts_with("second") { 173 1000.0 174 } else if unit_tail.starts_with('m') && !unit_tail.starts_with("ms") { 175 60_000.0 176 } else { 177 1000.0 178 }; 179 180 return Some(Duration::from_millis((num * multiplier) as u64)); 181 } 182 } 183 184 None 185} 186 187#[cfg(test)] 188mod tests { 189 use super::*; 190 191 #[test] 192 fn test_parse_retry_from_message_seconds() { 193 let msg = "Rate limit reached. Please try again in 45.622s."; 194 let dur = parse_retry_from_message(msg).unwrap(); 195 assert!(dur.as_millis() >= 45000 && dur.as_millis() <= 46000); 196 } 197 198 #[test] 199 fn test_parse_retry_from_message_ms() { 200 let msg = "Try again in 500ms"; 201 let dur = parse_retry_from_message(msg).unwrap(); 202 assert_eq!(dur.as_millis(), 500); 203 } 204 205 #[test] 206 fn test_parse_retry_from_message_none() { 207 let msg = "Something went wrong"; 208 assert!(parse_retry_from_message(msg).is_none()); 209 } 210}