Microservice to bring 2FA to self hosted PDSes

feat: add rate limit bypass for IPs and tokens #11

Summary

Add GATEKEEPER_RATE_LIMIT_BYPASS_IPS (comma-separated) and GATEKEEPER_RATE_LIMIT_BYPASS_KEY (checked via x-ratelimit-bypass header) env vars to exempt specific requests from rate limiting Replace tower_governor with custom middleware using the governor crate directly, since tower_governor has no built-in bypass mechanism All 4 rate-limited routes (sign-in, createSession, createAccount, gate/signup) use the new middleware Motivation

Certain trusted IPs (e.g. the PDS itself at 148.251.49.115) were being rate-limited, producing "Rate limit exceeded for smart IP" log noise. The PDS already supports PDS_RATE_LIMIT_BYPASS_KEY and PDS_RATE_LIMIT_BYPASS_IPS โ€” this brings the same pattern to gatekeeper.

Test plan

11 new unit tests in src/rate_limit.rs covering IP extraction, bypass by IP, bypass by token, unconfigured default, env parsing, 429 enforcement, and bypass passthrough All 37 tests pass (cargo test) Manual smoke test: hit a rate-limited endpoint with/without bypass header/IP

Note: vibe coded with Claude Code

Labels

None yet.

Participants 2
AT URI
at://did:plc:autcqcg4hsvgdf3hwt4cvci3/sh.tangled.repo.pull/3mfm5bh5fba22
+345 -137
Diff #0
+1 -79
Cargo.lock
··· 1103 1103 "percent-encoding", 1104 1104 ] 1105 1105 1106 - [[package]] 1107 - name = "forwarded-header-value" 1108 - version = "0.1.1" 1109 - source = "registry+https://github.com/rust-lang/crates.io-index" 1110 - checksum = "8835f84f38484cc86f110a805655697908257fb9a7af005234060891557198e9" 1111 - dependencies = [ 1112 - "nonempty", 1113 - "thiserror 1.0.69", 1114 - ] 1115 - 1116 1106 [[package]] 1117 1107 name = "fs_extra" 1118 1108 version = "1.3.0" ··· 1597 1587 "webpki-roots 1.0.5", 1598 1588 ] 1599 1589 1600 - [[package]] 1601 - name = "hyper-timeout" 1602 - version = "0.5.2" 1603 - source = "registry+https://github.com/rust-lang/crates.io-index" 1604 - checksum = "2b90d566bffbce6a75bd8b09a05aa8c2cb1fabb6cb348f8840c9e4c90a0d83b0" 1605 - dependencies = [ 1606 - "hyper", 1607 - "hyper-util", 1608 - "pin-project-lite", 1609 - "tokio", 1610 - "tower-service", 1611 - ] 1612 - 1613 1590 [[package]] 1614 1591 name = "hyper-util" 1615 1592 version = "0.1.19" ··· 2331 2308 "memchr", 2332 2309 ] 2333 2310 2334 - [[package]] 2335 - name = "nonempty" 2336 - version = "0.7.0" 2337 - source = "registry+https://github.com/rust-lang/crates.io-index" 2338 - checksum = "e9e591e719385e6ebaeb5ce5d3887f7d5676fceca6411d1925ccc95745f3d6f7" 2339 - 2340 2311 [[package]] 2341 2312 name = "nonzero_ext" 2342 2313 version = "0.3.0" ··· 2569 2540 "chrono", 2570 2541 "dashmap", 2571 2542 "dotenvy", 2543 + "governor", 2572 2544 "handlebars", 2573 2545 "hex", 2574 2546 "html-escape", ··· 2591 2563 "tokio", 2592 2564 "tower", 2593 2565 "tower-http", 2594 - "tower_governor", 2595 2566 "tracing", 2596 2567 "tracing-subscriber", 2597 2568 "urlencoding", ··· 4133 4104 "tokio", 4134 4105 ] 4135 4106 4136 - [[package]] 4137 - name = "tonic" 4138 - version = "0.14.2" 4139 - source = "registry+https://github.com/rust-lang/crates.io-index" 4140 - checksum = "eb7613188ce9f7df5bfe185db26c5814347d110db17920415cf2fbcad85e7203" 4141 - dependencies = [ 4142 - "async-trait", 4143 - "axum", 4144 - "base64", 4145 - "bytes", 4146 - "h2", 4147 - "http", 4148 - "http-body", 4149 - "http-body-util", 4150 - "hyper", 4151 - "hyper-timeout", 4152 - "hyper-util", 4153 - "percent-encoding", 4154 - "pin-project", 4155 - "socket2", 4156 - "sync_wrapper", 4157 - "tokio", 4158 - "tokio-stream", 4159 - "tower", 4160 - "tower-layer", 4161 - "tower-service", 4162 - "tracing", 4163 - ] 4164 - 4165 4107 [[package]] 4166 4108 name = "tower" 4167 4109 version = "0.5.2" ··· 4170 4112 dependencies = [ 4171 4113 "futures-core", 4172 4114 "futures-util", 4173 - "indexmap 2.12.1", 4174 4115 "pin-project-lite", 4175 - "slab", 4176 4116 "sync_wrapper", 4177 4117 "tokio", 4178 - "tokio-util", 4179 4118 "tower-layer", 4180 4119 "tower-service", 4181 4120 "tracing", ··· 4216 4155 source = "registry+https://github.com/rust-lang/crates.io-index" 4217 4156 checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" 4218 4157 4219 - [[package]] 4220 - name = "tower_governor" 4221 - version = "0.8.0" 4222 - source = "registry+https://github.com/rust-lang/crates.io-index" 4223 - checksum = "44de9b94d849d3c46e06a883d72d408c2de6403367b39df2b1c9d9e7b6736fe6" 4224 - dependencies = [ 4225 - "axum", 4226 - "forwarded-header-value", 4227 - "governor", 4228 - "http", 4229 - "pin-project", 4230 - "thiserror 2.0.17", 4231 - "tonic", 4232 - "tower", 4233 - "tracing", 4234 - ] 4235 - 4236 4158 [[package]] 4237 4159 name = "tracing" 4238 4160 version = "0.1.44"
+1 -1
Cargo.toml
··· 15 15 tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } 16 16 hyper-util = { version = "0.1.19", features = ["client", "client-legacy"] } 17 17 tower-http = { version = "0.6", features = ["cors", "compression-zstd"] } 18 - tower_governor = { version = "0.8.0", features = ["axum", "tracing"] } 19 18 hex = "0.4" 20 19 jwt-compact = { version = "0.8.0", features = ["es256k"] } 21 20 scrypt = "0.11" ··· 39 38 josekit = "0.10.3" 40 39 dashmap = "6.1" 41 40 tower = "0.5" 41 + governor = "0.10"
+34 -57
src/main.rs
··· 27 27 use std::sync::Arc; 28 28 use std::time::Duration; 29 29 use std::{env, net::SocketAddr}; 30 - use tower_governor::{ 31 - GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, 32 - }; 30 + use crate::rate_limit::{RateLimitBypass, RateLimitState, rate_limit_middleware}; 33 31 use tower_http::{ 34 32 compression::CompressionLayer, 35 33 cors::{Any, CorsLayer}, ··· 38 36 use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 39 37 40 38 mod auth; 39 + mod rate_limit; 41 40 mod gate; 42 41 pub mod helpers; 43 42 mod middleware; ··· 285 284 }; 286 285 287 286 // Rate limiting 288 - //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 289 - let captcha_governor_conf = GovernorConfigBuilder::default() 290 - .per_second(60) 291 - .burst_size(5) 292 - .key_extractor(SmartIpKeyExtractor) 293 - .finish() 294 - .expect("failed to create governor config for create session. this should not happen and is a bug"); 295 - 296 - // Create a second config with the same settings for the other endpoint 297 - let sign_in_governor_conf = GovernorConfigBuilder::default() 298 - .per_second(60) 299 - .burst_size(5) 300 - .key_extractor(SmartIpKeyExtractor) 301 - .finish() 302 - .expect( 303 - "failed to create governor config for sign in. this should not happen and is a bug", 304 - ); 287 + let bypass = Arc::new(RateLimitBypass::from_env()); 305 288 306 - let create_account_limiter_time: Option<String> = 307 - env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 308 - let create_account_limiter_burst: Option<String> = 309 - env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 310 - 311 - //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 312 - let mut create_account_governor_conf = GovernorConfigBuilder::default(); 313 - if create_account_limiter_time.is_some() { 314 - let time = create_account_limiter_time 315 - .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 316 - .parse::<u64>() 317 - .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 318 - create_account_governor_conf.per_second(time); 319 - } 289 + let default_quota = governor::Quota::per_second(std::num::NonZeroU32::new(5).expect("nonzero")) 290 + .allow_burst(std::num::NonZeroU32::new(5).expect("nonzero")); 320 291 321 - if create_account_limiter_burst.is_some() { 322 - let burst = create_account_limiter_burst 323 - .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 324 - .parse::<u32>() 325 - .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 326 - create_account_governor_conf.burst_size(burst); 327 - } 292 + let captcha_limiter = Arc::new(governor::RateLimiter::keyed(default_quota)); 293 + let sign_in_limiter = Arc::new(governor::RateLimiter::keyed(default_quota)); 328 294 329 - let create_account_governor_conf = create_account_governor_conf 330 - .key_extractor(SmartIpKeyExtractor) 331 - .finish().expect( 332 - "failed to create governor config for create account. this should not happen and is a bug", 295 + let create_account_per_sec: u32 = env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND") 296 + .ok() 297 + .and_then(|s| s.parse().ok()) 298 + .unwrap_or(60); 299 + let create_account_burst: u32 = env::var("GATEKEEPER_CREATE_ACCOUNT_BURST") 300 + .ok() 301 + .and_then(|s| s.parse().ok()) 302 + .unwrap_or(5); 303 + let create_account_quota = governor::Quota::per_second( 304 + std::num::NonZeroU32::new(create_account_per_sec).expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be > 0"), 305 + ) 306 + .allow_burst( 307 + std::num::NonZeroU32::new(create_account_burst).expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be > 0"), 333 308 ); 309 + let create_account_limiter = Arc::new(governor::RateLimiter::keyed(create_account_quota)); 334 310 335 - let captcha_governor_limiter = captcha_governor_conf.limiter().clone(); 336 - let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 337 - let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 311 + let captcha_rl = RateLimitState { bypass: bypass.clone(), limiter: captcha_limiter.clone() }; 312 + let sign_in_rl = RateLimitState { bypass: bypass.clone(), limiter: sign_in_limiter.clone() }; 313 + let create_account_rl = RateLimitState { bypass: bypass.clone(), limiter: create_account_limiter.clone() }; 338 314 339 - let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf); 315 + let captcha_limiter_cleanup = captcha_limiter.clone(); 316 + let sign_in_limiter_cleanup = sign_in_limiter.clone(); 317 + let create_account_limiter_cleanup = create_account_limiter.clone(); 340 318 341 319 let interval = Duration::from_secs(60); 342 - // a separate background task to clean up 343 320 std::thread::spawn(move || { 344 321 loop { 345 322 std::thread::sleep(interval); 346 - captcha_governor_limiter.retain_recent(); 347 - sign_in_governor_limiter.retain_recent(); 348 - create_account_governor_limiter.retain_recent(); 323 + captcha_limiter_cleanup.retain_recent(); 324 + sign_in_limiter_cleanup.retain_recent(); 325 + create_account_limiter_cleanup.retain_recent(); 349 326 } 350 327 }); 351 328 ··· 367 344 ) 368 345 .route( 369 346 "/@atproto/oauth-provider/~api/sign-in", 370 - post(sign_in).layer(sign_in_governor_layer.clone()), 347 + post(sign_in).layer(ax_middleware::from_fn_with_state(sign_in_rl.clone(), rate_limit_middleware)), 371 348 ) 372 349 .route( 373 350 "/xrpc/com.atproto.server.createSession", 374 - post(create_session.layer(sign_in_governor_layer)), 351 + post(create_session).layer(ax_middleware::from_fn_with_state(sign_in_rl, rate_limit_middleware)), 375 352 ) 376 353 .route( 377 354 "/xrpc/com.atproto.server.createAccount", 378 - post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 355 + post(create_account).layer(ax_middleware::from_fn_with_state(create_account_rl, rate_limit_middleware)), 379 356 ); 380 357 381 358 if state.app_config.use_captcha { 382 359 app = app.route( 383 360 "/gate/signup", 384 - get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))), 361 + get(get_gate).post(post_gate.layer(ax_middleware::from_fn_with_state(captcha_rl, rate_limit_middleware))), 385 362 ); 386 363 } 387 364
+309
src/rate_limit.rs
··· 1 + use std::collections::HashSet; 2 + use std::net::{IpAddr, SocketAddr}; 3 + use std::sync::Arc; 4 + use std::env; 5 + 6 + use axum::body::Body; 7 + use axum::extract::State; 8 + use axum::http::{Request, StatusCode}; 9 + use axum::middleware::Next; 10 + use axum::response::{IntoResponse, Response}; 11 + use governor::clock::Clock; 12 + 13 + pub type KeyedRateLimiter = governor::RateLimiter< 14 + String, 15 + governor::state::keyed::DefaultKeyedStateStore<String>, 16 + governor::clock::DefaultClock, 17 + >; 18 + 19 + #[derive(Clone)] 20 + pub struct RateLimitState { 21 + pub bypass: Arc<RateLimitBypass>, 22 + pub limiter: Arc<KeyedRateLimiter>, 23 + } 24 + 25 + pub async fn rate_limit_middleware( 26 + State(state): State<RateLimitState>, 27 + req: Request<Body>, 28 + next: Next, 29 + ) -> Response { 30 + let connect_info = req 31 + .extensions() 32 + .get::<SocketAddr>() 33 + .copied() 34 + .unwrap_or_else(|| SocketAddr::from(([127, 0, 0, 1], 0))); 35 + 36 + let ip = extract_client_ip(&req, &connect_info); 37 + 38 + if state.bypass.should_bypass(&ip, &req) { 39 + return next.run(req).await; 40 + } 41 + 42 + match state.limiter.check_key(&ip.to_string()) { 43 + Ok(_) => next.run(req).await, 44 + Err(not_until) => { 45 + let wait = not_until.wait_time_from(governor::clock::DefaultClock::default().now()); 46 + let secs = wait.as_secs(); 47 + ( 48 + StatusCode::TOO_MANY_REQUESTS, 49 + [(axum::http::header::RETRY_AFTER, secs.to_string())], 50 + axum::Json(serde_json::json!({ 51 + "error": "RateLimitExceeded", 52 + "message": "Too Many Requests" 53 + })), 54 + ) 55 + .into_response() 56 + } 57 + } 58 + } 59 + 60 + pub struct RateLimitBypass { 61 + pub bypass_ips: HashSet<IpAddr>, 62 + pub bypass_key: Option<String>, 63 + } 64 + 65 + impl RateLimitBypass { 66 + pub fn from_env() -> Self { 67 + let bypass_key = env::var("GATEKEEPER_RATE_LIMIT_BYPASS_KEY").ok(); 68 + let bypass_ips = env::var("GATEKEEPER_RATE_LIMIT_BYPASS_IPS") 69 + .ok() 70 + .map(|val| { 71 + val.split(',') 72 + .filter_map(|s| s.trim().split('/').next()?.parse::<IpAddr>().ok()) 73 + .collect() 74 + }) 75 + .unwrap_or_default(); 76 + Self { bypass_ips, bypass_key } 77 + } 78 + 79 + pub fn should_bypass<B>(&self, ip: &IpAddr, req: &Request<B>) -> bool { 80 + if self.bypass_ips.contains(ip) { 81 + return true; 82 + } 83 + if let Some(ref key) = self.bypass_key { 84 + if let Some(val) = req.headers().get("x-ratelimit-bypass") { 85 + if let Ok(v) = val.to_str() { 86 + return v == key; 87 + } 88 + } 89 + } 90 + false 91 + } 92 + } 93 + 94 + pub fn extract_client_ip<B>(req: &Request<B>, connect_info: &SocketAddr) -> IpAddr { 95 + if let Some(xff) = req.headers().get("x-forwarded-for") { 96 + if let Ok(val) = xff.to_str() { 97 + if let Some(first) = val.split(',').next() { 98 + if let Ok(ip) = first.trim().parse::<IpAddr>() { 99 + return ip; 100 + } 101 + } 102 + } 103 + } 104 + if let Some(xri) = req.headers().get("x-real-ip") { 105 + if let Ok(val) = xri.to_str() { 106 + if let Ok(ip) = val.trim().parse::<IpAddr>() { 107 + return ip; 108 + } 109 + } 110 + } 111 + connect_info.ip() 112 + } 113 + 114 + #[cfg(test)] 115 + mod tests { 116 + use super::*; 117 + 118 + fn fallback() -> SocketAddr { 119 + "127.0.0.1:1234".parse().unwrap() 120 + } 121 + 122 + #[test] 123 + fn extracts_first_ip_from_x_forwarded_for() { 124 + let req = Request::builder() 125 + .header("x-forwarded-for", "203.0.113.50, 70.41.3.18") 126 + .body(()) 127 + .unwrap(); 128 + assert_eq!( 129 + extract_client_ip(&req, &fallback()), 130 + "203.0.113.50".parse::<IpAddr>().unwrap() 131 + ); 132 + } 133 + 134 + #[test] 135 + fn extracts_ip_from_x_real_ip() { 136 + let req = Request::builder() 137 + .header("x-real-ip", "198.51.100.1") 138 + .body(()) 139 + .unwrap(); 140 + assert_eq!( 141 + extract_client_ip(&req, &fallback()), 142 + "198.51.100.1".parse::<IpAddr>().unwrap() 143 + ); 144 + } 145 + 146 + #[test] 147 + fn falls_back_to_connect_info() { 148 + let req = Request::builder().body(()).unwrap(); 149 + assert_eq!( 150 + extract_client_ip(&req, &fallback()), 151 + "127.0.0.1".parse::<IpAddr>().unwrap() 152 + ); 153 + } 154 + 155 + #[test] 156 + fn bypass_matching_ip() { 157 + let bypass = RateLimitBypass { 158 + bypass_ips: HashSet::from(["203.0.113.50".parse().unwrap()]), 159 + bypass_key: None, 160 + }; 161 + let req = Request::builder().body(()).unwrap(); 162 + let ip: IpAddr = "203.0.113.50".parse().unwrap(); 163 + assert!(bypass.should_bypass(&ip, &req)); 164 + } 165 + 166 + #[test] 167 + fn bypass_matching_token() { 168 + let bypass = RateLimitBypass { 169 + bypass_ips: HashSet::new(), 170 + bypass_key: Some("secret123".into()), 171 + }; 172 + let req = Request::builder() 173 + .header("x-ratelimit-bypass", "secret123") 174 + .body(()) 175 + .unwrap(); 176 + let ip: IpAddr = "10.0.0.1".parse().unwrap(); 177 + assert!(bypass.should_bypass(&ip, &req)); 178 + } 179 + 180 + #[test] 181 + fn no_bypass_wrong_token() { 182 + let bypass = RateLimitBypass { 183 + bypass_ips: HashSet::new(), 184 + bypass_key: Some("secret123".into()), 185 + }; 186 + let req = Request::builder() 187 + .header("x-ratelimit-bypass", "wrong") 188 + .body(()) 189 + .unwrap(); 190 + let ip: IpAddr = "10.0.0.1".parse().unwrap(); 191 + assert!(!bypass.should_bypass(&ip, &req)); 192 + } 193 + 194 + #[test] 195 + fn no_bypass_when_unconfigured() { 196 + let bypass = RateLimitBypass { 197 + bypass_ips: HashSet::new(), 198 + bypass_key: None, 199 + }; 200 + let req = Request::builder().body(()).unwrap(); 201 + let ip: IpAddr = "10.0.0.1".parse().unwrap(); 202 + assert!(!bypass.should_bypass(&ip, &req)); 203 + } 204 + 205 + #[test] 206 + fn from_env_parses_ips_and_key() { 207 + unsafe { 208 + std::env::set_var("GATEKEEPER_RATE_LIMIT_BYPASS_KEY", "mysecret"); 209 + std::env::set_var("GATEKEEPER_RATE_LIMIT_BYPASS_IPS", "10.0.0.1, 192.168.1.1"); 210 + } 211 + let bypass = RateLimitBypass::from_env(); 212 + unsafe { 213 + std::env::remove_var("GATEKEEPER_RATE_LIMIT_BYPASS_KEY"); 214 + std::env::remove_var("GATEKEEPER_RATE_LIMIT_BYPASS_IPS"); 215 + } 216 + 217 + assert_eq!(bypass.bypass_key, Some("mysecret".into())); 218 + assert!(bypass.bypass_ips.contains(&"10.0.0.1".parse::<IpAddr>().unwrap())); 219 + assert!(bypass.bypass_ips.contains(&"192.168.1.1".parse::<IpAddr>().unwrap())); 220 + assert_eq!(bypass.bypass_ips.len(), 2); 221 + } 222 + 223 + #[test] 224 + fn no_bypass_non_matching_ip() { 225 + let bypass = RateLimitBypass { 226 + bypass_ips: HashSet::from(["203.0.113.50".parse().unwrap()]), 227 + bypass_key: None, 228 + }; 229 + let req = Request::builder().body(()).unwrap(); 230 + let ip: IpAddr = "10.0.0.1".parse().unwrap(); 231 + assert!(!bypass.should_bypass(&ip, &req)); 232 + } 233 + 234 + #[tokio::test] 235 + async fn rate_limit_returns_429_when_exceeded() { 236 + use axum::{Router, routing::get, body::Body, http::StatusCode}; 237 + use tower::ServiceExt; 238 + 239 + let bypass = Arc::new(RateLimitBypass { 240 + bypass_ips: HashSet::new(), 241 + bypass_key: None, 242 + }); 243 + 244 + let limiter = Arc::new( 245 + governor::RateLimiter::keyed(governor::Quota::per_second(std::num::NonZeroU32::new(1).unwrap())) 246 + ); 247 + 248 + let state = RateLimitState { 249 + bypass: bypass, 250 + limiter: limiter, 251 + }; 252 + 253 + let app = Router::new() 254 + .route("/test", get(|| async { "ok" })) 255 + .layer(axum::middleware::from_fn_with_state( 256 + state, 257 + rate_limit_middleware, 258 + )); 259 + 260 + // First request should pass 261 + let req = Request::builder() 262 + .uri("/test") 263 + .extension(std::net::SocketAddr::from(([10, 0, 0, 1], 1234))) 264 + .body(Body::empty()) 265 + .unwrap(); 266 + let resp = app.clone().oneshot(req).await.unwrap(); 267 + assert_eq!(resp.status(), StatusCode::OK); 268 + 269 + // Second request from same IP should be rate limited 270 + let req = Request::builder() 271 + .uri("/test") 272 + .extension(std::net::SocketAddr::from(([10, 0, 0, 1], 1234))) 273 + .body(Body::empty()) 274 + .unwrap(); 275 + let resp = app.oneshot(req).await.unwrap(); 276 + assert_eq!(resp.status(), StatusCode::TOO_MANY_REQUESTS); 277 + } 278 + 279 + #[tokio::test] 280 + async fn bypass_ip_never_gets_429() { 281 + use axum::{Router, routing::get, body::Body, http::StatusCode}; 282 + use tower::ServiceExt; 283 + 284 + let bypass = Arc::new(RateLimitBypass { 285 + bypass_ips: HashSet::from(["10.0.0.1".parse().unwrap()]), 286 + bypass_key: None, 287 + }); 288 + 289 + let limiter = Arc::new( 290 + governor::RateLimiter::keyed(governor::Quota::per_second(std::num::NonZeroU32::new(1).unwrap())) 291 + ); 292 + 293 + let state = RateLimitState { bypass, limiter }; 294 + 295 + let app = Router::new() 296 + .route("/test", get(|| async { "ok" })) 297 + .layer(axum::middleware::from_fn_with_state(state, rate_limit_middleware)); 298 + 299 + for _ in 0..5 { 300 + let req = Request::builder() 301 + .uri("/test") 302 + .extension(std::net::SocketAddr::from(([10, 0, 0, 1], 1234))) 303 + .body(Body::empty()) 304 + .unwrap(); 305 + let resp = app.clone().oneshot(req).await.unwrap(); 306 + assert_eq!(resp.status(), StatusCode::OK); 307 + } 308 + } 309 + }

History

1 round 1 comment
sign up or login to add to the discussion
1 commit
expand
feat: add rate limit bypass for IPs and tokens
merge conflicts detected
expand
  • Cargo.lock:2591
  • Cargo.toml:15
  • src/main.rs:27
expand 1 comment

I'm going to chew a bit more on this one since it could have some bigger changes if that works for you. I think the original reason for this PR may be covered now for eurosky (the eu-hauler need for it)?

I want to see a bit more what the PDS does on some of these endpoints and see if I also need to add rate limiting per did as well for server side applications and try and decide the best path on it