lightweight com.atproto.sync.listReposByCollection
at main 157 lines 5.8 kB view raw
1//! Shared HTTP client with per-host GCRA rate limiting. 2//! 3//! [`ThrottledClient`] wraps `reqwest::Client` and implements 4//! `jacquard_common::http_client::HttpClient`, so any function that accepts 5//! a generic `C: HttpClient` (including all XRPC call sites via the blanket 6//! `XrpcExt` impl) works with it directly. 7//! 8//! Rate limiting is keyed by the hostname of the outgoing request. Once the 9//! resync pipeline resolves DIDs to PDS endpoints before fetching, workers 10//! will naturally be rate-limited against the correct PDS host. 11 12use std::num::NonZeroU32; 13use std::sync::Arc; 14use std::time::Duration; 15 16use dashmap::DashMap; 17use governor::{DefaultDirectRateLimiter, Quota, RateLimiter}; 18use jacquard_common::http_client::{HttpClient, HttpClientExt}; 19use jacquard_common::stream::{ByteStream, StreamError}; 20 21/// State shared across all clones of a [`ThrottledClient`]. 22struct Shared { 23 /// Duration of one token at the configured rate (= 1s / rate). 24 token_interval: Duration, 25 /// One GCRA limiter per hostname. Entries are never evicted; the number 26 /// of distinct hosts contacted is small enough to be unbounded in memory. 27 limiters: DashMap<String, Arc<DefaultDirectRateLimiter>>, 28 /// Per-host quota, kept for creating new limiters on demand. 29 quota: Quota, 30} 31 32impl Shared { 33 fn get_or_create_limiter(&self, host: &str) -> Arc<DefaultDirectRateLimiter> { 34 Arc::clone( 35 self.limiters 36 .entry(host.to_string()) 37 .or_insert_with(|| Arc::new(RateLimiter::direct(self.quota))) 38 .value(), 39 ) 40 } 41} 42 43/// HTTP client that applies per-host GCRA rate limiting to every request. 44/// 45/// Implements [`HttpClient`], so it plugs in anywhere `reqwest::Client` was 46/// used (via the `jacquard-common` trait). All clones share the same limiter 47/// pool, so a single set of per-host buckets is maintained across the full 48/// worker pool. 49#[derive(Clone)] 50pub struct ThrottledClient { 51 inner: reqwest::Client, 52 shared: Arc<Shared>, 53} 54 55impl ThrottledClient { 56 pub fn new(rate_per_second: NonZeroU32) -> Self { 57 let inner = reqwest::Client::builder() 58 .user_agent(concat!( 59 "microcosm lightrail/v", 60 env!("CARGO_PKG_VERSION"), 61 ", https://tangled.org/microcosm.blue/lightrail" 62 )) 63 .build() 64 .expect("failed to build HTTP client"); 65 let quota = Quota::per_second(rate_per_second); 66 Self { 67 inner, 68 shared: Arc::new(Shared { 69 token_interval: Duration::from_secs(1) / rate_per_second.get(), 70 limiters: DashMap::new(), 71 quota, 72 }), 73 } 74 } 75} 76 77/// Build the shared HTTP client used for all outbound ATProto requests. 78pub fn build_client(rate_per_sec: NonZeroU32) -> ThrottledClient { 79 ThrottledClient::new(rate_per_sec) 80} 81 82impl HttpClient for ThrottledClient { 83 type Error = reqwest::Error; 84 85 async fn send_http( 86 &self, 87 request: http::Request<Vec<u8>>, 88 ) -> Result<http::Response<Vec<u8>>, Self::Error> { 89 let (parts, body) = request.into_parts(); 90 91 if let Some(host) = parts.uri.host() { 92 let limiter = self.shared.get_or_create_limiter(host); 93 while limiter.check().is_err() { 94 metrics::gauge!("lightrail_http_host_thorottling").increment(1); 95 tokio::time::sleep(self.shared.token_interval).await; 96 metrics::gauge!("lightrail_http_host_thorottling").decrement(1); 97 } 98 } 99 100 let mut req = self 101 .inner 102 .request(parts.method, parts.uri.to_string()) 103 .body(body); 104 for (name, value) in &parts.headers { 105 req = req.header(name, value); 106 } 107 let resp = req.send().await?; 108 109 let status = resp.status(); 110 let mut builder = http::Response::builder().status(status); 111 for (name, value) in resp.headers() { 112 builder = builder.header(name, value); 113 } 114 let body = resp.bytes().await?.to_vec(); 115 Ok(builder.body(body).expect("failed to build response")) 116 } 117} 118 119impl HttpClientExt for ThrottledClient { 120 async fn send_http_streaming( 121 &self, 122 request: http::Request<Vec<u8>>, 123 ) -> Result<http::Response<ByteStream>, Self::Error> { 124 let (parts, body) = request.into_parts(); 125 if let Some(host) = parts.uri.host() { 126 let limiter = self.shared.get_or_create_limiter(host); 127 while limiter.check().is_err() { 128 metrics::gauge!("lightrail_http_host_thorottling").increment(1); 129 tokio::time::sleep(self.shared.token_interval).await; 130 metrics::gauge!("lightrail_http_host_thorottling").decrement(1); 131 } 132 } 133 self.inner 134 .send_http_streaming(http::Request::from_parts(parts, body)) 135 .await 136 } 137 138 #[cfg(not(target_arch = "wasm32"))] 139 async fn send_http_bidirectional<S>( 140 &self, 141 parts: http::request::Parts, 142 body: S, 143 ) -> Result<http::Response<ByteStream>, Self::Error> 144 where 145 S: n0_future::Stream<Item = Result<bytes::Bytes, StreamError>> + Send + 'static, 146 { 147 if let Some(host) = parts.uri.host() { 148 let limiter = self.shared.get_or_create_limiter(host); 149 while limiter.check().is_err() { 150 metrics::gauge!("lightrail_http_host_thorottling").increment(1); 151 tokio::time::sleep(self.shared.token_interval).await; 152 metrics::gauge!("lightrail_http_host_thorottling").decrement(1); 153 } 154 } 155 self.inner.send_http_bidirectional(parts, body).await 156 } 157}