forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::helpers::json_error_response;
2use axum::extract::Request;
3use axum::http::{HeaderMap, StatusCode};
4use axum::middleware::Next;
5use axum::response::IntoResponse;
6use jwt_compact::alg::{Hs256, Hs256Key};
7use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
8use serde::{Deserialize, Serialize};
9use std::env;
10use tracing::log;
11
12#[derive(Clone, Debug)]
13pub struct Did(pub Option<String>);
14
15#[derive(Serialize, Deserialize)]
16pub struct TokenClaims {
17 pub sub: String,
18}
19
20pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse {
21 let token = extract_bearer(req.headers());
22
23 match token {
24 Ok(token) => {
25 match token {
26 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
27 .expect("Error creating an error response"),
28 Some(token) => {
29 let token = UntrustedToken::new(&token);
30 if token.is_err() {
31 return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
32 .expect("Error creating an error response");
33 }
34 let parsed_token = token.expect("Already checked for error");
35 let claims: Result<Claims<TokenClaims>, ValidationError> =
36 parsed_token.deserialize_claims_unchecked();
37 if claims.is_err() {
38 return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
39 .expect("Error creating an error response");
40 }
41
42 let key = Hs256Key::new(
43 env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
44 );
45 let token: Result<Token<TokenClaims>, ValidationError> =
46 Hs256.validator(&key).validate(&parsed_token);
47 if token.is_err() {
48 return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
49 .expect("Error creating an error response");
50 }
51 let token = token.expect("Already checked for error,");
52 //Not going to worry about expiration since it still goes to the PDS
53 req.extensions_mut()
54 .insert(Did(Some(token.claims().custom.sub.clone())));
55 next.run(req).await
56 }
57 }
58 }
59 Err(err) => {
60 log::error!("Error extracting token: {err}");
61 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
62 .expect("Error creating an error response")
63 }
64 }
65}
66
67fn extract_bearer(headers: &HeaderMap) -> Result<Option<String>, String> {
68 match headers.get(axum::http::header::AUTHORIZATION) {
69 None => Ok(None),
70 Some(hv) => match hv.to_str() {
71 Err(_) => Err("Authorization header is not valid".into()),
72 Ok(s) => {
73 // Accept forms like: "Bearer <token>" (case-sensitive for the scheme here)
74 let mut parts = s.splitn(2, ' ');
75 match (parts.next(), parts.next()) {
76 (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())),
77 _ => Err("Authorization header must be in format 'Bearer <token>'".into()),
78 }
79 }
80 },
81 }
82}