Microservice to bring 2FA to self hosted PDSes
1use super::{AuthRules, HandleCache, SessionData};
2use crate::helpers::json_error_response;
3use crate::AppState;
4use axum::extract::{Request, State};
5use axum::http::{HeaderMap, StatusCode};
6use axum::middleware::Next;
7use axum::response::{IntoResponse, Response};
8use jacquard_identity::resolver::IdentityResolver;
9use jacquard_identity::PublicResolver;
10use jwt_compact::alg::{Hs256, Hs256Key};
11use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
12use serde::{Deserialize, Serialize};
13use std::env;
14use std::sync::Arc;
15use tracing::log;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq)]
18pub enum AuthScheme {
19 Bearer,
20 DPoP,
21}
22
23#[derive(Serialize, Deserialize)]
24pub struct TokenClaims {
25 pub sub: String,
26 /// OAuth scopes as space-separated string (per OAuth 2.0 spec)
27 #[serde(default)]
28 pub scope: Option<String>,
29}
30
31/// State passed to the auth middleware containing both AppState and auth rules.
32#[derive(Clone)]
33pub struct AuthMiddlewareState {
34 pub app_state: AppState,
35 pub rules: AuthRules,
36}
37
38/// Core middleware function that validates authentication and applies auth rules.
39///
40/// Use this with `axum::middleware::from_fn_with_state`:
41/// ```ignore
42/// use axum::middleware::from_fn_with_state;
43///
44/// let mw_state = AuthMiddlewareState {
45/// app_state: state.clone(),
46/// rules: AuthRules::HandleEndsWith(".blacksky.team".into()),
47/// };
48///
49/// .route("/protected", get(handler).layer(from_fn_with_state(mw_state, auth_middleware)))
50/// ```
51pub async fn auth_middleware(
52 State(mw_state): State<AuthMiddlewareState>,
53 req: Request,
54 next: Next,
55) -> Response {
56 let AuthMiddlewareState { app_state, rules } = mw_state;
57
58 // 1. Extract DID and scopes from JWT (Bearer token)
59 let extracted = match extract_auth_from_request(req.headers()) {
60 Ok(Some(auth)) => auth,
61 Ok(None) => {
62 return json_error_response(StatusCode::UNAUTHORIZED, "AuthRequired", "Authentication required")
63 .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response());
64 }
65 Err(e) => {
66 log::error!("Token extraction error: {}", e);
67 return json_error_response(StatusCode::UNAUTHORIZED, "InvalidToken", &e)
68 .unwrap_or_else(|_| StatusCode::UNAUTHORIZED.into_response());
69 }
70 };
71
72 // 2. Resolve DID to handle (check cache first)
73 let handle = match resolve_did_to_handle(&app_state.resolver, &app_state.handle_cache, &extracted.did).await {
74 Ok(handle) => handle,
75 Err(e) => {
76 log::error!("Failed to resolve DID {} to handle: {}", extracted.did, e);
77 return json_error_response(
78 StatusCode::INTERNAL_SERVER_ERROR,
79 "ResolutionError",
80 "Failed to resolve identity",
81 )
82 .unwrap_or_else(|_| StatusCode::INTERNAL_SERVER_ERROR.into_response());
83 }
84 };
85
86 // 3. Build session data and validate rules
87 let session = SessionData {
88 did: extracted.did,
89 handle,
90 scopes: extracted.scopes,
91 };
92
93 if !rules.validate(&session) {
94 return json_error_response(StatusCode::FORBIDDEN, "AccessDenied", "Access denied by authorization rules")
95 .unwrap_or_else(|_| StatusCode::FORBIDDEN.into_response());
96 }
97
98 // 4. Pass through on success
99 next.run(req).await
100}
101
102/// Extracted authentication data from JWT
103struct ExtractedAuth {
104 did: String,
105 scopes: Vec<String>,
106}
107
108/// Extracts the DID and scopes from the Authorization header (Bearer JWT).
109fn extract_auth_from_request(headers: &HeaderMap) -> Result<Option<ExtractedAuth>, String> {
110 let auth = extract_auth(headers)?;
111
112 match auth {
113 None => Ok(None),
114 Some((scheme, token_str)) => {
115 match scheme {
116 AuthScheme::Bearer => {
117 let token = UntrustedToken::new(&token_str)
118 .map_err(|_| "Invalid token format".to_string())?;
119
120 let _claims: Claims<TokenClaims> = token
121 .deserialize_claims_unchecked()
122 .map_err(|_| "Failed to parse token claims".to_string())?;
123
124 let key = Hs256Key::new(
125 env::var("PDS_JWT_SECRET")
126 .map_err(|_| "PDS_JWT_SECRET not configured".to_string())?,
127 );
128
129 let validated: Token<TokenClaims> = Hs256
130 .validator(&key)
131 .validate(&token)
132 .map_err(|e: ValidationError| format!("Token validation failed: {:?}", e))?;
133
134 let custom = &validated.claims().custom;
135
136 // Parse scopes from space-separated string (OAuth 2.0 spec)
137 let scopes: Vec<String> = custom.scope
138 .as_ref()
139 .map(|s| s.split_whitespace().map(|s| s.to_string()).collect())
140 .unwrap_or_default();
141
142 Ok(Some(ExtractedAuth {
143 did: custom.sub.clone(),
144 scopes,
145 }))
146 }
147 AuthScheme::DPoP => {
148 // DPoP tokens are not validated here; pass through without auth data
149 Ok(None)
150 }
151 }
152 }
153 }
154}
155
156/// Extracts the authentication scheme and token from the Authorization header.
157fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> {
158 match headers.get(axum::http::header::AUTHORIZATION) {
159 None => Ok(None),
160 Some(hv) => {
161 let s = hv
162 .to_str()
163 .map_err(|_| "Authorization header is not valid UTF-8".to_string())?;
164
165 let mut parts = s.splitn(2, ' ');
166 match (parts.next(), parts.next()) {
167 (Some("Bearer"), Some(tok)) if !tok.is_empty() => {
168 Ok(Some((AuthScheme::Bearer, tok.to_string())))
169 }
170 (Some("DPoP"), Some(tok)) if !tok.is_empty() => {
171 Ok(Some((AuthScheme::DPoP, tok.to_string())))
172 }
173 _ => Err(
174 "Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'"
175 .to_string(),
176 ),
177 }
178 }
179 }
180}
181
182/// Resolves a DID to its handle using the PublicResolver, with caching.
183async fn resolve_did_to_handle(
184 resolver: &Arc<PublicResolver>,
185 cache: &HandleCache,
186 did: &str,
187) -> Result<String, String> {
188 // Check cache first
189 if let Some(handle) = cache.get(did) {
190 return Ok(handle);
191 }
192
193 // Parse the DID
194 let did_parsed = jacquard_common::types::did::Did::new(did)
195 .map_err(|e| format!("Invalid DID: {:?}", e))?;
196
197 // Resolve the DID document
198 let did_doc_response = resolver
199 .resolve_did_doc(&did_parsed)
200 .await
201 .map_err(|e| format!("DID resolution failed: {:?}", e))?;
202
203 let doc = did_doc_response
204 .parse()
205 .map_err(|e| format!("Failed to parse DID document: {:?}", e))?;
206
207 // Extract handle from alsoKnownAs field
208 // Format is typically: ["at://handle.example.com"]
209 let handle: String = doc
210 .also_known_as
211 .as_ref()
212 .and_then(|aka| {
213 aka.iter()
214 .find(|uri| uri.starts_with("at://"))
215 .map(|uri| uri.strip_prefix("at://").unwrap_or(uri.as_ref()).to_string())
216 })
217 .ok_or_else(|| "No ATProto handle found in DID document".to_string())?;
218
219 // Cache the result
220 cache.insert(did.to_string(), handle.clone());
221
222 Ok(handle)
223}
224
225// ============================================================================
226// Helper Functions for Creating Middleware State
227// ============================================================================
228
229/// Creates an `AuthMiddlewareState` for requiring the handle to end with a specific suffix.
230///
231/// # Example
232/// ```ignore
233/// use axum::middleware::from_fn_with_state;
234/// use crate::auth::{auth_middleware, handle_ends_with};
235///
236/// .route("/protected", get(handler).layer(
237/// from_fn_with_state(handle_ends_with(".blacksky.team", &state), auth_middleware)
238/// ))
239/// ```
240pub fn handle_ends_with(suffix: impl Into<String>, state: &AppState) -> AuthMiddlewareState {
241 AuthMiddlewareState {
242 app_state: state.clone(),
243 rules: AuthRules::HandleEndsWith(suffix.into()),
244 }
245}
246
247/// Creates an `AuthMiddlewareState` for requiring the handle to end with any of the specified suffixes.
248pub fn handle_ends_with_any<I, T>(suffixes: I, state: &AppState) -> AuthMiddlewareState
249where
250 I: IntoIterator<Item = T>,
251 T: Into<String>,
252{
253 AuthMiddlewareState {
254 app_state: state.clone(),
255 rules: AuthRules::HandleEndsWithAny(suffixes.into_iter().map(|s| s.into()).collect()),
256 }
257}
258
259/// Creates an `AuthMiddlewareState` for requiring the DID to equal a specific value.
260pub fn did_equals(did: impl Into<String>, state: &AppState) -> AuthMiddlewareState {
261 AuthMiddlewareState {
262 app_state: state.clone(),
263 rules: AuthRules::DidEquals(did.into()),
264 }
265}
266
267/// Creates an `AuthMiddlewareState` for requiring the DID to be one of the specified values.
268pub fn did_equals_any<I, T>(dids: I, state: &AppState) -> AuthMiddlewareState
269where
270 I: IntoIterator<Item = T>,
271 T: Into<String>,
272{
273 AuthMiddlewareState {
274 app_state: state.clone(),
275 rules: AuthRules::DidEqualsAny(dids.into_iter().map(|d| d.into()).collect()),
276 }
277}
278
279/// Creates an `AuthMiddlewareState` with custom auth rules.
280pub fn with_rules(rules: AuthRules, state: &AppState) -> AuthMiddlewareState {
281 AuthMiddlewareState {
282 app_state: state.clone(),
283 rules,
284 }
285}
286
287// ============================================================================
288// Scope Helper Functions
289// ============================================================================
290
291/// Creates an `AuthMiddlewareState` requiring a specific OAuth scope.
292///
293/// # Example
294/// ```ignore
295/// .route("/xrpc/com.atproto.repo.createRecord",
296/// post(handler).layer(from_fn_with_state(
297/// scope_equals("repo:app.bsky.feed.post", &state),
298/// auth_middleware
299/// )))
300/// ```
301pub fn scope_equals(scope: impl Into<String>, state: &AppState) -> AuthMiddlewareState {
302 AuthMiddlewareState {
303 app_state: state.clone(),
304 rules: AuthRules::ScopeEquals(scope.into()),
305 }
306}
307
308/// Creates an `AuthMiddlewareState` requiring ANY of the specified scopes (OR logic).
309///
310/// # Example
311/// ```ignore
312/// .route("/xrpc/com.atproto.repo.putRecord",
313/// post(handler).layer(from_fn_with_state(
314/// scope_any(["repo:app.bsky.feed.post", "transition:generic"], &state),
315/// auth_middleware
316/// )))
317/// ```
318pub fn scope_any<I, T>(scopes: I, state: &AppState) -> AuthMiddlewareState
319where
320 I: IntoIterator<Item = T>,
321 T: Into<String>,
322{
323 AuthMiddlewareState {
324 app_state: state.clone(),
325 rules: AuthRules::ScopeEqualsAny(scopes.into_iter().map(|s| s.into()).collect()),
326 }
327}
328
329/// Creates an `AuthMiddlewareState` requiring ALL of the specified scopes (AND logic).
330///
331/// # Example
332/// ```ignore
333/// .route("/xrpc/com.atproto.admin.updateAccount",
334/// post(handler).layer(from_fn_with_state(
335/// scope_all(["account:email", "account:repo?action=manage"], &state),
336/// auth_middleware
337/// )))
338/// ```
339pub fn scope_all<I, T>(scopes: I, state: &AppState) -> AuthMiddlewareState
340where
341 I: IntoIterator<Item = T>,
342 T: Into<String>,
343{
344 AuthMiddlewareState {
345 app_state: state.clone(),
346 rules: AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()),
347 }
348}
349
350// ============================================================================
351// Combined Rule Helpers (Identity + Scope)
352// ============================================================================
353
354/// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have a specific scope.
355///
356/// # Example
357/// ```ignore
358/// .route("/xrpc/community.blacksky.feed.generator",
359/// post(handler).layer(from_fn_with_state(
360/// handle_ends_with_and_scope(".blacksky.team", "transition:generic", &state),
361/// auth_middleware
362/// )))
363/// ```
364pub fn handle_ends_with_and_scope(
365 suffix: impl Into<String>,
366 scope: impl Into<String>,
367 state: &AppState,
368) -> AuthMiddlewareState {
369 AuthMiddlewareState {
370 app_state: state.clone(),
371 rules: AuthRules::All(vec![
372 AuthRules::HandleEndsWith(suffix.into()),
373 AuthRules::ScopeEquals(scope.into()),
374 ]),
375 }
376}
377
378/// Creates an `AuthMiddlewareState` requiring handle to end with suffix AND have ALL specified scopes.
379///
380/// # Example
381/// ```ignore
382/// .route("/xrpc/community.blacksky.admin.manage",
383/// post(handler).layer(from_fn_with_state(
384/// handle_ends_with_and_scopes(".blacksky.team", ["transition:generic", "identity:*"], &state),
385/// auth_middleware
386/// )))
387/// ```
388pub fn handle_ends_with_and_scopes<I, T>(
389 suffix: impl Into<String>,
390 scopes: I,
391 state: &AppState,
392) -> AuthMiddlewareState
393where
394 I: IntoIterator<Item = T>,
395 T: Into<String>,
396{
397 AuthMiddlewareState {
398 app_state: state.clone(),
399 rules: AuthRules::All(vec![
400 AuthRules::HandleEndsWith(suffix.into()),
401 AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()),
402 ]),
403 }
404}
405
406/// Creates an `AuthMiddlewareState` requiring DID to equal value AND have a specific scope.
407///
408/// # Example
409/// ```ignore
410/// .route("/xrpc/com.atproto.admin.deleteAccount",
411/// post(handler).layer(from_fn_with_state(
412/// did_with_scope("did:plc:rnpkyqnmsw4ipey6eotbdnnf", "transition:generic", &state),
413/// auth_middleware
414/// )))
415/// ```
416pub fn did_with_scope(
417 did: impl Into<String>,
418 scope: impl Into<String>,
419 state: &AppState,
420) -> AuthMiddlewareState {
421 AuthMiddlewareState {
422 app_state: state.clone(),
423 rules: AuthRules::All(vec![
424 AuthRules::DidEquals(did.into()),
425 AuthRules::ScopeEquals(scope.into()),
426 ]),
427 }
428}
429
430/// Creates an `AuthMiddlewareState` requiring DID to equal value AND have ALL specified scopes.
431///
432/// # Example
433/// ```ignore
434/// .route("/xrpc/com.atproto.admin.fullAccess",
435/// post(handler).layer(from_fn_with_state(
436/// did_with_scopes("did:plc:rnpkyqnmsw4ipey6eotbdnnf", ["transition:generic", "identity:*"], &state),
437/// auth_middleware
438/// )))
439/// ```
440pub fn did_with_scopes<I, T>(
441 did: impl Into<String>,
442 scopes: I,
443 state: &AppState,
444) -> AuthMiddlewareState
445where
446 I: IntoIterator<Item = T>,
447 T: Into<String>,
448{
449 AuthMiddlewareState {
450 app_state: state.clone(),
451 rules: AuthRules::All(vec![
452 AuthRules::DidEquals(did.into()),
453 AuthRules::ScopeEqualsAll(scopes.into_iter().map(|s| s.into()).collect()),
454 ]),
455 }
456}