forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::helpers::json_error_response;
2use axum::extract::Request;
3use axum::http::{HeaderMap, StatusCode};
4use axum::middleware::Next;
5use axum::response::IntoResponse;
6use jwt_compact::alg::{Hs256, Hs256Key};
7use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
8use serde::{Deserialize, Serialize};
9use std::env;
10use tracing::log;
11
12#[derive(Clone, Debug)]
13pub struct Did(pub Option<String>);
14
15#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16pub enum AuthScheme {
17 Bearer,
18 DPoP,
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct TokenClaims {
23 pub sub: String,
24}
25
26pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse {
27 let auth = extract_auth(req.headers());
28
29 match auth {
30 Ok(auth_opt) => {
31 match auth_opt {
32 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
33 .expect("Error creating an error response"),
34 Some((scheme, token_str)) => {
35 // For Bearer, validate JWT and extract DID from `sub`.
36 // For DPoP, we currently only pass through and do not validate here; insert None DID.
37 match scheme {
38 AuthScheme::Bearer => {
39 let token = UntrustedToken::new(&token_str);
40 if token.is_err() {
41 return json_error_response(
42 StatusCode::BAD_REQUEST,
43 "TokenRequired",
44 "",
45 )
46 .expect("Error creating an error response");
47 }
48 let parsed_token = token.expect("Already checked for error");
49 let claims: Result<Claims<TokenClaims>, ValidationError> =
50 parsed_token.deserialize_claims_unchecked();
51 if claims.is_err() {
52 return json_error_response(
53 StatusCode::BAD_REQUEST,
54 "TokenRequired",
55 "",
56 )
57 .expect("Error creating an error response");
58 }
59
60 let key = Hs256Key::new(
61 env::var("PDS_JWT_SECRET")
62 .expect("PDS_JWT_SECRET not set in the pds.env"),
63 );
64 let token: Result<Token<TokenClaims>, ValidationError> =
65 Hs256.validator(&key).validate(&parsed_token);
66 if token.is_err() {
67 return json_error_response(
68 StatusCode::BAD_REQUEST,
69 "InvalidToken",
70 "",
71 )
72 .expect("Error creating an error response");
73 }
74 let token = token.expect("Already checked for error,");
75 req.extensions_mut()
76 .insert(Did(Some(token.claims().custom.sub.clone())));
77 }
78 AuthScheme::DPoP => {
79 //Not going to worry about oauth email update for now, just always forward to the PDS
80 req.extensions_mut().insert(Did(None));
81 }
82 }
83
84 next.run(req).await
85 }
86 }
87 }
88 Err(err) => {
89 log::error!("Error extracting token: {err}");
90 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
91 .expect("Error creating an error response")
92 }
93 }
94}
95
96fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> {
97 match headers.get(axum::http::header::AUTHORIZATION) {
98 None => Ok(None),
99 Some(hv) => {
100 match hv.to_str() {
101 Err(_) => Err("Authorization header is not valid".into()),
102 Ok(s) => {
103 // Accept forms like: "Bearer <token>" or "DPoP <token>" (case-sensitive for the scheme here)
104 let mut parts = s.splitn(2, ' ');
105 match (parts.next(), parts.next()) {
106 (Some("Bearer"), Some(tok)) if !tok.is_empty() =>
107 Ok(Some((AuthScheme::Bearer, tok.to_string()))),
108 (Some("DPoP"), Some(tok)) if !tok.is_empty() =>
109 Ok(Some((AuthScheme::DPoP, tok.to_string()))),
110 _ => Err("Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'".into()),
111 }
112 }
113 }
114 }
115 }
116}