Alternative ATProto PDS implementation
1use anyhow::{Context as _, anyhow};
2use atrium_crypto::{
3 keypair::{Did as _, Secp256k1Keypair},
4 verify::Verifier,
5};
6use axum::{extract::FromRequestParts, http::StatusCode};
7use base64::Engine as _;
8use diesel::prelude::*;
9use sha2::{Digest as _, Sha256};
10
11use crate::{
12 error::{Error, ErrorMessage},
13 serve::AppState,
14};
15
16/// Request extractor for authenticated users.
17/// If specified in an API endpoint, this guarantees the API can only be called
18/// by an authenticated user.
19pub(crate) struct AuthenticatedUser {
20 /// The DID of the authenticated user.
21 did: String,
22}
23
24impl AuthenticatedUser {
25 /// Get the DID of the authenticated user.
26 pub(crate) fn did(&self) -> String {
27 self.did.clone()
28 }
29}
30
31impl FromRequestParts<AppState> for AuthenticatedUser {
32 type Rejection = Error;
33
34 async fn from_request_parts(
35 parts: &mut axum::http::request::Parts,
36 state: &AppState,
37 ) -> Result<Self, Self::Rejection> {
38 // Check for authorization header (either for Bearer or DPoP tokens)
39 let auth_header = parts
40 .headers
41 .get(axum::http::header::AUTHORIZATION)
42 .ok_or_else(|| {
43 Error::with_status(StatusCode::UNAUTHORIZED, anyhow!("no authorization header"))
44 })?;
45
46 let auth_str = auth_header.to_str().map_err(|e| {
47 Error::with_status(
48 StatusCode::UNAUTHORIZED,
49 anyhow!("authorization header should be valid utf-8").context(e),
50 )
51 })?;
52
53 // Check for DPoP header
54 let dpop_header = parts.headers.get("dpop");
55 let has_dpop = dpop_header.is_some();
56
57 // Handle different token types
58 if auth_str.starts_with("Bearer ") || auth_str.starts_with("DPoP ") {
59 let token = auth_str
60 .split_once(' ')
61 .expect("Auth string should have a space")
62 .1;
63
64 if has_dpop {
65 // Process DPoP token - the Authorization header contains the access token
66 // and the DPoP header contains the proof
67 let dpop_token = dpop_header
68 .expect("DPoP header should exist")
69 .to_str()
70 .map_err(|e| {
71 Error::with_status(
72 StatusCode::UNAUTHORIZED,
73 anyhow!("DPoP header should be valid utf-8").context(e),
74 )
75 })?;
76
77 return validate_dpop_token(token, dpop_token, parts, state).await;
78 }
79
80 // Standard Bearer token
81 return validate_bearer_token(token, state).await;
82 }
83
84 // If we reach here, no valid authorization method was found
85 Err(Error::with_status(
86 StatusCode::UNAUTHORIZED,
87 anyhow!("unsupported authorization method"),
88 ))
89 }
90}
91
92/// Validate a standard Bearer token
93async fn validate_bearer_token(token: &str, state: &AppState) -> Result<AuthenticatedUser, Error> {
94 // Validate JWT token
95 let (typ, claims) = verify(&state.signing_key.did(), token)
96 .map_err(|e| {
97 Error::with_status(
98 StatusCode::UNAUTHORIZED,
99 e.context("failed to verify bearer token"),
100 )
101 })
102 .context("token auth should verify")?;
103
104 // Ensure this is an authentication token.
105 if typ != "at+jwt" {
106 return Err(Error::with_status(
107 StatusCode::UNAUTHORIZED,
108 anyhow!("invalid token type: {typ}"),
109 ));
110 }
111
112 // Ensure we are in the audience field.
113 if let Some(aud) = claims.get("aud").and_then(serde_json::Value::as_str) {
114 if aud != format!("did:web:{}", state.config.host_name) {
115 return Err(Error::with_status(
116 StatusCode::UNAUTHORIZED,
117 anyhow!("invalid audience: {aud}"),
118 ));
119 }
120 }
121
122 // Check token expiration
123 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) {
124 let now = chrono::Utc::now().timestamp();
125 if now >= exp {
126 return Err(Error::with_message(
127 StatusCode::BAD_REQUEST,
128 anyhow!("token has expired"),
129 ErrorMessage::new("ExpiredToken", "Token has expired"),
130 ));
131 }
132 }
133
134 // Extract subject (DID)
135 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) {
136 use crate::schema::pds::account::dsl as AccountSchema;
137 let did_clone = did.to_owned();
138
139 let _did = state
140 .db
141 .get()
142 .await
143 .expect("failed to get db connection")
144 .interact(move |conn| {
145 AccountSchema::account
146 .filter(AccountSchema::did.eq(did_clone))
147 .select(AccountSchema::did)
148 .first::<String>(conn)
149 })
150 .await
151 .expect("failed to query account");
152
153 Ok(AuthenticatedUser {
154 did: did.to_owned(),
155 })
156 } else {
157 Err(Error::with_status(
158 StatusCode::UNAUTHORIZED,
159 anyhow!("invalid authorization token: missing subject"),
160 ))
161 }
162}
163
164#[expect(clippy::too_many_lines, reason = "validating dpop has many loc")]
165/// Validate a DPoP token and proof
166async fn validate_dpop_token(
167 access_token: &str,
168 dpop_token: &str,
169 parts: &axum::http::request::Parts,
170 state: &AppState,
171) -> Result<AuthenticatedUser, Error> {
172 // Step 1: Parse and validate the access token
173 let (typ, claims) = verify(&state.signing_key.did(), access_token)
174 .map_err(|e| {
175 Error::with_status(
176 StatusCode::UNAUTHORIZED,
177 e.context("failed to verify access token"),
178 )
179 })
180 .context("access token auth should verify")?;
181
182 // Ensure this is an access token JWT
183 if typ != "at+jwt" {
184 return Err(Error::with_status(
185 StatusCode::UNAUTHORIZED,
186 anyhow!("invalid token type: {typ}"),
187 ));
188 }
189
190 // Check token expiration
191 if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) {
192 let now = chrono::Utc::now().timestamp();
193 if now >= exp {
194 return Err(Error::with_message(
195 StatusCode::BAD_REQUEST,
196 anyhow!("token has expired"),
197 ErrorMessage::new("ExpiredToken", "Token has expired"),
198 ));
199 }
200 }
201
202 // Step 2: Parse and validate the DPoP proof
203 let dpop_parts: Vec<&str> = dpop_token.split('.').collect();
204 if dpop_parts.len() != 3 {
205 return Err(Error::with_status(
206 StatusCode::UNAUTHORIZED,
207 anyhow!("invalid DPoP token format"),
208 ));
209 }
210
211 // Decode header
212 let dpop_header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
213 .decode(dpop_parts.first().context("header part missing")?)
214 .context("failed to decode DPoP header")?;
215
216 let dpop_header: serde_json::Value =
217 serde_json::from_slice(&dpop_header_bytes).context("failed to parse DPoP header")?;
218
219 // Check typ is "dpop+jwt"
220 if dpop_header.get("typ").and_then(|v| v.as_str()) != Some("dpop+jwt") {
221 return Err(Error::with_status(
222 StatusCode::UNAUTHORIZED,
223 anyhow!("invalid DPoP token type"),
224 ));
225 }
226
227 // Decode claims
228 let dpop_claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
229 .decode(dpop_parts.get(1).context("claims part missing")?)
230 .context("failed to decode DPoP claims")?;
231
232 let dpop_claims: serde_json::Value =
233 serde_json::from_slice(&dpop_claims_bytes).context("failed to parse DPoP claims")?;
234
235 // Check HTTP method
236 if dpop_claims.get("htm").and_then(|v| v.as_str()) != Some(parts.method.as_str()) {
237 return Err(Error::with_status(
238 StatusCode::UNAUTHORIZED,
239 anyhow!("DPoP token HTTP method mismatch"),
240 ));
241 }
242
243 // Check HTTP URI
244 let expected_uri = format!(
245 "https://{}/xrpc{}",
246 state.config.host_name,
247 parts.uri.path_and_query().expect("path and query to exist")
248 );
249 if dpop_claims.get("htu").and_then(|v| v.as_str()) != Some(&expected_uri) {
250 return Err(Error::with_status(
251 StatusCode::UNAUTHORIZED,
252 anyhow!(
253 "DPoP token HTTP URI mismatch: expected {}, got {}",
254 expected_uri,
255 dpop_claims
256 .get("htu")
257 .and_then(|v| v.as_str())
258 .unwrap_or("None")
259 ),
260 ));
261 }
262
263 // Verify access token hash (ath)
264 if let Some(ath) = dpop_claims.get("ath").and_then(|v| v.as_str()) {
265 let mut hasher = Sha256::new();
266 hasher.update(access_token.as_bytes());
267 let token_hash = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
268
269 if ath != token_hash {
270 return Err(Error::with_status(
271 StatusCode::UNAUTHORIZED,
272 anyhow!("DPoP access token hash mismatch"),
273 ));
274 }
275 }
276
277 // Extract JWK from DPoP header
278 let jwk = dpop_header
279 .get("jwk")
280 .context("missing jwk in DPoP header")?;
281
282 // Calculate JWK thumbprint - RFC7638 compliant
283 let key_type = jwk
284 .get("kty")
285 .and_then(|v| v.as_str())
286 .context("JWK missing kty property")?;
287
288 // Define required properties for each key type
289 let required_props = match key_type {
290 "EC" => &["crv", "kty", "x", "y"][..],
291 "RSA" => &["e", "kty", "n"][..],
292 _ => {
293 return Err(Error::with_status(
294 StatusCode::UNAUTHORIZED,
295 anyhow!("Unsupported JWK key type: {}", key_type),
296 ));
297 }
298 };
299
300 // Build a new JWK with only the required properties
301 let mut canonical_jwk = serde_json::Map::new();
302
303 for prop in required_props {
304 let value = jwk
305 .get(prop)
306 .context(format!("JWK missing required property: {prop}"))?;
307 drop(canonical_jwk.insert((*prop).to_owned(), value.clone()));
308 }
309
310 // Serialize with no whitespace
311 let canonical_json = serde_json::to_string(&serde_json::Value::Object(canonical_jwk))
312 .context("Failed to serialize canonical JWK")?;
313
314 // Hash the canonical representation
315 let mut hasher = Sha256::new();
316 hasher.update(canonical_json.as_bytes());
317 let calculated_thumbprint =
318 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
319
320 // Get JWK thumbprint from the access token "cnf" claim
321 let jkt = claims
322 .get("cnf")
323 .and_then(|v| v.as_object())
324 .and_then(|o| o.get("jkt"))
325 .and_then(|v| v.as_str())
326 .context("missing or invalid 'cnf.jkt' in access token")?;
327
328 // Verify JWK thumbprint
329 if calculated_thumbprint != jkt {
330 return Err(Error::with_status(
331 StatusCode::UNAUTHORIZED,
332 anyhow!("JWK thumbprint mismatch"),
333 ));
334 }
335
336 // Check JTI replay using the database
337 let jti = dpop_claims
338 .get("jti")
339 .and_then(|j| j.as_str())
340 .context("Missing jti claim in DPoP token")?;
341
342 let timestamp = chrono::Utc::now().timestamp();
343
344 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema;
345
346 // Check if JTI has been used before
347 let jti_string = jti.to_owned();
348 let jti_used = state
349 .db
350 .get()
351 .await
352 .expect("failed to get db connection")
353 .interact(move |conn| {
354 JtiSchema::oauth_used_jtis
355 .filter(JtiSchema::jti.eq(jti_string))
356 .count()
357 .get_result::<i64>(conn)
358 })
359 .await
360 .expect("failed to query JTI")
361 .expect("failed to get JTI count");
362
363 if jti_used > 0 {
364 return Err(Error::with_status(
365 StatusCode::UNAUTHORIZED,
366 anyhow!("DPoP token has been replayed"),
367 ));
368 }
369
370 // Store the JTI to prevent replay attacks
371 // Get expiry from token or default to 60 seconds
372 let exp = dpop_claims
373 .get("exp")
374 .and_then(serde_json::Value::as_i64)
375 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp));
376
377 // Convert SQLx INSERT to Diesel
378 let jti_str = jti.to_owned();
379 let thumbprint_str = calculated_thumbprint.to_string();
380 let _ = state
381 .db
382 .get()
383 .await
384 .expect("failed to get db connection")
385 .interact(move |conn| {
386 diesel::insert_into(JtiSchema::oauth_used_jtis)
387 .values((
388 JtiSchema::jti.eq(jti_str),
389 JtiSchema::issuer.eq(thumbprint_str),
390 JtiSchema::created_at.eq(timestamp),
391 JtiSchema::expires_at.eq(exp),
392 ))
393 .execute(conn)
394 })
395 .await
396 .expect("failed to insert JTI")
397 .expect("failed to insert JTI");
398
399 // Extract subject (DID) from access token
400 if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) {
401 use crate::schema::pds::account::dsl as AccountSchema;
402
403 let did_clone = did.to_owned();
404
405 let _did = state
406 .db
407 .get()
408 .await
409 .expect("failed to get db connection")
410 .interact(move |conn| {
411 AccountSchema::account
412 .filter(AccountSchema::did.eq(did_clone))
413 .select(AccountSchema::did)
414 .first::<String>(conn)
415 })
416 .await
417 .expect("failed to query account")
418 .expect("failed to get account");
419
420 Ok(AuthenticatedUser {
421 did: did.to_owned(),
422 })
423 } else {
424 Err(Error::with_status(
425 StatusCode::UNAUTHORIZED,
426 anyhow!("invalid access token: missing subject"),
427 ))
428 }
429}
430
431/// Cryptographically sign a JSON web token with the specified key.
432pub(crate) fn sign(
433 key: &Secp256k1Keypair,
434 typ: &str,
435 claims: &serde_json::Value,
436) -> anyhow::Result<String> {
437 // RFC 9068
438 let hdr = serde_json::json!({
439 "typ": typ,
440 "alg": "ES256K", // Secp256k1Keypair
441 });
442 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD
443 .encode(serde_json::to_vec(&hdr).context("failed to encode claims")?);
444 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD
445 .encode(serde_json::to_vec(&claims).context("failed to encode claims")?);
446 let sig = key
447 .sign(format!("{hdr}.{claims}").as_bytes())
448 .context("failed to sign jwt")?;
449 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig);
450
451 Ok(format!("{hdr}.{claims}.{sig}"))
452}
453
454/// Cryptographically verify a JSON web token's validity using the specified public key.
455pub(crate) fn verify(key: &str, token: &str) -> anyhow::Result<(String, serde_json::Value)> {
456 let mut parts = token.splitn(3, '.');
457 let hdr = parts.next().context("no header")?;
458 let claims = parts.next().context("no claims")?;
459 let sig = base64::prelude::BASE64_URL_SAFE_NO_PAD
460 .decode(parts.next().context("no sig")?)
461 .context("failed to decode signature")?;
462
463 let (alg, key) = atrium_crypto::did::parse_did_key(key).context("failed to decode key")?;
464 Verifier::default()
465 .verify(alg, &key, format!("{hdr}.{claims}").as_bytes(), &sig)
466 .context("failed to verify jwt")?;
467
468 let hdr = base64::prelude::BASE64_URL_SAFE_NO_PAD
469 .decode(hdr)
470 .context("failed to decode hdr")?;
471 let hdr =
472 serde_json::from_slice::<serde_json::Value>(&hdr).context("failed to parse hdr as json")?;
473 let typ = hdr
474 .get("typ")
475 .and_then(serde_json::Value::as_str)
476 .context("hdr is invalid")?;
477
478 let claims = base64::prelude::BASE64_URL_SAFE_NO_PAD
479 .decode(claims)
480 .context("failed to decode claims")?;
481 let claims = serde_json::from_slice(&claims).context("failed to parse claims as json")?;
482
483 Ok((typ.to_owned(), claims))
484}