Alternative ATProto PDS implementation
1use anyhow::{Context as _, anyhow};
2use atrium_crypto::{
3 keypair::{Did as _, Secp256k1Keypair},
4 verify::Verifier,
5};
6use axum::{extract::FromRequestParts, http::StatusCode};
7use base64::Engine as _;
8use diesel::prelude::*;
9use sha2::{Digest as _, Sha256};
10
11use crate::{AppState, Error, error::ErrorMessage};
12
13/// Request extractor for authenticated users.
14/// If specified in an API endpoint, this guarantees the API can only be called
15/// by an authenticated user.
16pub(crate) struct AuthenticatedUser {
17 /// The DID of the authenticated user.
18 did: String,
19}
20
21impl AuthenticatedUser {
22 /// Get the DID of the authenticated user.
23 pub(crate) fn did(&self) -> String {
24 self.did.clone()
25 }
26}
27
28impl FromRequestParts<AppState> for AuthenticatedUser {
29 type Rejection = Error;
30
31 async fn from_request_parts(
32 parts: &mut axum::http::request::Parts,
33 state: &AppState,
34 ) -> Result<Self, Self::Rejection> {
35 // Check for authorization header (either for Bearer or DPoP tokens)
36 let auth_header = parts
37 .headers
38 .get(axum::http::header::AUTHORIZATION)
39 .ok_or_else(|| {
40 Error::with_status(StatusCode::UNAUTHORIZED, anyhow!("no authorization header"))
41 })?;
42
43 let auth_str = auth_header.to_str().map_err(|e| {
44 Error::with_status(
45 StatusCode::UNAUTHORIZED,
46 anyhow!("authorization header should be valid utf-8").context(e),
47 )
48 })?;
49
50 // Check for DPoP header
51 let dpop_header = parts.headers.get("dpop");
52 let has_dpop = dpop_header.is_some();
53
54 // Handle different token types
55 if auth_str.starts_with("Bearer ") || auth_str.starts_with("DPoP ") {
56 let token = auth_str
57 .split_once(' ')
58 .expect("Auth string should have a space")
59 .1;
60
61 if has_dpop {
62 // Process DPoP token - the Authorization header contains the access token
63 // and the DPoP header contains the proof
64 let dpop_token = dpop_header
65 .expect("DPoP header should exist")
66 .to_str()
67 .map_err(|e| {
68 Error::with_status(
69 StatusCode::UNAUTHORIZED,
70 anyhow!("DPoP header should be valid utf-8").context(e),
71 )
72 })?;
73
74 return validate_dpop_token(token, dpop_token, parts, state).await;
75 }
76
77 // Standard Bearer token
78 return validate_bearer_token(token, state).await;
79 }
80
81 // If we reach here, no valid authorization method was found
82 Err(Error::with_status(
83 StatusCode::UNAUTHORIZED,
84 anyhow!("unsupported authorization method"),
85 ))
86 }
87}
88
89/// Validate a standard Bearer token
90async fn validate_bearer_token(token: &str, state: &AppState) -> Result<AuthenticatedUser, Error> {
91 // Validate JWT token
92 let (typ, claims) = verify(&state.signing_key.did(), token)
93 .map_err(|e| {
94 Error::with_status(
95 StatusCode::UNAUTHORIZED,
96 e.context("failed to verify bearer token"),
97 )
98 })
99 .context("token auth should verify")?;
100
101 // Ensure this is an authentication token.
102 if typ != "at+jwt" {
103 return Err(Error::with_status(
104 StatusCode::UNAUTHORIZED,
105 anyhow!("invalid token type: {typ}"),
106 ));
107 }
108
109 // Ensure we are in the audience field.
110 if let Some(aud) = claims.get("aud").and_then(serde_json::Value::as_str) {
111 if aud != format!("did:web:{}", state.config.host_name) {
112 return Err(Error::with_status(
113 StatusCode::UNAUTHORIZED,
114 anyhow!("invalid audience: {aud}"),
115 ));
116 }
117 }
118
119 // Check token expiration
120 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) {
121 let now = chrono::Utc::now().timestamp();
122 if now >= exp {
123 return Err(Error::with_message(
124 StatusCode::BAD_REQUEST,
125 anyhow!("token has expired"),
126 ErrorMessage::new("ExpiredToken", "Token has expired"),
127 ));
128 }
129 }
130
131 // Extract subject (DID)
132 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) {
133 use crate::schema::pds::account::dsl as AccountSchema;
134 let did_clone = did.to_owned();
135
136 let _did = state
137 .db
138 .get()
139 .await
140 .expect("failed to get db connection")
141 .interact(move |conn| {
142 AccountSchema::account
143 .filter(AccountSchema::did.eq(did_clone))
144 .select(AccountSchema::did)
145 .first::<String>(conn)
146 })
147 .await
148 .expect("failed to query account");
149
150 Ok(AuthenticatedUser {
151 did: did.to_owned(),
152 })
153 } else {
154 Err(Error::with_status(
155 StatusCode::UNAUTHORIZED,
156 anyhow!("invalid authorization token: missing subject"),
157 ))
158 }
159}
160
161#[expect(clippy::too_many_lines, reason = "validating dpop has many loc")]
162/// Validate a DPoP token and proof
163async fn validate_dpop_token(
164 access_token: &str,
165 dpop_token: &str,
166 parts: &axum::http::request::Parts,
167 state: &AppState,
168) -> Result<AuthenticatedUser, Error> {
169 // Step 1: Parse and validate the access token
170 let (typ, claims) = verify(&state.signing_key.did(), access_token)
171 .map_err(|e| {
172 Error::with_status(
173 StatusCode::UNAUTHORIZED,
174 e.context("failed to verify access token"),
175 )
176 })
177 .context("access token auth should verify")?;
178
179 // Ensure this is an access token JWT
180 if typ != "at+jwt" {
181 return Err(Error::with_status(
182 StatusCode::UNAUTHORIZED,
183 anyhow!("invalid token type: {typ}"),
184 ));
185 }
186
187 // Check token expiration
188 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) {
189 let now = chrono::Utc::now().timestamp();
190 if now >= exp {
191 return Err(Error::with_message(
192 StatusCode::BAD_REQUEST,
193 anyhow!("token has expired"),
194 ErrorMessage::new("ExpiredToken", "Token has expired"),
195 ));
196 }
197 }
198
199 // Step 2: Parse and validate the DPoP proof
200 let dpop_parts: Vec<&str> = dpop_token.split('.').collect();
201 if dpop_parts.len() != 3 {
202 return Err(Error::with_status(
203 StatusCode::UNAUTHORIZED,
204 anyhow!("invalid DPoP token format"),
205 ));
206 }
207
208 // Decode header
209 let dpop_header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
210 .decode(dpop_parts.first().context("header part missing")?)
211 .context("failed to decode DPoP header")?;
212
213 let dpop_header: serde_json::Value =
214 serde_json::from_slice(&dpop_header_bytes).context("failed to parse DPoP header")?;
215
216 // Check typ is "dpop+jwt"
217 if dpop_header.get("typ").and_then(|v| v.as_str()) != Some("dpop+jwt") {
218 return Err(Error::with_status(
219 StatusCode::UNAUTHORIZED,
220 anyhow!("invalid DPoP token type"),
221 ));
222 }
223
224 // Decode claims
225 let dpop_claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
226 .decode(dpop_parts.get(1).context("claims part missing")?)
227 .context("failed to decode DPoP claims")?;
228
229 let dpop_claims: serde_json::Value =
230 serde_json::from_slice(&dpop_claims_bytes).context("failed to parse DPoP claims")?;
231
232 // Check HTTP method
233 if dpop_claims.get("htm").and_then(|v| v.as_str()) != Some(parts.method.as_str()) {
234 return Err(Error::with_status(
235 StatusCode::UNAUTHORIZED,
236 anyhow!("DPoP token HTTP method mismatch"),
237 ));
238 }
239
240 // Check HTTP URI
241 let expected_uri = format!(
242 "https://{}/xrpc{}",
243 state.config.host_name,
244 parts.uri.path_and_query().expect("path and query to exist")
245 );
246 if dpop_claims.get("htu").and_then(|v| v.as_str()) != Some(&expected_uri) {
247 return Err(Error::with_status(
248 StatusCode::UNAUTHORIZED,
249 anyhow!(
250 "DPoP token HTTP URI mismatch: expected {}, got {}",
251 expected_uri,
252 dpop_claims
253 .get("htu")
254 .and_then(|v| v.as_str())
255 .unwrap_or("None")
256 ),
257 ));
258 }
259
260 // Verify access token hash (ath)
261 if let Some(ath) = dpop_claims.get("ath").and_then(|v| v.as_str()) {
262 let mut hasher = Sha256::new();
263 hasher.update(access_token.as_bytes());
264 let token_hash = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
265
266 if ath != token_hash {
267 return Err(Error::with_status(
268 StatusCode::UNAUTHORIZED,
269 anyhow!("DPoP access token hash mismatch"),
270 ));
271 }
272 }
273
274 // Extract JWK from DPoP header
275 let jwk = dpop_header
276 .get("jwk")
277 .context("missing jwk in DPoP header")?;
278
279 // Calculate JWK thumbprint - RFC7638 compliant
280 let key_type = jwk
281 .get("kty")
282 .and_then(|v| v.as_str())
283 .context("JWK missing kty property")?;
284
285 // Define required properties for each key type
286 let required_props = match key_type {
287 "EC" => &["crv", "kty", "x", "y"][..],
288 "RSA" => &["e", "kty", "n"][..],
289 _ => {
290 return Err(Error::with_status(
291 StatusCode::UNAUTHORIZED,
292 anyhow!("Unsupported JWK key type: {}", key_type),
293 ));
294 }
295 };
296
297 // Build a new JWK with only the required properties
298 let mut canonical_jwk = serde_json::Map::new();
299
300 for prop in required_props {
301 let value = jwk
302 .get(prop)
303 .context(format!("JWK missing required property: {prop}"))?;
304 drop(canonical_jwk.insert((*prop).to_owned(), value.clone()));
305 }
306
307 // Serialize with no whitespace
308 let canonical_json = serde_json::to_string(&serde_json::Value::Object(canonical_jwk))
309 .context("Failed to serialize canonical JWK")?;
310
311 // Hash the canonical representation
312 let mut hasher = Sha256::new();
313 hasher.update(canonical_json.as_bytes());
314 let calculated_thumbprint =
315 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
316
317 // Get JWK thumbprint from the access token "cnf" claim
318 let jkt = claims
319 .get("cnf")
320 .and_then(|v| v.as_object())
321 .and_then(|o| o.get("jkt"))
322 .and_then(|v| v.as_str())
323 .context("missing or invalid 'cnf.jkt' in access token")?;
324
325 // Verify JWK thumbprint
326 if calculated_thumbprint != jkt {
327 return Err(Error::with_status(
328 StatusCode::UNAUTHORIZED,
329 anyhow!("JWK thumbprint mismatch"),
330 ));
331 }
332
333 // Check JTI replay using the database
334 let jti = dpop_claims
335 .get("jti")
336 .and_then(|j| j.as_str())
337 .context("Missing jti claim in DPoP token")?;
338
339 let timestamp = chrono::Utc::now().timestamp();
340
341 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema;
342
343 // Check if JTI has been used before
344 let jti_string = jti.to_owned();
345 let jti_used = state
346 .db
347 .get()
348 .await
349 .expect("failed to get db connection")
350 .interact(move |conn| {
351 JtiSchema::oauth_used_jtis
352 .filter(JtiSchema::jti.eq(jti_string))
353 .count()
354 .get_result::<i64>(conn)
355 })
356 .await
357 .expect("failed to query JTI")
358 .expect("failed to get JTI count");
359
360 if jti_used > 0 {
361 return Err(Error::with_status(
362 StatusCode::UNAUTHORIZED,
363 anyhow!("DPoP token has been replayed"),
364 ));
365 }
366
367 // Store the JTI to prevent replay attacks
368 // Get expiry from token or default to 60 seconds
369 let exp = dpop_claims
370 .get("exp")
371 .and_then(serde_json::Value::as_i64)
372 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp));
373
374 // Convert SQLx INSERT to Diesel
375 let jti_str = jti.to_owned();
376 let thumbprint_str = calculated_thumbprint.to_string();
377 let _ = state
378 .db
379 .get()
380 .await
381 .expect("failed to get db connection")
382 .interact(move |conn| {
383 diesel::insert_into(JtiSchema::oauth_used_jtis)
384 .values((
385 JtiSchema::jti.eq(jti_str),
386 JtiSchema::issuer.eq(thumbprint_str),
387 JtiSchema::created_at.eq(timestamp),
388 JtiSchema::expires_at.eq(exp),
389 ))
390 .execute(conn)
391 })
392 .await
393 .expect("failed to insert JTI")
394 .expect("failed to insert JTI");
395
396 // Extract subject (DID) from access token
397 if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) {
398 use crate::schema::pds::account::dsl as AccountSchema;
399
400 let did_clone = did.to_owned();
401
402 let _did = state
403 .db
404 .get()
405 .await
406 .expect("failed to get db connection")
407 .interact(move |conn| {
408 AccountSchema::account
409 .filter(AccountSchema::did.eq(did_clone))
410 .select(AccountSchema::did)
411 .first::<String>(conn)
412 })
413 .await
414 .expect("failed to query account")
415 .expect("failed to get account");
416
417 Ok(AuthenticatedUser {
418 did: did.to_owned(),
419 })
420 } else {
421 Err(Error::with_status(
422 StatusCode::UNAUTHORIZED,
423 anyhow!("invalid access token: missing subject"),
424 ))
425 }
426}
427
428/// Cryptographically sign a JSON web token with the specified key.
429pub(crate) fn sign(
430 key: &Secp256k1Keypair,
431 typ: &str,
432 claims: &serde_json::Value,
433) -> anyhow::Result<String> {
434 // RFC 9068
435 let hdr = serde_json::json!({
436 "typ": typ,
437 "alg": "ES256K", // Secp256k1Keypair
438 });
439 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD
440 .encode(serde_json::to_vec(&hdr).context("failed to encode claims")?);
441 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD
442 .encode(serde_json::to_vec(&claims).context("failed to encode claims")?);
443 let sig = key
444 .sign(format!("{hdr}.{claims}").as_bytes())
445 .context("failed to sign jwt")?;
446 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig);
447
448 Ok(format!("{hdr}.{claims}.{sig}"))
449}
450
451/// Cryptographically verify a JSON web token's validity using the specified public key.
452pub(crate) fn verify(key: &str, token: &str) -> anyhow::Result<(String, serde_json::Value)> {
453 let mut parts = token.splitn(3, '.');
454 let hdr = parts.next().context("no header")?;
455 let claims = parts.next().context("no claims")?;
456 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD
457 .decode(parts.next().context("no sig")?)
458 .context("failed to decode signature")?;
459
460 let (alg, key) = atrium_crypto::did::parse_did_key(key).context("failed to decode key")?;
461 Verifier::default()
462 .verify(alg, &key, format!("{hdr}.{claims}").as_bytes(), &sig)
463 .context("failed to verify jwt")?;
464
465 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD
466 .decode(hdr)
467 .context("failed to decode hdr")?;
468 let hdr =
469 serde_json::from_slice::<serde_json::Value>(&hdr).context("failed to parse hdr as json")?;
470 let typ = hdr
471 .get("typ")
472 .and_then(serde_json::Value::as_str)
473 .context("hdr is invalid")?;
474
475 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD
476 .decode(claims)
477 .context("failed to decode claims")?;
478 let claims = serde_json::from_slice(&claims).context("failed to parse claims as json")?;
479
480 Ok((typ.to_owned(), claims))
481}