Alternative ATProto PDS implementation
1//! OAuth endpoints
2#![allow(unnameable_types, unused_qualifications)]
3use crate::config::AppConfig;
4use crate::error::Error;
5use crate::metrics::AUTH_FAILED;
6use crate::serve::{AppState, Client, Result, SigningKey};
7use anyhow::{Context as _, anyhow};
8use argon2::{Argon2, PasswordHash, PasswordVerifier as _};
9use atrium_crypto::keypair::Did as _;
10use axum::response::Redirect;
11use axum::{
12 Json, Router, extract,
13 extract::{Query, State},
14 http::{HeaderMap, HeaderValue, StatusCode, header},
15 response::IntoResponse,
16 routing::{get, post},
17};
18use base64::Engine as _;
19use deadpool_diesel::sqlite::Pool;
20use diesel::*;
21use metrics::counter;
22use rand::distributions::Alphanumeric;
23use rand::{Rng as _, thread_rng};
24use serde::{Deserialize, Serialize};
25use serde_json::{Value, json};
26use sha2::{Digest as _, Sha256};
27use std::collections::{HashMap, HashSet};
28
29/// JWK thumbprint required properties for each key type (RFC7638)
30///
31/// Currently only supporting ES256K (Secp256k1) and RSA as those are
32/// commonly used in DPoP proofs with ATProto
33const JWK_REQUIRED_PROPS: &[(&str, &[&str])] = &[
34 ("EC", &["crv", "kty", "x", "y"]),
35 ("RSA", &["e", "kty", "n"]),
36];
37
38/// JWT ID used record for tracking used JTIs to prevent replay attacks
39#[derive(Debug, Serialize, Deserialize)]
40struct JtiRecord {
41 expires_at: i64,
42 issuer: String,
43 jti: String,
44}
45
46/// Parses a JWT without validation and returns header and claims
47fn parse_jwt(token: &str) -> Result<(Value, Value)> {
48 let parts: Vec<&str> = token.split('.').collect();
49 if parts.len() != 3 {
50 return Err(Error::with_status(
51 StatusCode::BAD_REQUEST,
52 anyhow!("Invalid JWT format"),
53 ));
54 }
55
56 let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
57 .decode(parts.first().expect("should have JWT header"))
58 .context("Failed to decode JWT header")?;
59
60 let header: Value =
61 serde_json::from_slice(&header_bytes).context("Failed to parse JWT header as JSON")?;
62
63 let claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
64 .decode(parts.get(1).expect("should have JWT claims"))
65 .context("Failed to decode JWT claims")?;
66
67 let claims: Value =
68 serde_json::from_slice(&claims_bytes).context("Failed to parse JWT claims as JSON")?;
69
70 Ok((header, claims))
71}
72
73/// Calculate RFC7638 compliant JWK thumbprint
74///
75/// This follows the standard:
76/// 1. Extract only the required members for the key type
77/// 2. Sort members alphabetically
78/// 3. Remove whitespace in the serialization
79/// 4. SHA-256 hash and base64url encode
80fn calculate_jwk_thumbprint(jwk: &Value) -> Result<String> {
81 // Determine the key type
82 let key_type = jwk
83 .get("kty")
84 .and_then(Value::as_str)
85 .context("JWK missing kty property")?;
86
87 // Find required properties for this key type
88 let required_props = JWK_REQUIRED_PROPS
89 .iter()
90 .find(|&&(kty, _)| kty == key_type)
91 .map(|&(_, props)| props)
92 .context(anyhow!("Unsupported key type: {key_type}"))?;
93
94 // Build a new JWK with only the required properties
95 let mut canonical_jwk = serde_json::Map::new();
96
97 for &prop in required_props {
98 let value = jwk
99 .get(prop)
100 .context(anyhow!("JWK missing required property: {prop}"))?;
101 drop(canonical_jwk.insert((*prop).to_string(), value.clone()));
102 }
103
104 // Serialize with no whitespace
105 let canonical_json = serde_json::to_string(&Value::Object(canonical_jwk))
106 .context("Failed to serialize canonical JWK")?;
107
108 // Hash the canonical representation
109 let mut hasher = Sha256::new();
110 hasher.update(canonical_json.as_bytes());
111 let thumbprint = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
112
113 Ok(thumbprint)
114}
115
116/// Protected Resource Metadata
117/// - GET `/.well-known/oauth-protected-resource`
118async fn protected_resource(State(config): State<AppConfig>) -> Result<Json<Value>> {
119 Ok(Json(json!({
120 "resource": format!("https://{}", config.host_name),
121 "authorization_servers": [format!("https://{}", config.host_name)],
122 "scopes_supported": [],
123 "bearer_methods_supported": ["header"],
124 "resource_documentation": "https://atproto.com",
125 })))
126}
127
128/// Authorization Server Metadata
129/// - GET `/.well-known/oauth-authorization-server`
130async fn authorization_server(State(config): State<AppConfig>) -> Result<Json<Value>> {
131 let base_url = format!("https://{}", config.host_name);
132
133 Ok(Json(serde_json::json!({
134 "issuer": base_url,
135 "request_parameter_supported": true,
136 "request_uri_parameter_supported": true,
137 "require_request_uri_registration": true,
138 "scopes_supported": ["atproto", "transition:generic", "transition:chat.bsky"],
139 "subject_types_supported": ["public"],
140 "response_types_supported": ["code"],
141 "response_modes_supported": ["query", "fragment", "form_post"],
142 "grant_types_supported": ["authorization_code", "refresh_token"],
143 "code_challenge_methods_supported": ["S256"],
144 "ui_locales_supported": ["en-US"],
145 "display_values_supported": ["page", "popup", "touch"],
146 "request_object_signing_alg_values_supported": ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES256K", "ES384", "ES512"],
147 "authorization_response_iss_parameter_supported": true,
148 "request_object_encryption_alg_values_supported": [],
149 "request_object_encryption_enc_values_supported": [],
150 "jwks_uri": format!("{}/oauth/jwks", base_url),
151 "authorization_endpoint": format!("{}/oauth/authorize", base_url),
152 "token_endpoint": format!("{}/oauth/token", base_url),
153 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"],
154 "token_endpoint_auth_signing_alg_values_supported": ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES256K", "ES384", "ES512"],
155 "revocation_endpoint": format!("{}/oauth/revoke", base_url),
156 "introspection_endpoint": format!("{}/oauth/introspect", base_url),
157 "pushed_authorization_request_endpoint": format!("{}/oauth/par", base_url),
158 "require_pushed_authorization_requests": true,
159 "dpop_signing_alg_values_supported": ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES256K", "ES384", "ES512"],
160 "client_id_metadata_document_supported": true
161 })))
162}
163
164/// Fetch and validate client metadata from `client_id` URL
165async fn fetch_client_metadata(client: &Client, client_id: &str) -> Result<Value> {
166 // Handle localhost development
167 if client_id.starts_with("http://localhost") {
168 let client_url = url::Url::parse(client_id).context("client_id must be a valid URL")?;
169
170 let mut metadata = json!({
171 "client_id": client_id,
172 "client_name": "Loopback client",
173 "response_types": ["code"],
174 "grant_types": ["authorization_code", "refresh_token"],
175 "scope": "atproto transition:generic",
176 "token_endpoint_auth_method": "none",
177 "application_type": "native",
178 "dpop_bound_access_tokens": true,
179 });
180
181 // Extract redirect_uri from query params if available
182 let redirect_uris = client_url.query().map_or_else(
183 || {
184 vec![
185 json!("http://127.0.0.1/callback"),
186 json!("http://[::1]/callback"),
187 ]
188 },
189 |query| {
190 let pairs: HashMap<_, _> = url::form_urlencoded::parse(query.as_bytes()).collect();
191 pairs.get("redirect_uri").map_or_else(
192 || {
193 vec![
194 json!("http://127.0.0.1/callback"),
195 json!("http://[::1]/callback"),
196 ]
197 },
198 |uri| vec![json!(uri)],
199 )
200 },
201 );
202
203 if let Some(redirect_uris_value) = metadata.as_object_mut() {
204 drop(redirect_uris_value.insert("redirect_uris".to_owned(), json!(redirect_uris)));
205 }
206
207 return Ok(metadata);
208 }
209
210 // If not in dev environment, fetch metadata
211 let response = client
212 .get(client_id)
213 .send()
214 .await
215 .context("Failed to fetch client metadata")?;
216
217 if !response.status().is_success() {
218 return Err(Error::with_status(
219 StatusCode::BAD_REQUEST,
220 anyhow!(
221 "Failed to fetch client metadata: HTTP {}",
222 response.status()
223 ),
224 ));
225 }
226
227 let metadata: Value = response
228 .json()
229 .await
230 .context("Failed to parse client metadata as JSON")?;
231
232 // Validate client_id in metadata matches requested client_id
233 if metadata.get("client_id").and_then(|id| id.as_str()) != Some(client_id) {
234 return Err(Error::with_status(
235 StatusCode::BAD_REQUEST,
236 anyhow!("client_id in metadata doesn't match requested client_id"),
237 ));
238 }
239
240 // Validate DPoP tokens requirement
241 if !metadata
242 .get("dpop_bound_access_tokens")
243 .and_then(Value::as_bool)
244 .unwrap_or(false)
245 {
246 return Err(Error::with_status(
247 StatusCode::BAD_REQUEST,
248 anyhow!("Client metadata must set dpop_bound_access_tokens to true"),
249 ));
250 }
251
252 Ok(metadata)
253}
254
255/// Pushed Authorization Request endpoint
256/// POST `/oauth/par`
257#[expect(clippy::too_many_lines)]
258async fn par(
259 State(db): State<Pool>,
260 State(client): State<Client>,
261 Json(form_data): Json<HashMap<String, String>>,
262) -> Result<Json<Value>> {
263 // Required parameters
264 let client_id = form_data
265 .get("client_id")
266 .context("client_id parameter is required")?;
267 let response_type = form_data
268 .get("response_type")
269 .context("response_type parameter is required")?;
270 let code_challenge = form_data
271 .get("code_challenge")
272 .context("code_challenge parameter is required")?;
273 let code_challenge_method = form_data
274 .get("code_challenge_method")
275 .context("code_challenge_method parameter is required")?;
276
277 // Ensure code_challenge_method is S256 (required by spec)
278 if code_challenge_method != "S256" {
279 return Err(Error::with_status(
280 StatusCode::BAD_REQUEST,
281 anyhow!("code_challenge_method must be S256"),
282 ));
283 }
284
285 // Validate response_type is "code" (our spec only supports this)
286 if response_type != "code" {
287 return Err(Error::with_status(
288 StatusCode::BAD_REQUEST,
289 anyhow!("response_type must be code"),
290 ));
291 }
292
293 // Optional parameters
294 let state = form_data.get("state").cloned();
295 let login_hint = form_data.get("login_hint").cloned();
296 let scope = form_data.get("scope").cloned();
297 let redirect_uri = form_data.get("redirect_uri").cloned();
298 let response_mode = form_data.get("response_mode").cloned();
299 let display = form_data.get("display").cloned();
300
301 // Validate client metadata
302 let client_metadata = fetch_client_metadata(&client, client_id).await?;
303
304 // If redirect_uri is provided, validate it's in the client's allowed list
305 if let Some(ref provided_uri) = redirect_uri {
306 let allowed_uris = client_metadata
307 .get("redirect_uris")
308 .and_then(|uris| uris.as_array())
309 .context("client metadata missing redirect_uris")?;
310
311 let uri_valid = allowed_uris
312 .iter()
313 .any(|uri| uri.as_str().is_some_and(|uri_str| uri_str == provided_uri));
314
315 if !uri_valid {
316 return Err(Error::with_status(
317 StatusCode::BAD_REQUEST,
318 anyhow!("redirect_uri not allowed for this client"),
319 ));
320 }
321 } else if client_metadata
322 .get("redirect_uris")
323 .and_then(|uris| uris.as_array())
324 .map_or(0, Vec::len)
325 != 1
326 {
327 return Err(Error::with_status(
328 StatusCode::BAD_REQUEST,
329 anyhow!("redirect_uri required when client has multiple registered URIs"),
330 ));
331 }
332
333 // Validate scope is in allowed scope for client
334 if let Some(ref requested_scope) = scope {
335 if let Some(allowed_scope) = client_metadata.get("scope").and_then(|s| s.as_str()) {
336 let requested_scopes: HashSet<&str> = requested_scope.split_whitespace().collect();
337 let allowed_scopes: HashSet<&str> = allowed_scope.split_whitespace().collect();
338
339 if !requested_scopes.is_subset(&allowed_scopes) {
340 return Err(Error::with_status(
341 StatusCode::BAD_REQUEST,
342 anyhow!("requested scope exceeds allowed scope"),
343 ));
344 }
345 }
346 }
347
348 // Generate a unique request_uri
349 let request_id = thread_rng()
350 .sample_iter(Alphanumeric)
351 .take(32)
352 .map(char::from)
353 .collect::<String>();
354 let request_uri = format!("urn:ietf:params:oauth:request_uri:req-{request_id}");
355
356 // Store request data in the database
357 let now = chrono::Utc::now();
358 let created_at = now.timestamp();
359 let expires_at = now
360 .checked_add_signed(chrono::Duration::minutes(5))
361 .context("failed to compute expiration time")?
362 .timestamp();
363
364 use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema;
365 let client_id = client_id.to_owned();
366 let request_uri_cloned = request_uri.to_owned();
367 let response_type = response_type.to_owned();
368 let code_challenge = code_challenge.to_owned();
369 let code_challenge_method = code_challenge_method.to_owned();
370 _ = db
371 .get()
372 .await
373 .expect("Failed to get database connection")
374 .interact(move |conn| {
375 insert_into(ParRequestSchema::oauth_par_requests)
376 .values((
377 ParRequestSchema::request_uri.eq(&request_uri_cloned),
378 ParRequestSchema::client_id.eq(client_id),
379 ParRequestSchema::response_type.eq(response_type),
380 ParRequestSchema::code_challenge.eq(code_challenge),
381 ParRequestSchema::code_challenge_method.eq(code_challenge_method),
382 ParRequestSchema::state.eq(state),
383 ParRequestSchema::login_hint.eq(login_hint),
384 ParRequestSchema::scope.eq(scope),
385 ParRequestSchema::redirect_uri.eq(redirect_uri),
386 ParRequestSchema::response_mode.eq(response_mode),
387 ParRequestSchema::display.eq(display),
388 ParRequestSchema::created_at.eq(created_at),
389 ParRequestSchema::expires_at.eq(expires_at),
390 ))
391 .execute(conn)
392 })
393 .await
394 .expect("Failed to store PAR request")
395 .expect("Failed to store PAR request");
396
397 Ok(Json(json!({
398 "request_uri": request_uri,
399 "expires_in": 300_i32 // 5 minutes
400 })))
401}
402
403/// OAuth Authorization endpoint
404/// GET `/oauth/authorize`
405async fn authorize(
406 State(db): State<Pool>,
407 State(client): State<Client>,
408 Query(params): Query<HashMap<String, String>>,
409) -> Result<impl IntoResponse> {
410 // Required parameters
411 let client_id = params
412 .get("client_id")
413 .context("client_id parameter is required")?;
414 let request_uri = params
415 .get("request_uri")
416 .context("request_uri parameter is required")?;
417
418 let timestamp = chrono::Utc::now().timestamp();
419
420 // Retrieve the PAR request from the database
421 use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema;
422
423 let request_uri_clone = request_uri.to_owned();
424 let client_id_clone = client_id.to_owned();
425 let timestamp_clone = timestamp.clone();
426 let login_hint = db
427 .get()
428 .await
429 .expect("Failed to get database connection")
430 .interact(move |conn| {
431 ParRequestSchema::oauth_par_requests
432 .select(ParRequestSchema::login_hint)
433 .filter(ParRequestSchema::request_uri.eq(request_uri_clone))
434 .filter(ParRequestSchema::client_id.eq(client_id_clone))
435 .filter(ParRequestSchema::expires_at.gt(timestamp_clone))
436 .first::<Option<String>>(conn)
437 .optional()
438 })
439 .await
440 .expect("Failed to query PAR request")
441 .expect("Failed to query PAR request")
442 .expect("Failed to query PAR request");
443
444 // Validate client metadata
445 let client_metadata = fetch_client_metadata(&client, client_id).await?;
446
447 // Authorization page with login form
448 let login_hint = login_hint.unwrap_or_default();
449 let html = format!(
450 r#"<!DOCTYPE html>
451 <html>
452 <head>
453 <title>Authentication Required</title>
454 <meta name="viewport" content="width=device-width, initial-scale=1">
455 <style>
456 body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; max-width: 500px; margin: 0 auto; padding: 20px; }}
457 .container {{ border: 1px solid #e0e0e0; border-radius: 8px; padding: 20px; }}
458 h1 {{ font-size: 24px; }}
459 label {{ display: block; margin-top: 12px; }}
460 input[type="text"], input[type="password"] {{ width: 100%; padding: 8px; margin-top: 4px; border: 1px solid #ddd; border-radius: 4px; }}
461 button {{ margin-top: 20px; padding: 8px 16px; background-color: #0070f3; color: white; border: none; border-radius: 4px; cursor: pointer; }}
462 .client {{ margin-top: 12px; font-size: 14px; color: #666; }}
463 </style>
464 </head>
465 <body>
466 <div class="container">
467 <h1>Sign in to continue</h1>
468 <p>An application is requesting access to your account.</p>
469
470 <div class="client">
471 <strong>Client:</strong> {client_name}
472 </div>
473
474 <form action="/oauth/authorize/sign-in" method="post">
475 <input type="hidden" name="request_uri" value="{request_uri}">
476 <input type="hidden" name="client_id" value="{client_id}">
477
478 <label for="username">Username</label>
479 <input type="text" id="username" name="username" value="{login_hint}" required>
480
481 <label for="password">Password</label>
482 <input type="password" id="password" name="password" required>
483
484 <label>
485 <input type="checkbox" name="remember" value="true"> Remember me
486 </label>
487
488 <button type="submit">Sign in</button>
489 </form>
490 </div>
491 </body>
492 </html>
493 "#,
494 client_name = client_metadata
495 .get("client_name")
496 .and_then(|n| n.as_str())
497 .unwrap_or(client_id),
498 request_uri = request_uri,
499 client_id = client_id,
500 login_hint = login_hint
501 );
502
503 Ok((
504 StatusCode::OK,
505 [(header::CONTENT_TYPE, HeaderValue::from_static("text/html"))],
506 html,
507 ))
508}
509
510/// OAuth Authorization Sign-in endpoint
511/// POST `/oauth/authorize/sign-in`
512#[expect(clippy::too_many_lines)]
513async fn authorize_signin(
514 State(db): State<Pool>,
515 State(config): State<AppConfig>,
516 State(client): State<Client>,
517 extract::Form(form_data): extract::Form<HashMap<String, String>>,
518) -> Result<impl IntoResponse> {
519 use std::fmt::Write as _;
520
521 // Extract form data
522 let username = form_data.get("username").context("username is required")?;
523 let password = form_data.get("password").context("password is required")?;
524 let request_uri = form_data
525 .get("request_uri")
526 .context("request_uri is required")?;
527 let client_id = form_data
528 .get("client_id")
529 .context("client_id is required")?;
530
531 let timestamp = chrono::Utc::now().timestamp();
532
533 // Retrieve the PAR request
534 use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema;
535 #[derive(Queryable, Selectable)]
536 #[diesel(table_name = crate::schema::pds::oauth_par_requests)]
537 #[diesel(check_for_backend(sqlite::Sqlite))]
538 struct ParRequest {
539 request_uri: String,
540 client_id: String,
541 response_type: String,
542 code_challenge: String,
543 code_challenge_method: String,
544 state: Option<String>,
545 login_hint: Option<String>,
546 scope: Option<String>,
547 redirect_uri: Option<String>,
548 response_mode: Option<String>,
549 display: Option<String>,
550 created_at: i64,
551 expires_at: i64,
552 }
553 let request_uri_clone = request_uri.to_owned();
554 let client_id_clone = client_id.to_owned();
555 let timestamp_clone = timestamp.clone();
556 let par_request = db
557 .get()
558 .await
559 .expect("Failed to get database connection")
560 .interact(move |conn| {
561 ParRequestSchema::oauth_par_requests
562 .filter(ParRequestSchema::request_uri.eq(request_uri_clone))
563 .filter(ParRequestSchema::client_id.eq(client_id_clone))
564 .filter(ParRequestSchema::expires_at.gt(timestamp_clone))
565 .first::<ParRequest>(conn)
566 .optional()
567 })
568 .await
569 .expect("Failed to query PAR request")
570 .expect("Failed to query PAR request")
571 .expect("Failed to query PAR request");
572
573 // Authenticate the user
574 use crate::schema::pds::account::dsl as AccountSchema;
575 use crate::schema::pds::actor::dsl as ActorSchema;
576 let username_clone = username.to_owned();
577 let account = db
578 .get()
579 .await
580 .expect("Failed to get database connection")
581 .interact(move |conn| {
582 AccountSchema::account
583 .filter(AccountSchema::email.eq(username_clone))
584 .first::<crate::models::pds::Account>(conn)
585 .optional()
586 })
587 .await
588 .expect("Failed to query account")
589 .expect("Failed to query account")
590 .expect("Failed to query account");
591 // let actor = db
592 // .get()
593 // .await
594 // .expect("Failed to get database connection")
595 // .interact(move |conn| {
596 // ActorSchema::actor
597 // .filter(ActorSchema::did.eq(did))
598 // .first::<rsky_pds::models::Actor>(conn)
599 // .optional()
600 // })
601 // .await
602 // .expect("Failed to query actor")
603 // .expect("Failed to query actor")
604 // .expect("Failed to query actor");
605
606 // Verify password - fixed to use equality check instead of pattern matching
607 if Argon2::default().verify_password(
608 password.as_bytes(),
609 &PasswordHash::new(account.password.as_str()).context("invalid password hash in db")?,
610 ) == Ok(())
611 {
612 } else {
613 counter!(AUTH_FAILED).increment(1);
614 return Err(Error::with_status(
615 StatusCode::UNAUTHORIZED,
616 anyhow!("invalid credentials"),
617 ));
618 }
619
620 // Generate authorization code
621 let code = thread_rng()
622 .sample_iter(Alphanumeric)
623 .take(40)
624 .map(char::from)
625 .collect::<String>();
626
627 // Determine redirect URI to use
628 let redirect_uri = if let Some(uri) = par_request.redirect_uri.as_ref() {
629 uri.clone()
630 } else {
631 let client_metadata = fetch_client_metadata(&client, client_id).await?;
632 client_metadata
633 .get("redirect_uris")
634 .and_then(|uris| uris.as_array())
635 .and_then(|uris| uris.first())
636 .and_then(|uri| uri.as_str())
637 .context("No redirect_uri available")?
638 .to_owned()
639 };
640
641 // Store the authorization code
642 let now = chrono::Utc::now();
643 let created_at = now.timestamp();
644 let expires_at = now
645 .checked_add_signed(chrono::Duration::minutes(10))
646 .context("failed to compute expiration time")?
647 .timestamp();
648
649 use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema;
650 let code_cloned = code.to_owned();
651 let client_id = client_id.to_owned();
652 let subject = account.did.to_owned();
653 let code_challenge = par_request.code_challenge.to_owned();
654 let code_challenge_method = par_request.code_challenge_method.to_owned();
655 let redirect_uri_cloned = redirect_uri.to_owned();
656 let scope = par_request.scope.to_owned();
657 let used = false;
658 _ = db
659 .get()
660 .await
661 .expect("Failed to get database connection")
662 .interact(move |conn| {
663 insert_into(AuthCodeSchema::oauth_authorization_codes)
664 .values((
665 AuthCodeSchema::code.eq(code_cloned),
666 AuthCodeSchema::client_id.eq(client_id),
667 AuthCodeSchema::subject.eq(subject),
668 AuthCodeSchema::code_challenge.eq(code_challenge),
669 AuthCodeSchema::code_challenge_method.eq(code_challenge_method),
670 AuthCodeSchema::redirect_uri.eq(redirect_uri_cloned),
671 AuthCodeSchema::scope.eq(scope),
672 AuthCodeSchema::created_at.eq(created_at),
673 AuthCodeSchema::expires_at.eq(expires_at),
674 AuthCodeSchema::used.eq(used),
675 ))
676 .execute(conn)
677 })
678 .await
679 .expect("Failed to store authorization code")
680 .expect("Failed to store authorization code");
681
682 // Use state from the PAR request or generate one
683 let state = par_request.state.unwrap_or_else(|| {
684 thread_rng()
685 .sample_iter(Alphanumeric)
686 .take(16)
687 .map(char::from)
688 .collect::<String>()
689 });
690
691 // Build redirect URL
692 let mut redirect_target = redirect_uri;
693 match par_request.response_mode.as_deref() {
694 Some("fragment") => redirect_target.push('#'),
695 _ => redirect_target.push('?'),
696 }
697
698 write!(redirect_target, "state={}", urlencoding::encode(&state)).unwrap();
699 let host_name = format!("https://{}", &config.host_name);
700 write!(redirect_target, "&iss={}", urlencoding::encode(&host_name)).unwrap();
701 write!(redirect_target, "&code={}", urlencoding::encode(&code)).unwrap();
702 Ok(Redirect::to(&redirect_target))
703}
704
705/// Verify a DPoP proof and return the JWK thumbprint
706/// RFC 7519 JSON Web Token (JWT) - 4.3. Checking DPoP Proofs
707/// To validate a DPoP proof, the receiving server MUST ensure the
708/// following:
709/// 1. There is not more than one DPoP HTTP request header field.
710/// 2. The DPoP HTTP request header field value is a single and well-
711/// formed JWT.
712/// 3. All required claims per Section 4.2 are contained in the JWT.
713/// 4. The typ JOSE Header Parameter has the value dpop+jwt.
714/// 5. The alg JOSE Header Parameter indicates a registered asymmetric
715/// digital signature algorithm [IANA.JOSE.ALGS], is not none, is
716/// supported by the application, and is acceptable per local
717/// policy.
718/// 6. The JWT signature verifies with the public key contained in the
719/// jwk JOSE Header Parameter.
720/// 7. The jwk JOSE Header Parameter does not contain a private key.
721/// 8. The htm claim matches the HTTP method of the current request.
722/// 9. The htu claim matches the HTTP URI value for the HTTP request in
723/// which the JWT was received, ignoring any query and fragment
724/// parts.
725/// 10. If the server provided a nonce value to the client, the nonce
726/// claim matches the server-provided nonce value.
727/// 11. The creation time of the JWT, as determined by either the iat
728/// claim or a server managed timestamp via the nonce claim, is
729/// within an acceptable window (see Section 11.1).
730/// 12. If presented to a protected resource in conjunction with an
731/// access token,
732/// * ensure that the value of the ath claim equals the hash of
733/// that access token, and
734/// * confirm that the public key to which the access token is
735/// bound matches the public key from the DPoP proof.
736#[expect(clippy::too_many_lines)]
737async fn verify_dpop_proof(
738 dpop_token: &str,
739 http_method: &str,
740 http_uri: &str,
741 db: &Pool,
742 access_token: Option<&str>,
743 bound_key_thumbprint: Option<&str>,
744) -> Result<String> {
745 // Parse the DPoP token
746 let (header, claims) = parse_jwt(dpop_token)?;
747
748 // 1. Verify "typ" is "dpop+jwt" (requirement #4)
749 if header.get("typ").and_then(Value::as_str) != Some("dpop+jwt") {
750 return Err(Error::with_status(
751 StatusCode::BAD_REQUEST,
752 anyhow!("Invalid DPoP token type"),
753 ));
754 }
755
756 // 2. Check alg header (requirement #5)
757 let alg = header
758 .get("alg")
759 .and_then(Value::as_str)
760 .context("Missing alg in DPoP header")?;
761 if alg == "none" || !["RS256", "ES256", "ES256K", "PS256"].contains(&alg) {
762 return Err(Error::with_status(
763 StatusCode::BAD_REQUEST,
764 anyhow!("Unsupported or insecure signature algorithm"),
765 ));
766 }
767
768 // 3. Extract JWK and verify no private key components (requirement #7)
769 let jwk = header.get("jwk").context("missing jwk in DPoP header")?;
770 if jwk.get("d").is_some() || jwk.get("p").is_some() || jwk.get("q").is_some() {
771 return Err(Error::with_status(
772 StatusCode::BAD_REQUEST,
773 anyhow!("JWK contains private key components"),
774 ));
775 }
776
777 // 4. Calculate JWK thumbprint
778 let thumbprint = calculate_jwk_thumbprint(jwk)?;
779
780 // 5. Verify JWT signature (requirement #6)
781 // TODO: Implement signature verification with the JWK
782
783 // 6. Verify required claims (requirement #3)
784 let jti = claims
785 .get("jti")
786 .and_then(Value::as_str)
787 .context("Missing jti claim in DPoP token")?;
788
789 // 7. Check HTTP method matches htm claim (requirement #8)
790 if claims.get("htm").and_then(Value::as_str) != Some(http_method) {
791 return Err(Error::with_status(
792 StatusCode::BAD_REQUEST,
793 anyhow!("DPoP token HTTP method mismatch"),
794 ));
795 }
796
797 // 8. Check HTTP URI matches htu claim (requirement #9)
798 // Should perform proper URI normalization for production use
799 if claims.get("htu").and_then(Value::as_str) != Some(http_uri) {
800 return Err(Error::with_status(
801 StatusCode::BAD_REQUEST,
802 anyhow!(
803 "DPoP token HTTP URI mismatch: expected {}, got {}",
804 http_uri,
805 claims.get("htu").and_then(Value::as_str).unwrap_or("None")
806 ),
807 ));
808 }
809
810 // 9. Verify token timestamps (requirement #11)
811 let now = chrono::Utc::now().timestamp();
812
813 // Check creation time (iat)
814 if let Some(iat) = claims.get("iat").and_then(Value::as_i64) {
815 // Token not too old (5 minute max age)
816 if iat < now.saturating_sub(300) {
817 return Err(Error::with_status(
818 StatusCode::BAD_REQUEST,
819 anyhow!("DPoP token too old"),
820 ));
821 }
822
823 // Token not in the future (with clock skew allowance)
824 if iat > now.saturating_add(60) {
825 return Err(Error::with_status(
826 StatusCode::BAD_REQUEST,
827 anyhow!("DPoP token creation time is in the future"),
828 ));
829 }
830 }
831
832 // Check expiration (exp) if present
833 let exp = claims
834 .get("exp")
835 .and_then(Value::as_i64)
836 .unwrap_or_else(|| now.saturating_add(60)); // Default 60s if not present
837
838 if now >= exp {
839 return Err(Error::with_status(
840 StatusCode::BAD_REQUEST,
841 anyhow!("DPoP token has expired"),
842 ));
843 }
844
845 // 10. Verify access token binding (requirement #12)
846 if let Some(token) = access_token {
847 // Verify ath claim matches token hash
848 if let Some(ath) = claims.get("ath").and_then(Value::as_str) {
849 let mut hasher = Sha256::new();
850 hasher.update(token.as_bytes());
851 let token_hash =
852 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
853
854 if ath != token_hash {
855 return Err(Error::with_status(
856 StatusCode::BAD_REQUEST,
857 anyhow!("DPoP access token hash mismatch"),
858 ));
859 }
860 } else {
861 return Err(Error::with_status(
862 StatusCode::BAD_REQUEST,
863 anyhow!("Missing ath claim for DPoP with access token"),
864 ));
865 }
866
867 // Verify key binding matches
868 if let Some(expected_thumbprint) = bound_key_thumbprint {
869 if thumbprint != expected_thumbprint {
870 return Err(Error::with_status(
871 StatusCode::BAD_REQUEST,
872 anyhow!("DPoP key doesn't match token-bound key"),
873 ));
874 }
875 }
876 }
877
878 // 11. Check for replay attacks via JTI tracking
879 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema;
880 let jti_clone = jti.to_owned();
881 let jti_used = db
882 .get()
883 .await
884 .expect("Failed to get database connection")
885 .interact(move |conn| {
886 JtiSchema::oauth_used_jtis
887 .filter(JtiSchema::jti.eq(jti_clone))
888 .count()
889 .get_result::<i64>(conn)
890 .optional()
891 })
892 .await
893 .expect("Failed to check JTI")
894 .expect("Failed to check JTI")
895 .unwrap_or(0);
896
897 if jti_used > 0 {
898 return Err(Error::with_status(
899 StatusCode::BAD_REQUEST,
900 anyhow!("DPoP token has been replayed"),
901 ));
902 }
903
904 // 12. Store the JTI to prevent replay attacks
905 let jti_cloned = jti.to_owned();
906 let issuer = thumbprint.to_owned();
907 let created_at = now;
908 let expires_at = exp;
909 _ = db
910 .get()
911 .await
912 .expect("Failed to get database connection")
913 .interact(move |conn| {
914 insert_into(JtiSchema::oauth_used_jtis)
915 .values((
916 JtiSchema::jti.eq(jti_cloned),
917 JtiSchema::issuer.eq(issuer),
918 JtiSchema::created_at.eq(created_at),
919 JtiSchema::expires_at.eq(expires_at),
920 ))
921 .execute(conn)
922 })
923 .await
924 .expect("Failed to store JTI")
925 .expect("Failed to store JTI");
926
927 // 13. Cleanup expired JTIs periodically (1% chance on each request)
928 if thread_rng().gen_range(0_i32..100_i32) == 0_i32 {
929 let now_clone = now.to_owned();
930 _ = db
931 .get()
932 .await
933 .expect("Failed to get database connection")
934 .interact(move |conn| {
935 delete(JtiSchema::oauth_used_jtis)
936 .filter(JtiSchema::expires_at.lt(now_clone))
937 .execute(conn)
938 })
939 .await
940 .expect("Failed to clean up expired JTIs")
941 .expect("Failed to clean up expired JTIs");
942 }
943
944 Ok(thumbprint)
945}
946
947/// Verify a `code_verifier` against stored `code_challenge`
948fn verify_pkce(code_verifier: &str, stored_challenge: &str, method: &str) -> Result<()> {
949 // Only S256 is supported currently
950 if method != "S256" {
951 return Err(Error::with_status(
952 StatusCode::BAD_REQUEST,
953 anyhow!("Unsupported code_challenge_method: {}", method),
954 ));
955 }
956
957 // Calculate the code challenge from verifier
958 let mut hasher = Sha256::new();
959 hasher.update(code_verifier.as_bytes());
960 let calculated = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize());
961
962 // Compare with stored challenge
963 if calculated != stored_challenge {
964 return Err(Error::with_status(
965 StatusCode::BAD_REQUEST,
966 anyhow!("Code verifier doesn't match challenge"),
967 ));
968 }
969
970 Ok(())
971}
972
973/// OAuth token endpoint
974/// - POST `/oauth/token`
975///
976/// Handles both `authorization_code` and `refresh_token` grants
977#[expect(clippy::too_many_lines)]
978async fn token(
979 State(db): State<Pool>,
980 State(skey): State<SigningKey>,
981 State(config): State<AppConfig>,
982 State(client): State<Client>,
983 headers: HeaderMap,
984 Json(form_data): Json<HashMap<String, String>>,
985) -> Result<Json<Value>> {
986 // Extract form parameters
987 let grant_type = form_data
988 .get("grant_type")
989 .context("grant_type is required")?;
990 let client_id = form_data
991 .get("client_id")
992 .context("client_id is required")?;
993
994 // Validate DPoP header (Rule 1: Ensure DPoP is used)
995 let dpop_token = headers
996 .get("DPoP")
997 .context("DPoP header is required")?
998 .to_str()
999 .context("Invalid DPoP header")?;
1000
1001 // Get client metadata to determine client type (public vs confidential)
1002 let client_metadata = fetch_client_metadata(&client, client_id).await?;
1003 let is_confidential_client = client_metadata
1004 .get("token_endpoint_auth_method")
1005 .and_then(Value::as_str)
1006 .unwrap_or("none")
1007 == "private_key_jwt";
1008
1009 // Verify DPoP proof
1010 let dpop_thumbprint_res = verify_dpop_proof(
1011 dpop_token,
1012 "POST",
1013 &format!("https://{}/oauth/token", config.host_name),
1014 &db,
1015 None,
1016 None,
1017 )
1018 .await?;
1019
1020 if is_confidential_client {
1021 // Rule 3: Check client authentication consistency
1022 // For confidential clients, verify client_assertion
1023 let client_assertion_type = form_data
1024 .get("client_assertion_type")
1025 .context("client_assertion_type required for private_key_jwt auth")?;
1026 let _client_assertion = form_data
1027 .get("client_assertion")
1028 .context("client_assertion required for private_key_jwt auth")?;
1029
1030 if client_assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" {
1031 return Err(Error::with_status(
1032 StatusCode::BAD_REQUEST,
1033 anyhow!("Invalid client_assertion_type"),
1034 ));
1035 }
1036
1037 // Verify client assertion JWT
1038 // This would involve similar logic to verify_dpop_proof but for client auth
1039 //
1040 // WIP: Practically unimplemented
1041 //
1042 // TODO: Figure out how this actually works
1043
1044 // verify_client_assertion(&client, client_id, client_assertion).await?;
1045
1046 // Rule 4: Ensure DPoP and client_assertion use different keys
1047 // let client_assertion_thumbprint = calculate_client_assertion_thumbprint(client_assertion)?;
1048 // if client_assertion_thumbprint == dpop_thumbprint {
1049 // return Err(Error::with_status(
1050 // StatusCode::BAD_REQUEST,
1051 // anyhow!("DPoP proof and client assertion must use different keypairs"),
1052 // ));
1053 // }
1054 } else {
1055 // Rule 2: For public clients, check if this DPoP key has been used before
1056 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema;
1057 let dpop_thumbprint_clone = dpop_thumbprint_res.to_owned();
1058 let client_id_clone = client_id.to_owned();
1059 let is_key_reused = db
1060 .get()
1061 .await
1062 .expect("Failed to get database connection")
1063 .interact(move |conn| {
1064 RefreshTokenSchema::oauth_refresh_tokens
1065 .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_clone))
1066 .filter(RefreshTokenSchema::client_id.eq(client_id_clone))
1067 .count()
1068 .get_result::<i64>(conn)
1069 .optional()
1070 })
1071 .await
1072 .expect("Failed to check key usage history")
1073 .expect("Failed to check key usage history")
1074 .unwrap_or(0)
1075 > 0;
1076
1077 if is_key_reused && grant_type == "authorization_code" {
1078 return Err(Error::with_status(
1079 StatusCode::BAD_REQUEST,
1080 anyhow!("Public clients must use a new key for each token request"),
1081 ));
1082 }
1083 }
1084
1085 match grant_type.as_str() {
1086 "authorization_code" => {
1087 // Process authorization code grant
1088 let code = form_data.get("code").context("code is required")?;
1089 let code_verifier = form_data
1090 .get("code_verifier")
1091 .context("code_verifier is required")?;
1092 let redirect_uri = form_data
1093 .get("redirect_uri")
1094 .context("redirect_uri is required")?;
1095
1096 let timestamp = chrono::Utc::now().timestamp();
1097
1098 // Retrieve and validate the authorization code
1099 use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema;
1100 #[derive(Queryable, Selectable, Serialize)]
1101 #[diesel(table_name = crate::schema::pds::oauth_authorization_codes)]
1102 #[diesel(check_for_backend(sqlite::Sqlite))]
1103 struct AuthCode {
1104 code: String,
1105 client_id: String,
1106 subject: String,
1107 code_challenge: String,
1108 code_challenge_method: String,
1109 redirect_uri: String,
1110 scope: Option<String>,
1111 created_at: i64,
1112 expires_at: i64,
1113 used: bool,
1114 }
1115 let code_clone = code.to_owned();
1116 let client_id_clone = client_id.to_owned();
1117 let redirect_uri_clone = redirect_uri.to_owned();
1118 let auth_code = db
1119 .get()
1120 .await
1121 .expect("Failed to get database connection")
1122 .interact(move |conn| {
1123 AuthCodeSchema::oauth_authorization_codes
1124 .filter(AuthCodeSchema::code.eq(code_clone))
1125 .filter(AuthCodeSchema::client_id.eq(client_id_clone))
1126 .filter(AuthCodeSchema::redirect_uri.eq(redirect_uri_clone))
1127 .filter(AuthCodeSchema::expires_at.gt(timestamp))
1128 .filter(AuthCodeSchema::used.eq(false))
1129 .first::<AuthCode>(conn)
1130 .optional()
1131 })
1132 .await
1133 .expect("Failed to query authorization code")
1134 .expect("Failed to query authorization code")
1135 .expect("Failed to query authorization code");
1136
1137 // Verify PKCE code challenge
1138 verify_pkce(
1139 code_verifier,
1140 &auth_code.code_challenge,
1141 &auth_code.code_challenge_method,
1142 )?;
1143
1144 // Mark the code as used
1145 let code_cloned = code.to_owned();
1146 _ = db
1147 .get()
1148 .await
1149 .expect("Failed to get database connection")
1150 .interact(move |conn| {
1151 update(AuthCodeSchema::oauth_authorization_codes)
1152 .filter(AuthCodeSchema::code.eq(code_cloned))
1153 .set(AuthCodeSchema::used.eq(true))
1154 .execute(conn)
1155 })
1156 .await
1157 .expect("Failed to mark code as used")
1158 .expect("Failed to mark code as used");
1159
1160 // Generate tokens with appropriate lifetimes
1161 let now = chrono::Utc::now().timestamp();
1162
1163 // Rule 5: Access token valid for short period
1164 let access_token_expires_in = 3600_i64; // 1 hour (maximum allowed)
1165 let access_token_expires_at = now.saturating_add(access_token_expires_in);
1166
1167 // Rule 11 & 12: Different refresh token lifetimes based on client type
1168 let refresh_token_expires_at = if is_confidential_client {
1169 now.saturating_add(15_552_000_i64) // 6 months for confidential clients
1170 } else {
1171 now.saturating_add(604_800_i64) // 1 week maximum for public clients
1172 };
1173
1174 // Rule 5: Include subject claim with user DID
1175 let access_token_claims = json!({
1176 "iss": format!("https://{}", config.host_name),
1177 "sub": auth_code.subject, // User's DID
1178 "aud": client_id,
1179 "exp": access_token_expires_at,
1180 "iat": now,
1181 "cnf": {
1182 "jkt": dpop_thumbprint_res // Rule 1: Bind to DPoP key
1183 },
1184 "scope": auth_code.scope
1185 });
1186
1187 let access_token = crate::auth::sign(&skey, "at+jwt", &access_token_claims)
1188 .context("failed to sign access token")?;
1189
1190 // Create refresh token with similar binding
1191 let refresh_token_claims = json!({
1192 "iss": format!("https://{}", config.host_name),
1193 "sub": auth_code.subject,
1194 "aud": client_id,
1195 "exp": refresh_token_expires_at,
1196 "iat": now,
1197 "cnf": {
1198 "jkt": dpop_thumbprint_res // Rule 1: Bind to DPoP key
1199 },
1200 "scope": auth_code.scope
1201 });
1202
1203 let refresh_token = crate::auth::sign(&skey, "rt+jwt", &refresh_token_claims)
1204 .context("failed to sign refresh token")?;
1205
1206 // Store the refresh token with DPoP binding
1207 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema;
1208 let refresh_token_cloned = refresh_token.to_owned();
1209 let client_id_cloned = client_id.to_owned();
1210 let subject = auth_code.subject.to_owned();
1211 let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned();
1212 let scope = auth_code.scope.to_owned();
1213 let created_at = now;
1214 let expires_at = refresh_token_expires_at;
1215 _ = db
1216 .get()
1217 .await
1218 .expect("Failed to get database connection")
1219 .interact(move |conn| {
1220 insert_into(RefreshTokenSchema::oauth_refresh_tokens)
1221 .values((
1222 RefreshTokenSchema::token.eq(refresh_token_cloned),
1223 RefreshTokenSchema::client_id.eq(client_id_cloned),
1224 RefreshTokenSchema::subject.eq(subject),
1225 RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned),
1226 RefreshTokenSchema::scope.eq(scope),
1227 RefreshTokenSchema::created_at.eq(created_at),
1228 RefreshTokenSchema::expires_at.eq(expires_at),
1229 RefreshTokenSchema::revoked.eq(false),
1230 ))
1231 .execute(conn)
1232 })
1233 .await
1234 .expect("Failed to store refresh token")
1235 .expect("Failed to store refresh token");
1236
1237 // Return token response with the subject claim
1238 Ok(Json(json!({
1239 "access_token": access_token,
1240 "token_type": "DPoP",
1241 "expires_in": access_token_expires_in,
1242 "refresh_token": refresh_token,
1243 "scope": auth_code.scope,
1244 "sub": auth_code.subject // Rule 5: Include subject claim
1245 })))
1246 }
1247 "refresh_token" => {
1248 // Process refresh token grant
1249 let refresh_token = form_data
1250 .get("refresh_token")
1251 .context("refresh_token is required")?;
1252
1253 let timestamp = chrono::Utc::now().timestamp();
1254
1255 // Rules 7 & 8: Verify refresh token and DPoP consistency
1256 // Retrieve the refresh token
1257 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema;
1258 #[derive(Queryable, Selectable, Serialize)]
1259 #[diesel(table_name = crate::schema::pds::oauth_refresh_tokens)]
1260 #[diesel(check_for_backend(sqlite::Sqlite))]
1261 struct TokenData {
1262 token: String,
1263 client_id: String,
1264 subject: String,
1265 dpop_thumbprint: String,
1266 scope: Option<String>,
1267 created_at: i64,
1268 expires_at: i64,
1269 revoked: bool,
1270 }
1271 let dpop_thumbprint_clone = dpop_thumbprint_res.to_owned();
1272 let refresh_token_clone = refresh_token.to_owned();
1273 let client_id_clone = client_id.to_owned();
1274 let token_data = db
1275 .get()
1276 .await
1277 .expect("Failed to get database connection")
1278 .interact(move |conn| {
1279 RefreshTokenSchema::oauth_refresh_tokens
1280 .filter(RefreshTokenSchema::token.eq(refresh_token_clone))
1281 .filter(RefreshTokenSchema::client_id.eq(client_id_clone))
1282 .filter(RefreshTokenSchema::expires_at.gt(timestamp))
1283 .filter(RefreshTokenSchema::revoked.eq(false))
1284 .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_clone))
1285 .first::<TokenData>(conn)
1286 .optional()
1287 })
1288 .await
1289 .expect("Failed to query refresh token")
1290 .expect("Failed to query refresh token")
1291 .expect("Failed to query refresh token");
1292
1293 // Rule 10: For confidential clients, verify key is still advertised in their jwks
1294 if is_confidential_client {
1295 let client_still_advertises_key = true; // Implement actual check against client jwks
1296 if !client_still_advertises_key {
1297 // Revoke all tokens bound to this key
1298 let client_id_cloned = client_id.to_owned();
1299 let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned();
1300 _ = db
1301 .get()
1302 .await
1303 .expect("Failed to get database connection")
1304 .interact(move |conn| {
1305 update(RefreshTokenSchema::oauth_refresh_tokens)
1306 .filter(RefreshTokenSchema::client_id.eq(client_id_cloned))
1307 .filter(
1308 RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned),
1309 )
1310 .set(RefreshTokenSchema::revoked.eq(true))
1311 .execute(conn)
1312 })
1313 .await
1314 .expect("Failed to revoke tokens")
1315 .expect("Failed to revoke tokens");
1316
1317 return Err(Error::with_status(
1318 StatusCode::BAD_REQUEST,
1319 anyhow!("Key no longer advertised by client"),
1320 ));
1321 }
1322 }
1323
1324 // Rotate the refresh token
1325 let refresh_token_cloned = refresh_token.to_owned();
1326 _ = db
1327 .get()
1328 .await
1329 .expect("Failed to get database connection")
1330 .interact(move |conn| {
1331 update(RefreshTokenSchema::oauth_refresh_tokens)
1332 .filter(RefreshTokenSchema::token.eq(refresh_token_cloned))
1333 .set(RefreshTokenSchema::revoked.eq(true))
1334 .execute(conn)
1335 })
1336 .await
1337 .expect("Failed to revoke old refresh token")
1338 .expect("Failed to revoke old refresh token");
1339
1340 // Generate new tokens
1341 let now = chrono::Utc::now().timestamp();
1342 let access_token_expires_in = 3600_i64;
1343 let access_token_expires_at = now.saturating_add(access_token_expires_in);
1344
1345 // Maintain the original expiry policy for refresh tokens
1346 let original_duration = token_data.expires_at.saturating_sub(token_data.created_at);
1347 let refresh_token_expires_at = now.saturating_add(original_duration);
1348
1349 // Create access token
1350 let access_token_claims = json!({
1351 "iss": format!("https://{}", config.host_name),
1352 "sub": token_data.subject,
1353 "aud": client_id,
1354 "exp": access_token_expires_at,
1355 "iat": now,
1356 "cnf": {
1357 "jkt": dpop_thumbprint_res
1358 },
1359 "scope": token_data.scope
1360 });
1361
1362 let access_token = crate::auth::sign(&skey, "at+jwt", &access_token_claims)
1363 .context("failed to sign access token")?;
1364
1365 // Create new refresh token
1366 let new_refresh_token_claims = json!({
1367 "iss": format!("https://{}", config.host_name),
1368 "sub": token_data.subject,
1369 "aud": client_id,
1370 "exp": refresh_token_expires_at,
1371 "iat": now,
1372 "cnf": {
1373 "jkt": dpop_thumbprint_res
1374 },
1375 "scope": token_data.scope
1376 });
1377
1378 let new_refresh_token = crate::auth::sign(&skey, "rt+jwt", &new_refresh_token_claims)
1379 .context("failed to sign refresh token")?;
1380
1381 // Store the new refresh token
1382 let new_refresh_token_cloned = new_refresh_token.to_owned();
1383 let client_id_cloned = client_id.to_owned();
1384 let subject = token_data.subject.to_owned();
1385 let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned();
1386 let scope = token_data.scope.to_owned();
1387 let created_at = now;
1388 let expires_at = refresh_token_expires_at;
1389 _ = db
1390 .get()
1391 .await
1392 .expect("Failed to get database connection")
1393 .interact(move |conn| {
1394 insert_into(RefreshTokenSchema::oauth_refresh_tokens)
1395 .values((
1396 RefreshTokenSchema::token.eq(new_refresh_token_cloned),
1397 RefreshTokenSchema::client_id.eq(client_id_cloned),
1398 RefreshTokenSchema::subject.eq(subject),
1399 RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned),
1400 RefreshTokenSchema::scope.eq(scope),
1401 RefreshTokenSchema::created_at.eq(created_at),
1402 RefreshTokenSchema::expires_at.eq(expires_at),
1403 RefreshTokenSchema::revoked.eq(false),
1404 ))
1405 .execute(conn)
1406 })
1407 .await
1408 .expect("Failed to store refresh token")
1409 .expect("Failed to store refresh token");
1410
1411 // Return token response
1412 Ok(Json(json!({
1413 "access_token": access_token,
1414 "token_type": "DPoP",
1415 "expires_in": access_token_expires_in,
1416 "refresh_token": new_refresh_token,
1417 "scope": token_data.scope,
1418 "sub": token_data.subject
1419 })))
1420 }
1421 _ => Err(Error::with_status(
1422 StatusCode::BAD_REQUEST,
1423 anyhow!("unsupported grant_type: {}", grant_type),
1424 )),
1425 }
1426}
1427
1428/// JWKS (JSON Web Key Set) endpoint
1429/// - GET `/oauth/jwks`
1430///
1431/// Returns the server's public keys as a JWKS document
1432///
1433/// WIP: Practically unimplemented
1434///
1435/// TODO: Figure out if/how this actually works
1436async fn jwks(State(skey): State<SigningKey>) -> Result<Json<Value>> {
1437 let did_string = skey.did();
1438
1439 // Extract the public key data from the DID string
1440 let (_, public_key_bytes) =
1441 atrium_crypto::did::parse_did_key(&did_string).context("failed to parse did key")?;
1442
1443 // Secp256k1 uncompressed public keys should be 65 bytes: 0x04 + x(32 bytes) + y(32 bytes)
1444 if public_key_bytes.len() != 65 || public_key_bytes.first().copied() != Some(0x04) {
1445 return Err(Error::with_status(
1446 StatusCode::INTERNAL_SERVER_ERROR,
1447 anyhow!("unexpected public key format"),
1448 ));
1449 }
1450
1451 // Extract and encode the X and Y coordinates
1452 let x_coord = public_key_bytes
1453 .get(1..33)
1454 .map(|slice| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(slice))
1455 .context("failed to extract X coordinate")?;
1456
1457 let y_coord = public_key_bytes
1458 .get(33..65)
1459 .map(|slice| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(slice))
1460 .context("failed to extract Y coordinate")?;
1461
1462 // Create a stable key ID based on the DID
1463 let key_id = did_string.strip_prefix("did:key:").unwrap_or(&did_string);
1464
1465 let jwk = json!({
1466 "kty": "EC",
1467 "crv": "secp256k1",
1468 "kid": key_id,
1469 "use": "sig",
1470 "alg": "ES256K",
1471 "x": x_coord,
1472 "y": y_coord
1473 });
1474
1475 // Return the JWKS document
1476 Ok(Json(json!({
1477 "keys": [jwk]
1478 })))
1479}
1480
1481/// Token Revocation endpoint
1482/// - POST `/oauth/revoke`
1483///
1484/// Implements RFC7009 for revoking refresh tokens
1485async fn revoke(
1486 State(db): State<Pool>,
1487 Json(form_data): Json<HashMap<String, String>>,
1488) -> Result<Json<Value>> {
1489 // Extract required parameters
1490 let token = form_data.get("token").context("token is required")?;
1491 let refresh_token_string = "refresh_token".to_owned();
1492 let token_type_hint = form_data
1493 .get("token_type_hint")
1494 .unwrap_or(&refresh_token_string);
1495
1496 // We only support revoking refresh tokens
1497 if token_type_hint != "refresh_token" {
1498 return Err(Error::with_status(
1499 StatusCode::BAD_REQUEST,
1500 anyhow!("unsupported token_type_hint: {}", token_type_hint),
1501 ));
1502 }
1503
1504 // Revoke the token
1505 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema;
1506 let token_cloned = token.to_owned();
1507 _ = db
1508 .get()
1509 .await
1510 .expect("Failed to get database connection")
1511 .interact(move |conn| {
1512 update(RefreshTokenSchema::oauth_refresh_tokens)
1513 .filter(RefreshTokenSchema::token.eq(token_cloned))
1514 .set(RefreshTokenSchema::revoked.eq(true))
1515 .execute(conn)
1516 })
1517 .await
1518 .expect("Failed to revoke token")
1519 .expect("Failed to revoke token");
1520
1521 // RFC7009 requires a 200 OK with an empty response
1522 Ok(Json(json!({})))
1523}
1524
1525/// Token Introspection endpoint
1526/// - POST `/oauth/introspect`
1527///
1528/// Implements RFC7662 for introspecting tokens
1529async fn introspect(
1530 State(db): State<Pool>,
1531 State(skey): State<SigningKey>,
1532 Json(form_data): Json<HashMap<String, String>>,
1533) -> Result<Json<Value>> {
1534 // Extract required parameters
1535 let token = form_data.get("token").context("token is required")?;
1536 let token_type_hint = form_data.get("token_type_hint");
1537
1538 // Parse the token
1539 let Ok((typ, claims)) = crate::auth::verify(&skey.did(), token) else {
1540 // Per RFC7662, invalid tokens return { "active": false }
1541 return Ok(Json(json!({"active": false})));
1542 };
1543
1544 // Check token type
1545 let is_refresh_token = typ == "rt+jwt";
1546 let is_access_token = typ == "at+jwt";
1547
1548 if !is_access_token && !is_refresh_token {
1549 return Ok(Json(json!({"active": false})));
1550 }
1551
1552 // If token_type_hint is provided, verify it matches
1553 if let Some(hint) = token_type_hint {
1554 if (hint == "refresh_token" && !is_refresh_token)
1555 || (hint == "access_token" && !is_access_token)
1556 {
1557 return Ok(Json(json!({"active": false})));
1558 }
1559 }
1560
1561 // Check expiration
1562 if let Some(exp) = claims.get("exp").and_then(Value::as_i64) {
1563 let now = chrono::Utc::now().timestamp();
1564 if now >= exp {
1565 return Ok(Json(json!({"active": false})));
1566 }
1567 } else {
1568 return Ok(Json(json!({"active": false})));
1569 }
1570
1571 // For refresh tokens, check if it's been revoked
1572 if is_refresh_token {
1573 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema;
1574 let token_cloned = token.to_owned();
1575 let is_revoked = db
1576 .get()
1577 .await
1578 .expect("Failed to get database connection")
1579 .interact(move |conn| {
1580 RefreshTokenSchema::oauth_refresh_tokens
1581 .filter(RefreshTokenSchema::token.eq(token_cloned))
1582 .select(RefreshTokenSchema::revoked)
1583 .first::<bool>(conn)
1584 .optional()
1585 })
1586 .await
1587 .expect("Failed to query token")
1588 .expect("Failed to query token")
1589 .unwrap_or(true);
1590
1591 if is_revoked {
1592 return Ok(Json(json!({"active": false})));
1593 }
1594 }
1595
1596 // Token is valid, return introspection info
1597 let subject = claims.get("sub").and_then(Value::as_str);
1598 let client_id = claims.get("aud").and_then(Value::as_str);
1599 let scope = claims.get("scope").and_then(Value::as_str);
1600 let expiration = claims.get("exp").and_then(Value::as_i64);
1601 let issued_at = claims.get("iat").and_then(Value::as_i64);
1602
1603 Ok(Json(json!({
1604 "active": true,
1605 "sub": subject,
1606 "client_id": client_id,
1607 "scope": scope,
1608 "exp": expiration,
1609 "iat": issued_at,
1610 "token_type": if is_access_token { "access_token" } else { "refresh_token" }
1611 })))
1612}
1613
1614/// Register OAuth routes
1615pub(crate) fn routes() -> Router<AppState> {
1616 Router::new()
1617 .route(
1618 "/.well-known/oauth-protected-resource",
1619 get(protected_resource),
1620 )
1621 .route(
1622 "/.well-known/oauth-authorization-server",
1623 get(authorization_server),
1624 )
1625 .route("/oauth/par", post(par))
1626 .route("/oauth/authorize", get(authorize))
1627 .route("/oauth/authorize/sign-in", post(authorize_signin))
1628 .route("/oauth/token", post(token))
1629 .route("/oauth/jwks", get(jwks))
1630 .route("/oauth/revoke", post(revoke))
1631 .route("/oauth/introspect", post(introspect))
1632}