Microservice to bring 2FA to self hosted PDSes
at main 15 kB view raw
1use super::{AuthRules, HandleCache, SessionData}; 2use crate::helpers::json_error_response; 3use crate::AppState; 4use axum::extract::{Request, State}; 5use axum::http::{HeaderMap, StatusCode}; 6use axum::middleware::Next; 7use axum::response::{IntoResponse, Response}; 8use jacquard_identity::resolver::IdentityResolver; 9use jacquard_identity::PublicResolver; 10use jwt_compact::alg::{Hs256, Hs256Key}; 11use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; 12use serde::{Deserialize, Serialize}; 13use std::env; 14use std::sync::Arc; 15use tracing::log; 16 17#[derive(Clone, Copy, Debug, PartialEq, Eq)] 18pub enum AuthScheme { 19 Bearer, 20 DPoP, 21} 22 23#[derive(Serialize, Deserialize)] 24pub struct TokenClaims { 25 pub sub: String, 26 /// OAuth scopes as space-separated string (per OAuth 2.0 spec) 27 #[serde(default)] 28 pub scope: Option<String>, 29} 30 31/// State passed to the auth middleware containing both AppState and auth rules. 32#[derive(Clone)] 33pub struct AuthMiddlewareState { 34 pub app_state: AppState, 35 pub rules: AuthRules, 36} 37 38/// Core middleware function that validates authentication and applies auth rules. 39/// 40/// Use this with `axum::middleware::from_fn_with_state`: 41/// ```ignore 42/// use axum::middleware::from_fn_with_state; 43/// 44/// let mw_state = AuthMiddlewareState { 45/// app_state: state.clone(), 46/// rules: AuthRules::HandleEndsWith(".blacksky.team".into()), 47/// }; 48/// 49/// .route("/protected", get(handler).layer(from_fn_with_state(mw_state, auth_middleware))) 50/// ``` 51pub async fn auth_middleware( 52 State(mw_state): State<AuthMiddlewareState>, 53 req: Request, 54 next: Next, 55) -> Response { 56 let AuthMiddlewareState { app_state, rules } = mw_state; 57 58 // 1. Extract DID and scopes from JWT (Bearer token) 59 let extracted = match extract_auth_from_request(req.headers()) { 60 Ok(Some(auth)) => auth, 61 Ok(None) => { 62 return json_error_response(StatusCode::UNAUTHORIZED, "AuthRequired", "Authentication required") 63 .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response()); 64 } 65 Err(e) => { 66 log::error!("Token extraction error: {}", e); 67 return json_error_response(StatusCode::UNAUTHORIZED, "InvalidToken", &e) 68 .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response()); 69 } 70 }; 71 72 // 2. Resolve DID to handle (check cache first) 73 let handle = match resolve_did_to_handle(&app_state.resolver, &app_state.handle_cache, &extracted.did).await { 74 Ok(handle) => handle, 75 Err(e) => { 76 log::error!("Failed to resolve DID {} to handle: {}", extracted.did, e); 77 return json_error_response( 78 StatusCode::INTERNAL_SERVER_ERROR, 79 "ResolutionError", 80 "Failed to resolve identity", 81 ) 82 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response()); 83 } 84 }; 85 86 // 3. Build session data and validate rules 87 let session = SessionData { 88 did: extracted.did, 89 handle, 90 scopes: extracted.scopes, 91 }; 92 93 if !rules.validate(&session) { 94 return json_error_response(StatusCode::FORBIDDEN, "AccessDenied", "Access denied by authorization rules") 95 .unwrap_or_else(|_| StatusCode::FORBIDDEN.into_response()); 96 } 97 98 // 4. Pass through on success 99 next.run(req).await 100} 101 102/// Extracted authentication data from JWT 103struct ExtractedAuth { 104 did: String, 105 scopes: Vec<String>, 106} 107 108/// Extracts the DID and scopes from the Authorization header (Bearer JWT). 109fn extract_auth_from_request(headers: &HeaderMap) -> Result<Option<ExtractedAuth>, String> { 110 let auth = extract_auth(headers)?; 111 112 match auth { 113 None => Ok(None), 114 Some((scheme, token_str)) => { 115 match scheme { 116 AuthScheme::Bearer => { 117 let token = UntrustedToken::new(&token_str) 118 .map_err(|_| "Invalid token format".to_string())?; 119 120 let _claims: Claims<TokenClaims> = token 121 .deserialize_claims_unchecked() 122 .map_err(|_| "Failed to parse token claims".to_string())?; 123 124 let key = Hs256Key::new( 125 env::var("PDS_JWT_SECRET") 126 .map_err(|_| "PDS_JWT_SECRET not configured".to_string())?, 127 ); 128 129 let validated: Token<TokenClaims> = Hs256 130 .validator(&key) 131 .validate(&token) 132 .map_err(|e: ValidationError| format!("Token validation failed: {:?}", e))?; 133 134 let custom = &validated.claims().custom; 135 136 // Parse scopes from space-separated string (OAuth 2.0 spec) 137 let scopes: Vec<String> = custom.scope 138 .as_ref() 139 .map(|s| s.split_whitespace().map(|s| s.to_string()).collect()) 140 .unwrap_or_default(); 141 142 Ok(Some(ExtractedAuth { 143 did: custom.sub.clone(), 144 scopes, 145 })) 146 } 147 AuthScheme::DPoP => { 148 // DPoP tokens are not validated here; pass through without auth data 149 Ok(None) 150 } 151 } 152 } 153 } 154} 155 156/// Extracts the authentication scheme and token from the Authorization header. 157fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> { 158 match headers.get(axum::http::header::AUTHORIZATION) { 159 None => Ok(None), 160 Some(hv) => { 161 let s = hv 162 .to_str() 163 .map_err(|_| "Authorization header is not valid UTF-8".to_string())?; 164 165 let mut parts = s.splitn(2, ' '); 166 match (parts.next(), parts.next()) { 167 (Some("Bearer"), Some(tok)) if !tok.is_empty() => { 168 Ok(Some((AuthScheme::Bearer, tok.to_string()))) 169 } 170 (Some("DPoP"), Some(tok)) if !tok.is_empty() => { 171 Ok(Some((AuthScheme::DPoP, tok.to_string()))) 172 } 173 _ => Err( 174 "Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'" 175 .to_string(), 176 ), 177 } 178 } 179 } 180} 181 182/// Resolves a DID to its handle using the PublicResolver, with caching. 183async fn resolve_did_to_handle( 184 resolver: &Arc<PublicResolver>, 185 cache: &HandleCache, 186 did: &str, 187) -> Result<String, String> { 188 // Check cache first 189 if let Some(handle) = cache.get(did) { 190 return Ok(handle); 191 } 192 193 // Parse the DID 194 let did_parsed = jacquard_common::types::did::Did::new(did) 195 .map_err(|e| format!("Invalid DID: {:?}", e))?; 196 197 // Resolve the DID document 198 let did_doc_response = resolver 199 .resolve_did_doc(&did_parsed) 200 .await 201 .map_err(|e| format!("DID resolution failed: {:?}", e))?; 202 203 let doc = did_doc_response 204 .parse() 205 .map_err(|e| format!("Failed to parse DID document: {:?}", e))?; 206 207 // Extract handle from alsoKnownAs field 208 // Format is typically: ["at://handle.example.com"] 209 let handle: String = doc 210 .also_known_as 211 .as_ref() 212 .and_then(|aka| { 213 aka.iter() 214 .find(|uri| uri.starts_with("at://")) 215 .map(|uri| uri.strip_prefix("at://").unwrap_or(uri.as_ref()).to_string()) 216 }) 217 .ok_or_else(|| "No ATProto handle found in DID document".to_string())?; 218 219 // Cache the result 220 cache.insert(did.to_string(), handle.clone()); 221 222 Ok(handle) 223} 224 225// ============================================================================ 226// Helper Functions for Creating Middleware State 227// ============================================================================ 228 229/// Creates an `AuthMiddlewareState` for requiring the handle to end with a specific suffix. 230/// 231/// # Example 232/// ```ignore 233/// use axum::middleware::from_fn_with_state; 234/// use crate::auth::{auth_middleware, handle_ends_with}; 235/// 236/// .route("/protected", get(handler).layer( 237/// from_fn_with_state(handle_ends_with(".blacksky.team", &state), auth_middleware) 238/// )) 239/// ``` 240pub fn handle_ends_with(suffix: impl Into<String>, state: &AppState) -> AuthMiddlewareState { 241 AuthMiddlewareState { 242 app_state: state.clone(), 243 rules: AuthRules::HandleEndsWith(suffix.into()), 244 } 245} 246 247/// Creates an `AuthMiddlewareState` for requiring the handle to end with any of the specified suffixes. 248pub fn handle_ends_with_any<I, T>(suffixes: I, state: &AppState) -> AuthMiddlewareState 249where 250 I: IntoIterator<Item = T>, 251 T: Into<String>, 252{ 253 AuthMiddlewareState { 254 app_state: state.clone(), 255 rules: AuthRules::HandleEndsWithAny(suffixes.into_iter().map(|s| s.into()).collect()), 256 } 257} 258 259/// Creates an `AuthMiddlewareState` for requiring the DID to equal a specific value. 260pub fn did_equals(did: impl Into<String>, state: &AppState) -> AuthMiddlewareState { 261 AuthMiddlewareState { 262 app_state: state.clone(), 263 rules: AuthRules::DidEquals(did.into()), 264 } 265} 266 267/// Creates an `AuthMiddlewareState` for requiring the DID to be one of the specified values. 268pub fn did_equals_any<I, T>(dids: I, state: &AppState) -> AuthMiddlewareState 269where 270 I: IntoIterator<Item = T>, 271 T: Into<String>, 272{ 273 AuthMiddlewareState { 274 app_state: state.clone(), 275 rules: AuthRules::DidEqualsAny(dids.into_iter().map(|d| d.into()).collect()), 276 } 277} 278 279/// Creates an `AuthMiddlewareState` with custom auth rules. 280pub fn with_rules(rules: AuthRules, state: &AppState) -> AuthMiddlewareState { 281 AuthMiddlewareState { 282 app_state: state.clone(), 283 rules, 284 } 285} 286 287// ============================================================================ 288// Scope Helper Functions 289// ============================================================================ 290 291/// Creates an `AuthMiddlewareState` requiring a specific OAuth scope. 292/// 293/// # Example 294/// ```ignore 295/// .route("/xrpc/com.atproto.repo.createRecord", 296/// post(handler).layer(from_fn_with_state( 297/// scope_equals("repo:app.bsky.feed.post", &state), 298/// auth_middleware 299/// ))) 300/// ``` 301pub fn scope_equals(scope: impl Into<String>, state: &AppState) -> AuthMiddlewareState { 302 AuthMiddlewareState { 303 app_state: state.clone(), 304 rules: AuthRules::ScopeEquals(scope.into()), 305 } 306} 307 308/// Creates an `AuthMiddlewareState` requiring ANY of the specified scopes (OR logic). 309/// 310/// # Example 311/// ```ignore 312/// .route("/xrpc/com.atproto.repo.putRecord", 313/// post(handler).layer(from_fn_with_state( 314/// scope_any(["repo:app.bsky.feed.post", "transition:generic"], &state), 315/// auth_middleware 316/// ))) 317/// ``` 318pub fn scope_any<I, T>(scopes: I, state: &AppState) -> AuthMiddlewareState 319where 320 I: IntoIterator<Item = T>, 321 T: Into<String>, 322{ 323 AuthMiddlewareState { 324 app_state: state.clone(), 325 rules: AuthRules::ScopeEqualsAny(scopes.into_iter().map(|s| s.into()).collect()), 326 } 327} 328 329/// Creates an `AuthMiddlewareState` requiring ALL of the specified scopes (AND logic). 330/// 331/// # Example 332/// ```ignore 333/// .route("/xrpc/com.atproto.admin.updateAccount", 334/// post(handler).layer(from_fn_with_state( 335/// scope_all(["account:email", "account:repo?action=manage"], &state), 336/// auth_middleware 337/// ))) 338/// ``` 339pub fn scope_all<I, T>(scopes: I, state: &AppState) -> AuthMiddlewareState 340where 341 I: IntoIterator<Item = T>, 342 T: Into<String>, 343{ 344 AuthMiddlewareState { 345 app_state: state.clone(), 346 rules: AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), 347 } 348} 349 350// ============================================================================ 351// Combined Rule Helpers (Identity + Scope) 352// ============================================================================ 353 354/// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have a specific scope. 355/// 356/// # Example 357/// ```ignore 358/// .route("/xrpc/community.blacksky.feed.generator", 359/// post(handler).layer(from_fn_with_state( 360/// handle_ends_with_and_scope(".blacksky.team", "transition:generic", &state), 361/// auth_middleware 362/// ))) 363/// ``` 364pub fn handle_ends_with_and_scope( 365 suffix: impl Into<String>, 366 scope: impl Into<String>, 367 state: &AppState, 368) -> AuthMiddlewareState { 369 AuthMiddlewareState { 370 app_state: state.clone(), 371 rules: AuthRules::All(vec![ 372 AuthRules::HandleEndsWith(suffix.into()), 373 AuthRules::ScopeEquals(scope.into()), 374 ]), 375 } 376} 377 378/// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have ALL specified scopes. 379/// 380/// # Example 381/// ```ignore 382/// .route("/xrpc/community.blacksky.admin.manage", 383/// post(handler).layer(from_fn_with_state( 384/// handle_ends_with_and_scopes(".blacksky.team", ["transition:generic", "identity:*"], &state), 385/// auth_middleware 386/// ))) 387/// ``` 388pub fn handle_ends_with_and_scopes<I, T>( 389 suffix: impl Into<String>, 390 scopes: I, 391 state: &AppState, 392) -> AuthMiddlewareState 393where 394 I: IntoIterator<Item = T>, 395 T: Into<String>, 396{ 397 AuthMiddlewareState { 398 app_state: state.clone(), 399 rules: AuthRules::All(vec![ 400 AuthRules::HandleEndsWith(suffix.into()), 401 AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), 402 ]), 403 } 404} 405 406/// Creates an `AuthMiddlewareState` requiring DID to equal value AND have a specific scope. 407/// 408/// # Example 409/// ```ignore 410/// .route("/xrpc/com.atproto.admin.deleteAccount", 411/// post(handler).layer(from_fn_with_state( 412/// did_with_scope("did:plc:rnpkyqnmsw4ipey6eotbdnnf", "transition:generic", &state), 413/// auth_middleware 414/// ))) 415/// ``` 416pub fn did_with_scope( 417 did: impl Into<String>, 418 scope: impl Into<String>, 419 state: &AppState, 420) -> AuthMiddlewareState { 421 AuthMiddlewareState { 422 app_state: state.clone(), 423 rules: AuthRules::All(vec![ 424 AuthRules::DidEquals(did.into()), 425 AuthRules::ScopeEquals(scope.into()), 426 ]), 427 } 428} 429 430/// Creates an `AuthMiddlewareState` requiring DID to equal value AND have ALL specified scopes. 431/// 432/// # Example 433/// ```ignore 434/// .route("/xrpc/com.atproto.admin.fullAccess", 435/// post(handler).layer(from_fn_with_state( 436/// did_with_scopes("did:plc:rnpkyqnmsw4ipey6eotbdnnf", ["transition:generic", "identity:*"], &state), 437/// auth_middleware 438/// ))) 439/// ``` 440pub fn did_with_scopes<I, T>( 441 did: impl Into<String>, 442 scopes: I, 443 state: &AppState, 444) -> AuthMiddlewareState 445where 446 I: IntoIterator<Item = T>, 447 T: Into<String>, 448{ 449 AuthMiddlewareState { 450 app_state: state.clone(), 451 rules: AuthRules::All(vec![ 452 AuthRules::DidEquals(did.into()), 453 AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()), 454 ]), 455 } 456}