//! Shared HTTP client with per-host GCRA rate limiting. //! //! [`ThrottledClient`] wraps `reqwest::Client` and implements //! `jacquard_common::http_client::HttpClient`, so any function that accepts //! a generic `C: HttpClient` (including all XRPC call sites via the blanket //! `XrpcExt` impl) works with it directly. //! //! Rate limiting is keyed by the hostname of the outgoing request. Once the //! resync pipeline resolves DIDs to PDS endpoints before fetching, workers //! will naturally be rate-limited against the correct PDS host. use std::num::NonZeroU32; use std::sync::Arc; use std::time::Duration; use dashmap::DashMap; use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; use jacquard_common::http_client::{HttpClient, HttpClientExt}; use jacquard_common::stream::{ByteStream, StreamError}; const THROTTLE_JITTER_MS: f32 = 16.; /// State shared across all clones of a [`ThrottledClient`]. struct Shared { /// Duration of one token at the configured rate (= 1s / rate). token_interval: Duration, /// One GCRA limiter per hostname. Entries are never evicted; the number /// of distinct hosts contacted is small enough to be unbounded in memory. limiters: DashMap>, /// Per-host quota, kept for creating new limiters on demand. quota: Quota, } impl Shared { fn get_or_create_limiter(&self, host: &str) -> Arc { Arc::clone( self.limiters .entry(host.to_string()) .or_insert_with(|| Arc::new(RateLimiter::direct(self.quota))) .value(), ) } fn jittered_interval(&self) -> Duration { let secs = fastrand::f32() * THROTTLE_JITTER_MS / 1000.0; self.token_interval + Duration::from_secs_f32(secs) } } /// HTTP client that applies per-host GCRA rate limiting to every request. /// /// Implements [`HttpClient`], so it plugs in anywhere `reqwest::Client` was /// used (via the `jacquard-common` trait). All clones share the same limiter /// pool, so a single set of per-host buckets is maintained across the full /// worker pool. #[derive(Clone)] pub struct ThrottledClient { inner: reqwest::Client, shared: Arc, } impl ThrottledClient { pub fn new(rate_per_second: NonZeroU32) -> Self { let inner = reqwest::Client::builder() .user_agent(concat!( "microcosm lightrail/v", env!("CARGO_PKG_VERSION"), ", https://tangled.org/microcosm.blue/lightrail" )) .build() .expect("failed to build HTTP client"); let quota = Quota::per_second(rate_per_second); Self { inner, shared: Arc::new(Shared { token_interval: Duration::from_secs(1) / rate_per_second.get(), limiters: DashMap::new(), quota, }), } } } /// Build the shared HTTP client used for all outbound ATProto requests. pub fn build_client(rate_per_sec: NonZeroU32) -> ThrottledClient { ThrottledClient::new(rate_per_sec) } impl HttpClient for ThrottledClient { type Error = reqwest::Error; async fn send_http( &self, request: http::Request>, ) -> Result>, Self::Error> { let (parts, body) = request.into_parts(); if let Some(host) = parts.uri.host() { let limiter = self.shared.get_or_create_limiter(host); while limiter.check().is_err() { metrics::gauge!("lightrail_http_host_throttling").increment(1); // i think we should be limiter.until_ready_with_jitter().await! tokio::time::sleep(self.shared.jittered_interval()).await; metrics::gauge!("lightrail_http_host_throttling").decrement(1); } } let mut req = self .inner .request(parts.method, parts.uri.to_string()) .body(body); for (name, value) in &parts.headers { req = req.header(name, value); } metrics::gauge!("lightrail_http_requests_in_flight").increment(1); let resp = req.send().await; metrics::gauge!("lightrail_http_requests_in_flight").decrement(1); let resp = resp?; let status = resp.status(); let mut builder = http::Response::builder().status(status); for (name, value) in resp.headers() { builder = builder.header(name, value); } let body = resp.bytes().await?.to_vec(); Ok(builder.body(body).expect("failed to build response")) } } impl HttpClientExt for ThrottledClient { async fn send_http_streaming( &self, request: http::Request>, ) -> Result, Self::Error> { let (parts, body) = request.into_parts(); if let Some(host) = parts.uri.host() { let limiter = self.shared.get_or_create_limiter(host); while limiter.check().is_err() { metrics::gauge!("lightrail_http_host_throttling").increment(1); tokio::time::sleep(self.shared.jittered_interval()).await; metrics::gauge!("lightrail_http_host_throttling").decrement(1); } } metrics::gauge!("lightrail_http_requests_in_flight").increment(1); // decremented in get_repo (sketttch) self.inner .send_http_streaming(http::Request::from_parts(parts, body)) .await } #[cfg(not(target_arch = "wasm32"))] async fn send_http_bidirectional( &self, parts: http::request::Parts, body: S, ) -> Result, Self::Error> where S: n0_future::Stream> + Send + 'static, { if let Some(host) = parts.uri.host() { let limiter = self.shared.get_or_create_limiter(host); while limiter.check().is_err() { metrics::gauge!("lightrail_http_host_throttling").increment(1); tokio::time::sleep(self.shared.jittered_interval()).await; metrics::gauge!("lightrail_http_host_throttling").decrement(1); } } metrics::gauge!("lightrail_http_requests_in_flight").increment(1); // decremented in get_repo (sketttch) self.inner.send_http_bidirectional(parts, body).await } }