Microservice to bring 2FA to self hosted PDSes
1use crate::helpers::json_error_response; 2use axum::extract::Request; 3use axum::http::{HeaderMap, StatusCode}; 4use axum::middleware::Next; 5use axum::response::IntoResponse; 6use jwt_compact::alg::{Hs256, Hs256Key}; 7use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; 8use serde::{Deserialize, Serialize}; 9use std::env; 10use tracing::log; 11 12#[derive(Clone, Debug)] 13pub struct Did(pub Option<String>); 14 15#[derive(Serialize, Deserialize)] 16pub struct TokenClaims { 17 pub sub: String, 18} 19 20pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { 21 let token = extract_bearer(req.headers()); 22 23 match token { 24 Ok(token) => { 25 match token { 26 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 27 .expect("Error creating an error response"), 28 Some(token) => { 29 let token = UntrustedToken::new(&token); 30 if token.is_err() { 31 return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 32 .expect("Error creating an error response"); 33 } 34 let parsed_token = token.expect("Already checked for error"); 35 let claims: Result<Claims<TokenClaims>, ValidationError> = 36 parsed_token.deserialize_claims_unchecked(); 37 if claims.is_err() { 38 return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 39 .expect("Error creating an error response"); 40 } 41 42 let key = Hs256Key::new( 43 env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), 44 ); 45 let token: Result<Token<TokenClaims>, ValidationError> = 46 Hs256.validator(&key).validate(&parsed_token); 47 if token.is_err() { 48 return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 49 .expect("Error creating an error response"); 50 } 51 let token = token.expect("Already checked for error,"); 52 //Not going to worry about expiration since it still goes to the PDS 53 req.extensions_mut() 54 .insert(Did(Some(token.claims().custom.sub.clone()))); 55 next.run(req).await 56 } 57 } 58 } 59 Err(err) => { 60 log::error!("Error extracting token: {err}"); 61 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 62 .expect("Error creating an error response") 63 } 64 } 65} 66 67fn extract_bearer(headers: &HeaderMap) -> Result<Option<String>, String> { 68 match headers.get(axum::http::header::AUTHORIZATION) { 69 None => Ok(None), 70 Some(hv) => match hv.to_str() { 71 Err(_) => Err("Authorization header is not valid".into()), 72 Ok(s) => { 73 // Accept forms like: "Bearer <token>" (case-sensitive for the scheme here) 74 let mut parts = s.splitn(2, ' '); 75 match (parts.next(), parts.next()) { 76 (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())), 77 _ => Err("Authorization header must be in format 'Bearer <token>'".into()), 78 } 79 } 80 }, 81 } 82}