I've been saying "PDSes seem easy enough, they're what, some CRUD to a db? I can do that in my sleep". well i'm sleeping rn so let's go
at main 9.4 kB view raw
1use axum::{ 2 Json, 3 body::Body, 4 extract::ConnectInfo, 5 http::{HeaderMap, Request, StatusCode}, 6 middleware::Next, 7 response::{IntoResponse, Response}, 8}; 9use governor::{ 10 Quota, RateLimiter, 11 clock::DefaultClock, 12 state::{InMemoryState, NotKeyed, keyed::DefaultKeyedStateStore}, 13}; 14use std::{net::SocketAddr, num::NonZeroU32, sync::Arc}; 15 16pub type KeyedRateLimiter = RateLimiter<String, DefaultKeyedStateStore<String>, DefaultClock>; 17pub type GlobalRateLimiter = RateLimiter<NotKeyed, InMemoryState, DefaultClock>; 18 19#[derive(Clone)] 20pub struct RateLimiters { 21 pub login: Arc<KeyedRateLimiter>, 22 pub oauth_token: Arc<KeyedRateLimiter>, 23 pub oauth_authorize: Arc<KeyedRateLimiter>, 24 pub password_reset: Arc<KeyedRateLimiter>, 25 pub account_creation: Arc<KeyedRateLimiter>, 26 pub refresh_session: Arc<KeyedRateLimiter>, 27 pub reset_password: Arc<KeyedRateLimiter>, 28 pub oauth_par: Arc<KeyedRateLimiter>, 29 pub oauth_introspect: Arc<KeyedRateLimiter>, 30 pub app_password: Arc<KeyedRateLimiter>, 31 pub email_update: Arc<KeyedRateLimiter>, 32 pub totp_verify: Arc<KeyedRateLimiter>, 33 pub handle_update: Arc<KeyedRateLimiter>, 34 pub handle_update_daily: Arc<KeyedRateLimiter>, 35 pub verification_check: Arc<KeyedRateLimiter>, 36} 37 38impl Default for RateLimiters { 39 fn default() -> Self { 40 Self::new() 41 } 42} 43 44impl RateLimiters { 45 pub fn new() -> Self { 46 Self { 47 login: Arc::new(RateLimiter::keyed(Quota::per_minute( 48 NonZeroU32::new(10).unwrap(), 49 ))), 50 oauth_token: Arc::new(RateLimiter::keyed(Quota::per_minute( 51 NonZeroU32::new(30).unwrap(), 52 ))), 53 oauth_authorize: Arc::new(RateLimiter::keyed(Quota::per_minute( 54 NonZeroU32::new(10).unwrap(), 55 ))), 56 password_reset: Arc::new(RateLimiter::keyed(Quota::per_hour( 57 NonZeroU32::new(5).unwrap(), 58 ))), 59 account_creation: Arc::new(RateLimiter::keyed(Quota::per_hour( 60 NonZeroU32::new(10).unwrap(), 61 ))), 62 refresh_session: Arc::new(RateLimiter::keyed(Quota::per_minute( 63 NonZeroU32::new(60).unwrap(), 64 ))), 65 reset_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 66 NonZeroU32::new(10).unwrap(), 67 ))), 68 oauth_par: Arc::new(RateLimiter::keyed(Quota::per_minute( 69 NonZeroU32::new(30).unwrap(), 70 ))), 71 oauth_introspect: Arc::new(RateLimiter::keyed(Quota::per_minute( 72 NonZeroU32::new(30).unwrap(), 73 ))), 74 app_password: Arc::new(RateLimiter::keyed(Quota::per_minute( 75 NonZeroU32::new(10).unwrap(), 76 ))), 77 email_update: Arc::new(RateLimiter::keyed(Quota::per_hour( 78 NonZeroU32::new(5).unwrap(), 79 ))), 80 totp_verify: Arc::new(RateLimiter::keyed( 81 Quota::with_period(std::time::Duration::from_secs(60)) 82 .unwrap() 83 .allow_burst(NonZeroU32::new(5).unwrap()), 84 )), 85 handle_update: Arc::new(RateLimiter::keyed( 86 Quota::with_period(std::time::Duration::from_secs(30)) 87 .unwrap() 88 .allow_burst(NonZeroU32::new(10).unwrap()), 89 )), 90 handle_update_daily: Arc::new(RateLimiter::keyed( 91 Quota::with_period(std::time::Duration::from_secs(1728)) 92 .unwrap() 93 .allow_burst(NonZeroU32::new(50).unwrap()), 94 )), 95 verification_check: Arc::new(RateLimiter::keyed(Quota::per_minute( 96 NonZeroU32::new(60).unwrap(), 97 ))), 98 } 99 } 100 101 pub fn with_login_limit(mut self, per_minute: u32) -> Self { 102 self.login = Arc::new(RateLimiter::keyed(Quota::per_minute( 103 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 104 ))); 105 self 106 } 107 108 pub fn with_oauth_token_limit(mut self, per_minute: u32) -> Self { 109 self.oauth_token = Arc::new(RateLimiter::keyed(Quota::per_minute( 110 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(30).unwrap()), 111 ))); 112 self 113 } 114 115 pub fn with_oauth_authorize_limit(mut self, per_minute: u32) -> Self { 116 self.oauth_authorize = Arc::new(RateLimiter::keyed(Quota::per_minute( 117 NonZeroU32::new(per_minute).unwrap_or(NonZeroU32::new(10).unwrap()), 118 ))); 119 self 120 } 121 122 pub fn with_password_reset_limit(mut self, per_hour: u32) -> Self { 123 self.password_reset = Arc::new(RateLimiter::keyed(Quota::per_hour( 124 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 125 ))); 126 self 127 } 128 129 pub fn with_account_creation_limit(mut self, per_hour: u32) -> Self { 130 self.account_creation = Arc::new(RateLimiter::keyed(Quota::per_hour( 131 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(10).unwrap()), 132 ))); 133 self 134 } 135 136 pub fn with_email_update_limit(mut self, per_hour: u32) -> Self { 137 self.email_update = Arc::new(RateLimiter::keyed(Quota::per_hour( 138 NonZeroU32::new(per_hour).unwrap_or(NonZeroU32::new(5).unwrap()), 139 ))); 140 self 141 } 142} 143 144pub fn extract_client_ip(headers: &HeaderMap, addr: Option<SocketAddr>) -> String { 145 if let Some(forwarded) = headers.get("x-forwarded-for") 146 && let Ok(value) = forwarded.to_str() 147 && let Some(first_ip) = value.split(',').next() 148 { 149 return first_ip.trim().to_string(); 150 } 151 152 if let Some(real_ip) = headers.get("x-real-ip") 153 && let Ok(value) = real_ip.to_str() 154 { 155 return value.trim().to_string(); 156 } 157 158 addr.map(|a| a.ip().to_string()) 159 .unwrap_or_else(|| "unknown".to_string()) 160} 161 162fn rate_limit_response() -> Response { 163 ( 164 StatusCode::TOO_MANY_REQUESTS, 165 Json(serde_json::json!({ 166 "error": "RateLimitExceeded", 167 "message": "Too many requests. Please try again later." 168 })), 169 ) 170 .into_response() 171} 172 173pub async fn login_rate_limit( 174 ConnectInfo(addr): ConnectInfo<SocketAddr>, 175 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 176 request: Request<Body>, 177 next: Next, 178) -> Response { 179 let client_ip = extract_client_ip(request.headers(), Some(addr)); 180 181 if limiters.login.check_key(&client_ip).is_err() { 182 tracing::warn!(ip = %client_ip, "Login rate limit exceeded"); 183 return rate_limit_response(); 184 } 185 186 next.run(request).await 187} 188 189pub async fn oauth_token_rate_limit( 190 ConnectInfo(addr): ConnectInfo<SocketAddr>, 191 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 192 request: Request<Body>, 193 next: Next, 194) -> Response { 195 let client_ip = extract_client_ip(request.headers(), Some(addr)); 196 197 if limiters.oauth_token.check_key(&client_ip).is_err() { 198 tracing::warn!(ip = %client_ip, "OAuth token rate limit exceeded"); 199 return rate_limit_response(); 200 } 201 202 next.run(request).await 203} 204 205pub async fn password_reset_rate_limit( 206 ConnectInfo(addr): ConnectInfo<SocketAddr>, 207 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 208 request: Request<Body>, 209 next: Next, 210) -> Response { 211 let client_ip = extract_client_ip(request.headers(), Some(addr)); 212 213 if limiters.password_reset.check_key(&client_ip).is_err() { 214 tracing::warn!(ip = %client_ip, "Password reset rate limit exceeded"); 215 return rate_limit_response(); 216 } 217 218 next.run(request).await 219} 220 221pub async fn account_creation_rate_limit( 222 ConnectInfo(addr): ConnectInfo<SocketAddr>, 223 axum::extract::State(limiters): axum::extract::State<Arc<RateLimiters>>, 224 request: Request<Body>, 225 next: Next, 226) -> Response { 227 let client_ip = extract_client_ip(request.headers(), Some(addr)); 228 229 if limiters.account_creation.check_key(&client_ip).is_err() { 230 tracing::warn!(ip = %client_ip, "Account creation rate limit exceeded"); 231 return rate_limit_response(); 232 } 233 234 next.run(request).await 235} 236 237#[cfg(test)] 238mod tests { 239 use super::*; 240 241 #[test] 242 fn test_rate_limiters_creation() { 243 let limiters = RateLimiters::new(); 244 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 245 } 246 247 #[test] 248 fn test_rate_limiter_exhaustion() { 249 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(2).unwrap())); 250 let key = "test_ip".to_string(); 251 252 assert!(limiter.check_key(&key).is_ok()); 253 assert!(limiter.check_key(&key).is_ok()); 254 assert!(limiter.check_key(&key).is_err()); 255 } 256 257 #[test] 258 fn test_different_keys_have_separate_limits() { 259 let limiter = RateLimiter::keyed(Quota::per_minute(NonZeroU32::new(1).unwrap())); 260 261 assert!(limiter.check_key(&"ip1".to_string()).is_ok()); 262 assert!(limiter.check_key(&"ip1".to_string()).is_err()); 263 assert!(limiter.check_key(&"ip2".to_string()).is_ok()); 264 } 265 266 #[test] 267 fn test_builder_pattern() { 268 let limiters = RateLimiters::new() 269 .with_login_limit(20) 270 .with_oauth_token_limit(60) 271 .with_password_reset_limit(3) 272 .with_account_creation_limit(5); 273 274 assert!(limiters.login.check_key(&"test".to_string()).is_ok()); 275 } 276}