Alternative ATProto PDS implementation
at oauth 63 kB view raw
1//! OAuth endpoints 2#![allow(unnameable_types, unused_qualifications)] 3use crate::config::AppConfig; 4use crate::error::Error; 5use crate::metrics::AUTH_FAILED; 6use crate::serve::{AppState, Client, Result, SigningKey}; 7use anyhow::{Context as _, anyhow}; 8use argon2::{Argon2, PasswordHash, PasswordVerifier as _}; 9use atrium_crypto::keypair::Did as _; 10use axum::response::Redirect; 11use axum::{ 12 Json, Router, extract, 13 extract::{Query, State}, 14 http::{HeaderMap, HeaderValue, StatusCode, header}, 15 response::IntoResponse, 16 routing::{get, post}, 17}; 18use base64::Engine as _; 19use deadpool_diesel::sqlite::Pool; 20use diesel::*; 21use metrics::counter; 22use rand::distributions::Alphanumeric; 23use rand::{Rng as _, thread_rng}; 24use serde::{Deserialize, Serialize}; 25use serde_json::{Value, json}; 26use sha2::{Digest as _, Sha256}; 27use std::collections::{HashMap, HashSet}; 28 29/// JWK thumbprint required properties for each key type (RFC7638) 30/// 31/// Currently only supporting ES256K (Secp256k1) and RSA as those are 32/// commonly used in DPoP proofs with ATProto 33const JWK_REQUIRED_PROPS: &[(&str, &[&str])] = &[ 34 ("EC", &["crv", "kty", "x", "y"]), 35 ("RSA", &["e", "kty", "n"]), 36]; 37 38/// JWT ID used record for tracking used JTIs to prevent replay attacks 39#[derive(Debug, Serialize, Deserialize)] 40struct JtiRecord { 41 expires_at: i64, 42 issuer: String, 43 jti: String, 44} 45 46/// Parses a JWT without validation and returns header and claims 47fn parse_jwt(token: &str) -> Result<(Value, Value)> { 48 let parts: Vec<&str> = token.split('.').collect(); 49 if parts.len() != 3 { 50 return Err(Error::with_status( 51 StatusCode::BAD_REQUEST, 52 anyhow!("Invalid JWT format"), 53 )); 54 } 55 56 let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 57 .decode(parts.first().expect("should have JWT header")) 58 .context("Failed to decode JWT header")?; 59 60 let header: Value = 61 serde_json::from_slice(&header_bytes).context("Failed to parse JWT header as JSON")?; 62 63 let claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD 64 .decode(parts.get(1).expect("should have JWT claims")) 65 .context("Failed to decode JWT claims")?; 66 67 let claims: Value = 68 serde_json::from_slice(&claims_bytes).context("Failed to parse JWT claims as JSON")?; 69 70 Ok((header, claims)) 71} 72 73/// Calculate RFC7638 compliant JWK thumbprint 74/// 75/// This follows the standard: 76/// 1. Extract only the required members for the key type 77/// 2. Sort members alphabetically 78/// 3. Remove whitespace in the serialization 79/// 4. SHA-256 hash and base64url encode 80fn calculate_jwk_thumbprint(jwk: &Value) -> Result<String> { 81 // Determine the key type 82 let key_type = jwk 83 .get("kty") 84 .and_then(Value::as_str) 85 .context("JWK missing kty property")?; 86 87 // Find required properties for this key type 88 let required_props = JWK_REQUIRED_PROPS 89 .iter() 90 .find(|&&(kty, _)| kty == key_type) 91 .map(|&(_, props)| props) 92 .context(anyhow!("Unsupported key type: {key_type}"))?; 93 94 // Build a new JWK with only the required properties 95 let mut canonical_jwk = serde_json::Map::new(); 96 97 for &prop in required_props { 98 let value = jwk 99 .get(prop) 100 .context(anyhow!("JWK missing required property: {prop}"))?; 101 drop(canonical_jwk.insert((*prop).to_string(), value.clone())); 102 } 103 104 // Serialize with no whitespace 105 let canonical_json = serde_json::to_string(&Value::Object(canonical_jwk)) 106 .context("Failed to serialize canonical JWK")?; 107 108 // Hash the canonical representation 109 let mut hasher = Sha256::new(); 110 hasher.update(canonical_json.as_bytes()); 111 let thumbprint = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 112 113 Ok(thumbprint) 114} 115 116/// Protected Resource Metadata 117/// - GET `/.well-known/oauth-protected-resource` 118async fn protected_resource(State(config): State<AppConfig>) -> Result<Json<Value>> { 119 Ok(Json(json!({ 120 "resource": format!("https://{}", config.host_name), 121 "authorization_servers": [format!("https://{}", config.host_name)], 122 "scopes_supported": [], 123 "bearer_methods_supported": ["header"], 124 "resource_documentation": "https://atproto.com", 125 }))) 126} 127 128/// Authorization Server Metadata 129/// - GET `/.well-known/oauth-authorization-server` 130async fn authorization_server(State(config): State<AppConfig>) -> Result<Json<Value>> { 131 let base_url = format!("https://{}", config.host_name); 132 133 Ok(Json(serde_json::json!({ 134 "issuer": base_url, 135 "request_parameter_supported": true, 136 "request_uri_parameter_supported": true, 137 "require_request_uri_registration": true, 138 "scopes_supported": ["atproto", "transition:generic", "transition:chat.bsky"], 139 "subject_types_supported": ["public"], 140 "response_types_supported": ["code"], 141 "response_modes_supported": ["query", "fragment", "form_post"], 142 "grant_types_supported": ["authorization_code", "refresh_token"], 143 "code_challenge_methods_supported": ["S256"], 144 "ui_locales_supported": ["en-US"], 145 "display_values_supported": ["page", "popup", "touch"], 146 "request_object_signing_alg_values_supported": ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES256K", "ES384", "ES512"], 147 "authorization_response_iss_parameter_supported": true, 148 "request_object_encryption_alg_values_supported": [], 149 "request_object_encryption_enc_values_supported": [], 150 "jwks_uri": format!("{}/oauth/jwks", base_url), 151 "authorization_endpoint": format!("{}/oauth/authorize", base_url), 152 "token_endpoint": format!("{}/oauth/token", base_url), 153 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 154 "token_endpoint_auth_signing_alg_values_supported": ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES256K", "ES384", "ES512"], 155 "revocation_endpoint": format!("{}/oauth/revoke", base_url), 156 "introspection_endpoint": format!("{}/oauth/introspect", base_url), 157 "pushed_authorization_request_endpoint": format!("{}/oauth/par", base_url), 158 "require_pushed_authorization_requests": true, 159 "dpop_signing_alg_values_supported": ["RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "ES256", "ES256K", "ES384", "ES512"], 160 "client_id_metadata_document_supported": true 161 }))) 162} 163 164/// Fetch and validate client metadata from `client_id` URL 165async fn fetch_client_metadata(client: &Client, client_id: &str) -> Result<Value> { 166 // Handle localhost development 167 if client_id.starts_with("http://localhost") { 168 let client_url = url::Url::parse(client_id).context("client_id must be a valid URL")?; 169 170 let mut metadata = json!({ 171 "client_id": client_id, 172 "client_name": "Loopback client", 173 "response_types": ["code"], 174 "grant_types": ["authorization_code", "refresh_token"], 175 "scope": "atproto transition:generic", 176 "token_endpoint_auth_method": "none", 177 "application_type": "native", 178 "dpop_bound_access_tokens": true, 179 }); 180 181 // Extract redirect_uri from query params if available 182 let redirect_uris = client_url.query().map_or_else( 183 || { 184 vec![ 185 json!("http://127.0.0.1/callback"), 186 json!("http://[::1]/callback"), 187 ] 188 }, 189 |query| { 190 let pairs: HashMap<_, _> = url::form_urlencoded::parse(query.as_bytes()).collect(); 191 pairs.get("redirect_uri").map_or_else( 192 || { 193 vec![ 194 json!("http://127.0.0.1/callback"), 195 json!("http://[::1]/callback"), 196 ] 197 }, 198 |uri| vec![json!(uri)], 199 ) 200 }, 201 ); 202 203 if let Some(redirect_uris_value) = metadata.as_object_mut() { 204 drop(redirect_uris_value.insert("redirect_uris".to_owned(), json!(redirect_uris))); 205 } 206 207 return Ok(metadata); 208 } 209 210 // If not in dev environment, fetch metadata 211 let response = client 212 .get(client_id) 213 .send() 214 .await 215 .context("Failed to fetch client metadata")?; 216 217 if !response.status().is_success() { 218 return Err(Error::with_status( 219 StatusCode::BAD_REQUEST, 220 anyhow!( 221 "Failed to fetch client metadata: HTTP {}", 222 response.status() 223 ), 224 )); 225 } 226 227 let metadata: Value = response 228 .json() 229 .await 230 .context("Failed to parse client metadata as JSON")?; 231 232 // Validate client_id in metadata matches requested client_id 233 if metadata.get("client_id").and_then(|id| id.as_str()) != Some(client_id) { 234 return Err(Error::with_status( 235 StatusCode::BAD_REQUEST, 236 anyhow!("client_id in metadata doesn't match requested client_id"), 237 )); 238 } 239 240 // Validate DPoP tokens requirement 241 if !metadata 242 .get("dpop_bound_access_tokens") 243 .and_then(Value::as_bool) 244 .unwrap_or(false) 245 { 246 return Err(Error::with_status( 247 StatusCode::BAD_REQUEST, 248 anyhow!("Client metadata must set dpop_bound_access_tokens to true"), 249 )); 250 } 251 252 Ok(metadata) 253} 254 255/// Pushed Authorization Request endpoint 256/// POST `/oauth/par` 257#[expect(clippy::too_many_lines)] 258async fn par( 259 State(db): State<Pool>, 260 State(client): State<Client>, 261 Json(form_data): Json<HashMap<String, String>>, 262) -> Result<Json<Value>> { 263 // Required parameters 264 let client_id = form_data 265 .get("client_id") 266 .context("client_id parameter is required")?; 267 let response_type = form_data 268 .get("response_type") 269 .context("response_type parameter is required")?; 270 let code_challenge = form_data 271 .get("code_challenge") 272 .context("code_challenge parameter is required")?; 273 let code_challenge_method = form_data 274 .get("code_challenge_method") 275 .context("code_challenge_method parameter is required")?; 276 277 // Ensure code_challenge_method is S256 (required by spec) 278 if code_challenge_method != "S256" { 279 return Err(Error::with_status( 280 StatusCode::BAD_REQUEST, 281 anyhow!("code_challenge_method must be S256"), 282 )); 283 } 284 285 // Validate response_type is "code" (our spec only supports this) 286 if response_type != "code" { 287 return Err(Error::with_status( 288 StatusCode::BAD_REQUEST, 289 anyhow!("response_type must be code"), 290 )); 291 } 292 293 // Optional parameters 294 let state = form_data.get("state").cloned(); 295 let login_hint = form_data.get("login_hint").cloned(); 296 let scope = form_data.get("scope").cloned(); 297 let redirect_uri = form_data.get("redirect_uri").cloned(); 298 let response_mode = form_data.get("response_mode").cloned(); 299 let display = form_data.get("display").cloned(); 300 301 // Validate client metadata 302 let client_metadata = fetch_client_metadata(&client, client_id).await?; 303 304 // If redirect_uri is provided, validate it's in the client's allowed list 305 if let Some(ref provided_uri) = redirect_uri { 306 let allowed_uris = client_metadata 307 .get("redirect_uris") 308 .and_then(|uris| uris.as_array()) 309 .context("client metadata missing redirect_uris")?; 310 311 let uri_valid = allowed_uris 312 .iter() 313 .any(|uri| uri.as_str().is_some_and(|uri_str| uri_str == provided_uri)); 314 315 if !uri_valid { 316 return Err(Error::with_status( 317 StatusCode::BAD_REQUEST, 318 anyhow!("redirect_uri not allowed for this client"), 319 )); 320 } 321 } else if client_metadata 322 .get("redirect_uris") 323 .and_then(|uris| uris.as_array()) 324 .map_or(0, Vec::len) 325 != 1 326 { 327 return Err(Error::with_status( 328 StatusCode::BAD_REQUEST, 329 anyhow!("redirect_uri required when client has multiple registered URIs"), 330 )); 331 } 332 333 // Validate scope is in allowed scope for client 334 if let Some(ref requested_scope) = scope { 335 if let Some(allowed_scope) = client_metadata.get("scope").and_then(|s| s.as_str()) { 336 let requested_scopes: HashSet<&str> = requested_scope.split_whitespace().collect(); 337 let allowed_scopes: HashSet<&str> = allowed_scope.split_whitespace().collect(); 338 339 if !requested_scopes.is_subset(&allowed_scopes) { 340 return Err(Error::with_status( 341 StatusCode::BAD_REQUEST, 342 anyhow!("requested scope exceeds allowed scope"), 343 )); 344 } 345 } 346 } 347 348 // Generate a unique request_uri 349 let request_id = thread_rng() 350 .sample_iter(Alphanumeric) 351 .take(32) 352 .map(char::from) 353 .collect::<String>(); 354 let request_uri = format!("urn:ietf:params:oauth:request_uri:req-{request_id}"); 355 356 // Store request data in the database 357 let now = chrono::Utc::now(); 358 let created_at = now.timestamp(); 359 let expires_at = now 360 .checked_add_signed(chrono::Duration::minutes(5)) 361 .context("failed to compute expiration time")? 362 .timestamp(); 363 364 use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 365 let client_id = client_id.to_owned(); 366 let request_uri_cloned = request_uri.to_owned(); 367 let response_type = response_type.to_owned(); 368 let code_challenge = code_challenge.to_owned(); 369 let code_challenge_method = code_challenge_method.to_owned(); 370 _ = db 371 .get() 372 .await 373 .expect("Failed to get database connection") 374 .interact(move |conn| { 375 insert_into(ParRequestSchema::oauth_par_requests) 376 .values(( 377 ParRequestSchema::request_uri.eq(&request_uri_cloned), 378 ParRequestSchema::client_id.eq(client_id), 379 ParRequestSchema::response_type.eq(response_type), 380 ParRequestSchema::code_challenge.eq(code_challenge), 381 ParRequestSchema::code_challenge_method.eq(code_challenge_method), 382 ParRequestSchema::state.eq(state), 383 ParRequestSchema::login_hint.eq(login_hint), 384 ParRequestSchema::scope.eq(scope), 385 ParRequestSchema::redirect_uri.eq(redirect_uri), 386 ParRequestSchema::response_mode.eq(response_mode), 387 ParRequestSchema::display.eq(display), 388 ParRequestSchema::created_at.eq(created_at), 389 ParRequestSchema::expires_at.eq(expires_at), 390 )) 391 .execute(conn) 392 }) 393 .await 394 .expect("Failed to store PAR request") 395 .expect("Failed to store PAR request"); 396 397 Ok(Json(json!({ 398 "request_uri": request_uri, 399 "expires_in": 300_i32 // 5 minutes 400 }))) 401} 402 403/// OAuth Authorization endpoint 404/// GET `/oauth/authorize` 405async fn authorize( 406 State(db): State<Pool>, 407 State(client): State<Client>, 408 Query(params): Query<HashMap<String, String>>, 409) -> Result<impl IntoResponse> { 410 // Required parameters 411 let client_id = params 412 .get("client_id") 413 .context("client_id parameter is required")?; 414 let request_uri = params 415 .get("request_uri") 416 .context("request_uri parameter is required")?; 417 418 let timestamp = chrono::Utc::now().timestamp(); 419 420 // Retrieve the PAR request from the database 421 use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 422 423 let request_uri_clone = request_uri.to_owned(); 424 let client_id_clone = client_id.to_owned(); 425 let timestamp_clone = timestamp.clone(); 426 let login_hint = db 427 .get() 428 .await 429 .expect("Failed to get database connection") 430 .interact(move |conn| { 431 ParRequestSchema::oauth_par_requests 432 .select(ParRequestSchema::login_hint) 433 .filter(ParRequestSchema::request_uri.eq(request_uri_clone)) 434 .filter(ParRequestSchema::client_id.eq(client_id_clone)) 435 .filter(ParRequestSchema::expires_at.gt(timestamp_clone)) 436 .first::<Option<String>>(conn) 437 .optional() 438 }) 439 .await 440 .expect("Failed to query PAR request") 441 .expect("Failed to query PAR request") 442 .expect("Failed to query PAR request"); 443 444 // Validate client metadata 445 let client_metadata = fetch_client_metadata(&client, client_id).await?; 446 447 // Authorization page with login form 448 let login_hint = login_hint.unwrap_or_default(); 449 let html = format!( 450 r#"<!DOCTYPE html> 451 <html> 452 <head> 453 <title>Authentication Required</title> 454 <meta name="viewport" content="width=device-width, initial-scale=1"> 455 <style> 456 body {{ font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, Helvetica, Arial, sans-serif; max-width: 500px; margin: 0 auto; padding: 20px; }} 457 .container {{ border: 1px solid #e0e0e0; border-radius: 8px; padding: 20px; }} 458 h1 {{ font-size: 24px; }} 459 label {{ display: block; margin-top: 12px; }} 460 input[type="text"], input[type="password"] {{ width: 100%; padding: 8px; margin-top: 4px; border: 1px solid #ddd; border-radius: 4px; }} 461 button {{ margin-top: 20px; padding: 8px 16px; background-color: #0070f3; color: white; border: none; border-radius: 4px; cursor: pointer; }} 462 .client {{ margin-top: 12px; font-size: 14px; color: #666; }} 463 </style> 464 </head> 465 <body> 466 <div class="container"> 467 <h1>Sign in to continue</h1> 468 <p>An application is requesting access to your account.</p> 469 470 <div class="client"> 471 <strong>Client:</strong> {client_name} 472 </div> 473 474 <form action="/oauth/authorize/sign-in" method="post"> 475 <input type="hidden" name="request_uri" value="{request_uri}"> 476 <input type="hidden" name="client_id" value="{client_id}"> 477 478 <label for="username">Username</label> 479 <input type="text" id="username" name="username" value="{login_hint}" required> 480 481 <label for="password">Password</label> 482 <input type="password" id="password" name="password" required> 483 484 <label> 485 <input type="checkbox" name="remember" value="true"> Remember me 486 </label> 487 488 <button type="submit">Sign in</button> 489 </form> 490 </div> 491 </body> 492 </html> 493 "#, 494 client_name = client_metadata 495 .get("client_name") 496 .and_then(|n| n.as_str()) 497 .unwrap_or(client_id), 498 request_uri = request_uri, 499 client_id = client_id, 500 login_hint = login_hint 501 ); 502 503 Ok(( 504 StatusCode::OK, 505 [(header::CONTENT_TYPE, HeaderValue::from_static("text/html"))], 506 html, 507 )) 508} 509 510/// OAuth Authorization Sign-in endpoint 511/// POST `/oauth/authorize/sign-in` 512#[expect(clippy::too_many_lines)] 513async fn authorize_signin( 514 State(db): State<Pool>, 515 State(config): State<AppConfig>, 516 State(client): State<Client>, 517 extract::Form(form_data): extract::Form<HashMap<String, String>>, 518) -> Result<impl IntoResponse> { 519 use std::fmt::Write as _; 520 521 // Extract form data 522 let username = form_data.get("username").context("username is required")?; 523 let password = form_data.get("password").context("password is required")?; 524 let request_uri = form_data 525 .get("request_uri") 526 .context("request_uri is required")?; 527 let client_id = form_data 528 .get("client_id") 529 .context("client_id is required")?; 530 531 let timestamp = chrono::Utc::now().timestamp(); 532 533 // Retrieve the PAR request 534 use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 535 #[derive(Queryable, Selectable)] 536 #[diesel(table_name = crate::schema::pds::oauth_par_requests)] 537 #[diesel(check_for_backend(sqlite::Sqlite))] 538 struct ParRequest { 539 request_uri: String, 540 client_id: String, 541 response_type: String, 542 code_challenge: String, 543 code_challenge_method: String, 544 state: Option<String>, 545 login_hint: Option<String>, 546 scope: Option<String>, 547 redirect_uri: Option<String>, 548 response_mode: Option<String>, 549 display: Option<String>, 550 created_at: i64, 551 expires_at: i64, 552 } 553 let request_uri_clone = request_uri.to_owned(); 554 let client_id_clone = client_id.to_owned(); 555 let timestamp_clone = timestamp.clone(); 556 let par_request = db 557 .get() 558 .await 559 .expect("Failed to get database connection") 560 .interact(move |conn| { 561 ParRequestSchema::oauth_par_requests 562 .filter(ParRequestSchema::request_uri.eq(request_uri_clone)) 563 .filter(ParRequestSchema::client_id.eq(client_id_clone)) 564 .filter(ParRequestSchema::expires_at.gt(timestamp_clone)) 565 .first::<ParRequest>(conn) 566 .optional() 567 }) 568 .await 569 .expect("Failed to query PAR request") 570 .expect("Failed to query PAR request") 571 .expect("Failed to query PAR request"); 572 573 // Authenticate the user 574 use crate::schema::pds::account::dsl as AccountSchema; 575 use crate::schema::pds::actor::dsl as ActorSchema; 576 let username_clone = username.to_owned(); 577 let account = db 578 .get() 579 .await 580 .expect("Failed to get database connection") 581 .interact(move |conn| { 582 AccountSchema::account 583 .filter(AccountSchema::email.eq(username_clone)) 584 .first::<crate::models::pds::Account>(conn) 585 .optional() 586 }) 587 .await 588 .expect("Failed to query account") 589 .expect("Failed to query account") 590 .expect("Failed to query account"); 591 // let actor = db 592 // .get() 593 // .await 594 // .expect("Failed to get database connection") 595 // .interact(move |conn| { 596 // ActorSchema::actor 597 // .filter(ActorSchema::did.eq(did)) 598 // .first::<rsky_pds::models::Actor>(conn) 599 // .optional() 600 // }) 601 // .await 602 // .expect("Failed to query actor") 603 // .expect("Failed to query actor") 604 // .expect("Failed to query actor"); 605 606 // Verify password - fixed to use equality check instead of pattern matching 607 if Argon2::default().verify_password( 608 password.as_bytes(), 609 &PasswordHash::new(account.password.as_str()).context("invalid password hash in db")?, 610 ) == Ok(()) 611 { 612 } else { 613 counter!(AUTH_FAILED).increment(1); 614 return Err(Error::with_status( 615 StatusCode::UNAUTHORIZED, 616 anyhow!("invalid credentials"), 617 )); 618 } 619 620 // Generate authorization code 621 let code = thread_rng() 622 .sample_iter(Alphanumeric) 623 .take(40) 624 .map(char::from) 625 .collect::<String>(); 626 627 // Determine redirect URI to use 628 let redirect_uri = if let Some(uri) = par_request.redirect_uri.as_ref() { 629 uri.clone() 630 } else { 631 let client_metadata = fetch_client_metadata(&client, client_id).await?; 632 client_metadata 633 .get("redirect_uris") 634 .and_then(|uris| uris.as_array()) 635 .and_then(|uris| uris.first()) 636 .and_then(|uri| uri.as_str()) 637 .context("No redirect_uri available")? 638 .to_owned() 639 }; 640 641 // Store the authorization code 642 let now = chrono::Utc::now(); 643 let created_at = now.timestamp(); 644 let expires_at = now 645 .checked_add_signed(chrono::Duration::minutes(10)) 646 .context("failed to compute expiration time")? 647 .timestamp(); 648 649 use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema; 650 let code_cloned = code.to_owned(); 651 let client_id = client_id.to_owned(); 652 let subject = account.did.to_owned(); 653 let code_challenge = par_request.code_challenge.to_owned(); 654 let code_challenge_method = par_request.code_challenge_method.to_owned(); 655 let redirect_uri_cloned = redirect_uri.to_owned(); 656 let scope = par_request.scope.to_owned(); 657 let used = false; 658 _ = db 659 .get() 660 .await 661 .expect("Failed to get database connection") 662 .interact(move |conn| { 663 insert_into(AuthCodeSchema::oauth_authorization_codes) 664 .values(( 665 AuthCodeSchema::code.eq(code_cloned), 666 AuthCodeSchema::client_id.eq(client_id), 667 AuthCodeSchema::subject.eq(subject), 668 AuthCodeSchema::code_challenge.eq(code_challenge), 669 AuthCodeSchema::code_challenge_method.eq(code_challenge_method), 670 AuthCodeSchema::redirect_uri.eq(redirect_uri_cloned), 671 AuthCodeSchema::scope.eq(scope), 672 AuthCodeSchema::created_at.eq(created_at), 673 AuthCodeSchema::expires_at.eq(expires_at), 674 AuthCodeSchema::used.eq(used), 675 )) 676 .execute(conn) 677 }) 678 .await 679 .expect("Failed to store authorization code") 680 .expect("Failed to store authorization code"); 681 682 // Use state from the PAR request or generate one 683 let state = par_request.state.unwrap_or_else(|| { 684 thread_rng() 685 .sample_iter(Alphanumeric) 686 .take(16) 687 .map(char::from) 688 .collect::<String>() 689 }); 690 691 // Build redirect URL 692 let mut redirect_target = redirect_uri; 693 match par_request.response_mode.as_deref() { 694 Some("fragment") => redirect_target.push('#'), 695 _ => redirect_target.push('?'), 696 } 697 698 write!(redirect_target, "state={}", urlencoding::encode(&state)).unwrap(); 699 let host_name = format!("https://{}", &config.host_name); 700 write!(redirect_target, "&iss={}", urlencoding::encode(&host_name)).unwrap(); 701 write!(redirect_target, "&code={}", urlencoding::encode(&code)).unwrap(); 702 Ok(Redirect::to(&redirect_target)) 703} 704 705/// Verify a DPoP proof and return the JWK thumbprint 706/// RFC 7519 JSON Web Token (JWT) - 4.3. Checking DPoP Proofs 707/// To validate a DPoP proof, the receiving server MUST ensure the 708/// following: 709/// 1. There is not more than one DPoP HTTP request header field. 710/// 2. The DPoP HTTP request header field value is a single and well- 711/// formed JWT. 712/// 3. All required claims per Section 4.2 are contained in the JWT. 713/// 4. The typ JOSE Header Parameter has the value dpop+jwt. 714/// 5. The alg JOSE Header Parameter indicates a registered asymmetric 715/// digital signature algorithm [IANA.JOSE.ALGS], is not none, is 716/// supported by the application, and is acceptable per local 717/// policy. 718/// 6. The JWT signature verifies with the public key contained in the 719/// jwk JOSE Header Parameter. 720/// 7. The jwk JOSE Header Parameter does not contain a private key. 721/// 8. The htm claim matches the HTTP method of the current request. 722/// 9. The htu claim matches the HTTP URI value for the HTTP request in 723/// which the JWT was received, ignoring any query and fragment 724/// parts. 725/// 10. If the server provided a nonce value to the client, the nonce 726/// claim matches the server-provided nonce value. 727/// 11. The creation time of the JWT, as determined by either the iat 728/// claim or a server managed timestamp via the nonce claim, is 729/// within an acceptable window (see Section 11.1). 730/// 12. If presented to a protected resource in conjunction with an 731/// access token, 732/// * ensure that the value of the ath claim equals the hash of 733/// that access token, and 734/// * confirm that the public key to which the access token is 735/// bound matches the public key from the DPoP proof. 736#[expect(clippy::too_many_lines)] 737async fn verify_dpop_proof( 738 dpop_token: &str, 739 http_method: &str, 740 http_uri: &str, 741 db: &Pool, 742 access_token: Option<&str>, 743 bound_key_thumbprint: Option<&str>, 744) -> Result<String> { 745 // Parse the DPoP token 746 let (header, claims) = parse_jwt(dpop_token)?; 747 748 // 1. Verify "typ" is "dpop+jwt" (requirement #4) 749 if header.get("typ").and_then(Value::as_str) != Some("dpop+jwt") { 750 return Err(Error::with_status( 751 StatusCode::BAD_REQUEST, 752 anyhow!("Invalid DPoP token type"), 753 )); 754 } 755 756 // 2. Check alg header (requirement #5) 757 let alg = header 758 .get("alg") 759 .and_then(Value::as_str) 760 .context("Missing alg in DPoP header")?; 761 if alg == "none" || !["RS256", "ES256", "ES256K", "PS256"].contains(&alg) { 762 return Err(Error::with_status( 763 StatusCode::BAD_REQUEST, 764 anyhow!("Unsupported or insecure signature algorithm"), 765 )); 766 } 767 768 // 3. Extract JWK and verify no private key components (requirement #7) 769 let jwk = header.get("jwk").context("missing jwk in DPoP header")?; 770 if jwk.get("d").is_some() || jwk.get("p").is_some() || jwk.get("q").is_some() { 771 return Err(Error::with_status( 772 StatusCode::BAD_REQUEST, 773 anyhow!("JWK contains private key components"), 774 )); 775 } 776 777 // 4. Calculate JWK thumbprint 778 let thumbprint = calculate_jwk_thumbprint(jwk)?; 779 780 // 5. Verify JWT signature (requirement #6) 781 // TODO: Implement signature verification with the JWK 782 783 // 6. Verify required claims (requirement #3) 784 let jti = claims 785 .get("jti") 786 .and_then(Value::as_str) 787 .context("Missing jti claim in DPoP token")?; 788 789 // 7. Check HTTP method matches htm claim (requirement #8) 790 if claims.get("htm").and_then(Value::as_str) != Some(http_method) { 791 return Err(Error::with_status( 792 StatusCode::BAD_REQUEST, 793 anyhow!("DPoP token HTTP method mismatch"), 794 )); 795 } 796 797 // 8. Check HTTP URI matches htu claim (requirement #9) 798 // Should perform proper URI normalization for production use 799 if claims.get("htu").and_then(Value::as_str) != Some(http_uri) { 800 return Err(Error::with_status( 801 StatusCode::BAD_REQUEST, 802 anyhow!( 803 "DPoP token HTTP URI mismatch: expected {}, got {}", 804 http_uri, 805 claims.get("htu").and_then(Value::as_str).unwrap_or("None") 806 ), 807 )); 808 } 809 810 // 9. Verify token timestamps (requirement #11) 811 let now = chrono::Utc::now().timestamp(); 812 813 // Check creation time (iat) 814 if let Some(iat) = claims.get("iat").and_then(Value::as_i64) { 815 // Token not too old (5 minute max age) 816 if iat < now.saturating_sub(300) { 817 return Err(Error::with_status( 818 StatusCode::BAD_REQUEST, 819 anyhow!("DPoP token too old"), 820 )); 821 } 822 823 // Token not in the future (with clock skew allowance) 824 if iat > now.saturating_add(60) { 825 return Err(Error::with_status( 826 StatusCode::BAD_REQUEST, 827 anyhow!("DPoP token creation time is in the future"), 828 )); 829 } 830 } 831 832 // Check expiration (exp) if present 833 let exp = claims 834 .get("exp") 835 .and_then(Value::as_i64) 836 .unwrap_or_else(|| now.saturating_add(60)); // Default 60s if not present 837 838 if now >= exp { 839 return Err(Error::with_status( 840 StatusCode::BAD_REQUEST, 841 anyhow!("DPoP token has expired"), 842 )); 843 } 844 845 // 10. Verify access token binding (requirement #12) 846 if let Some(token) = access_token { 847 // Verify ath claim matches token hash 848 if let Some(ath) = claims.get("ath").and_then(Value::as_str) { 849 let mut hasher = Sha256::new(); 850 hasher.update(token.as_bytes()); 851 let token_hash = 852 base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 853 854 if ath != token_hash { 855 return Err(Error::with_status( 856 StatusCode::BAD_REQUEST, 857 anyhow!("DPoP access token hash mismatch"), 858 )); 859 } 860 } else { 861 return Err(Error::with_status( 862 StatusCode::BAD_REQUEST, 863 anyhow!("Missing ath claim for DPoP with access token"), 864 )); 865 } 866 867 // Verify key binding matches 868 if let Some(expected_thumbprint) = bound_key_thumbprint { 869 if thumbprint != expected_thumbprint { 870 return Err(Error::with_status( 871 StatusCode::BAD_REQUEST, 872 anyhow!("DPoP key doesn't match token-bound key"), 873 )); 874 } 875 } 876 } 877 878 // 11. Check for replay attacks via JTI tracking 879 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema; 880 let jti_clone = jti.to_owned(); 881 let jti_used = db 882 .get() 883 .await 884 .expect("Failed to get database connection") 885 .interact(move |conn| { 886 JtiSchema::oauth_used_jtis 887 .filter(JtiSchema::jti.eq(jti_clone)) 888 .count() 889 .get_result::<i64>(conn) 890 .optional() 891 }) 892 .await 893 .expect("Failed to check JTI") 894 .expect("Failed to check JTI") 895 .unwrap_or(0); 896 897 if jti_used > 0 { 898 return Err(Error::with_status( 899 StatusCode::BAD_REQUEST, 900 anyhow!("DPoP token has been replayed"), 901 )); 902 } 903 904 // 12. Store the JTI to prevent replay attacks 905 let jti_cloned = jti.to_owned(); 906 let issuer = thumbprint.to_owned(); 907 let created_at = now; 908 let expires_at = exp; 909 _ = db 910 .get() 911 .await 912 .expect("Failed to get database connection") 913 .interact(move |conn| { 914 insert_into(JtiSchema::oauth_used_jtis) 915 .values(( 916 JtiSchema::jti.eq(jti_cloned), 917 JtiSchema::issuer.eq(issuer), 918 JtiSchema::created_at.eq(created_at), 919 JtiSchema::expires_at.eq(expires_at), 920 )) 921 .execute(conn) 922 }) 923 .await 924 .expect("Failed to store JTI") 925 .expect("Failed to store JTI"); 926 927 // 13. Cleanup expired JTIs periodically (1% chance on each request) 928 if thread_rng().gen_range(0_i32..100_i32) == 0_i32 { 929 let now_clone = now.to_owned(); 930 _ = db 931 .get() 932 .await 933 .expect("Failed to get database connection") 934 .interact(move |conn| { 935 delete(JtiSchema::oauth_used_jtis) 936 .filter(JtiSchema::expires_at.lt(now_clone)) 937 .execute(conn) 938 }) 939 .await 940 .expect("Failed to clean up expired JTIs") 941 .expect("Failed to clean up expired JTIs"); 942 } 943 944 Ok(thumbprint) 945} 946 947/// Verify a `code_verifier` against stored `code_challenge` 948fn verify_pkce(code_verifier: &str, stored_challenge: &str, method: &str) -> Result<()> { 949 // Only S256 is supported currently 950 if method != "S256" { 951 return Err(Error::with_status( 952 StatusCode::BAD_REQUEST, 953 anyhow!("Unsupported code_challenge_method: {}", method), 954 )); 955 } 956 957 // Calculate the code challenge from verifier 958 let mut hasher = Sha256::new(); 959 hasher.update(code_verifier.as_bytes()); 960 let calculated = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()); 961 962 // Compare with stored challenge 963 if calculated != stored_challenge { 964 return Err(Error::with_status( 965 StatusCode::BAD_REQUEST, 966 anyhow!("Code verifier doesn't match challenge"), 967 )); 968 } 969 970 Ok(()) 971} 972 973/// OAuth token endpoint 974/// - POST `/oauth/token` 975/// 976/// Handles both `authorization_code` and `refresh_token` grants 977#[expect(clippy::too_many_lines)] 978async fn token( 979 State(db): State<Pool>, 980 State(skey): State<SigningKey>, 981 State(config): State<AppConfig>, 982 State(client): State<Client>, 983 headers: HeaderMap, 984 Json(form_data): Json<HashMap<String, String>>, 985) -> Result<Json<Value>> { 986 // Extract form parameters 987 let grant_type = form_data 988 .get("grant_type") 989 .context("grant_type is required")?; 990 let client_id = form_data 991 .get("client_id") 992 .context("client_id is required")?; 993 994 // Validate DPoP header (Rule 1: Ensure DPoP is used) 995 let dpop_token = headers 996 .get("DPoP") 997 .context("DPoP header is required")? 998 .to_str() 999 .context("Invalid DPoP header")?; 1000 1001 // Get client metadata to determine client type (public vs confidential) 1002 let client_metadata = fetch_client_metadata(&client, client_id).await?; 1003 let is_confidential_client = client_metadata 1004 .get("token_endpoint_auth_method") 1005 .and_then(Value::as_str) 1006 .unwrap_or("none") 1007 == "private_key_jwt"; 1008 1009 // Verify DPoP proof 1010 let dpop_thumbprint_res = verify_dpop_proof( 1011 dpop_token, 1012 "POST", 1013 &format!("https://{}/oauth/token", config.host_name), 1014 &db, 1015 None, 1016 None, 1017 ) 1018 .await?; 1019 1020 if is_confidential_client { 1021 // Rule 3: Check client authentication consistency 1022 // For confidential clients, verify client_assertion 1023 let client_assertion_type = form_data 1024 .get("client_assertion_type") 1025 .context("client_assertion_type required for private_key_jwt auth")?; 1026 let _client_assertion = form_data 1027 .get("client_assertion") 1028 .context("client_assertion required for private_key_jwt auth")?; 1029 1030 if client_assertion_type != "urn:ietf:params:oauth:client-assertion-type:jwt-bearer" { 1031 return Err(Error::with_status( 1032 StatusCode::BAD_REQUEST, 1033 anyhow!("Invalid client_assertion_type"), 1034 )); 1035 } 1036 1037 // Verify client assertion JWT 1038 // This would involve similar logic to verify_dpop_proof but for client auth 1039 // 1040 // WIP: Practically unimplemented 1041 // 1042 // TODO: Figure out how this actually works 1043 1044 // verify_client_assertion(&client, client_id, client_assertion).await?; 1045 1046 // Rule 4: Ensure DPoP and client_assertion use different keys 1047 // let client_assertion_thumbprint = calculate_client_assertion_thumbprint(client_assertion)?; 1048 // if client_assertion_thumbprint == dpop_thumbprint { 1049 // return Err(Error::with_status( 1050 // StatusCode::BAD_REQUEST, 1051 // anyhow!("DPoP proof and client assertion must use different keypairs"), 1052 // )); 1053 // } 1054 } else { 1055 // Rule 2: For public clients, check if this DPoP key has been used before 1056 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1057 let dpop_thumbprint_clone = dpop_thumbprint_res.to_owned(); 1058 let client_id_clone = client_id.to_owned(); 1059 let is_key_reused = db 1060 .get() 1061 .await 1062 .expect("Failed to get database connection") 1063 .interact(move |conn| { 1064 RefreshTokenSchema::oauth_refresh_tokens 1065 .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_clone)) 1066 .filter(RefreshTokenSchema::client_id.eq(client_id_clone)) 1067 .count() 1068 .get_result::<i64>(conn) 1069 .optional() 1070 }) 1071 .await 1072 .expect("Failed to check key usage history") 1073 .expect("Failed to check key usage history") 1074 .unwrap_or(0) 1075 > 0; 1076 1077 if is_key_reused && grant_type == "authorization_code" { 1078 return Err(Error::with_status( 1079 StatusCode::BAD_REQUEST, 1080 anyhow!("Public clients must use a new key for each token request"), 1081 )); 1082 } 1083 } 1084 1085 match grant_type.as_str() { 1086 "authorization_code" => { 1087 // Process authorization code grant 1088 let code = form_data.get("code").context("code is required")?; 1089 let code_verifier = form_data 1090 .get("code_verifier") 1091 .context("code_verifier is required")?; 1092 let redirect_uri = form_data 1093 .get("redirect_uri") 1094 .context("redirect_uri is required")?; 1095 1096 let timestamp = chrono::Utc::now().timestamp(); 1097 1098 // Retrieve and validate the authorization code 1099 use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema; 1100 #[derive(Queryable, Selectable, Serialize)] 1101 #[diesel(table_name = crate::schema::pds::oauth_authorization_codes)] 1102 #[diesel(check_for_backend(sqlite::Sqlite))] 1103 struct AuthCode { 1104 code: String, 1105 client_id: String, 1106 subject: String, 1107 code_challenge: String, 1108 code_challenge_method: String, 1109 redirect_uri: String, 1110 scope: Option<String>, 1111 created_at: i64, 1112 expires_at: i64, 1113 used: bool, 1114 } 1115 let code_clone = code.to_owned(); 1116 let client_id_clone = client_id.to_owned(); 1117 let redirect_uri_clone = redirect_uri.to_owned(); 1118 let auth_code = db 1119 .get() 1120 .await 1121 .expect("Failed to get database connection") 1122 .interact(move |conn| { 1123 AuthCodeSchema::oauth_authorization_codes 1124 .filter(AuthCodeSchema::code.eq(code_clone)) 1125 .filter(AuthCodeSchema::client_id.eq(client_id_clone)) 1126 .filter(AuthCodeSchema::redirect_uri.eq(redirect_uri_clone)) 1127 .filter(AuthCodeSchema::expires_at.gt(timestamp)) 1128 .filter(AuthCodeSchema::used.eq(false)) 1129 .first::<AuthCode>(conn) 1130 .optional() 1131 }) 1132 .await 1133 .expect("Failed to query authorization code") 1134 .expect("Failed to query authorization code") 1135 .expect("Failed to query authorization code"); 1136 1137 // Verify PKCE code challenge 1138 verify_pkce( 1139 code_verifier, 1140 &auth_code.code_challenge, 1141 &auth_code.code_challenge_method, 1142 )?; 1143 1144 // Mark the code as used 1145 let code_cloned = code.to_owned(); 1146 _ = db 1147 .get() 1148 .await 1149 .expect("Failed to get database connection") 1150 .interact(move |conn| { 1151 update(AuthCodeSchema::oauth_authorization_codes) 1152 .filter(AuthCodeSchema::code.eq(code_cloned)) 1153 .set(AuthCodeSchema::used.eq(true)) 1154 .execute(conn) 1155 }) 1156 .await 1157 .expect("Failed to mark code as used") 1158 .expect("Failed to mark code as used"); 1159 1160 // Generate tokens with appropriate lifetimes 1161 let now = chrono::Utc::now().timestamp(); 1162 1163 // Rule 5: Access token valid for short period 1164 let access_token_expires_in = 3600_i64; // 1 hour (maximum allowed) 1165 let access_token_expires_at = now.saturating_add(access_token_expires_in); 1166 1167 // Rule 11 & 12: Different refresh token lifetimes based on client type 1168 let refresh_token_expires_at = if is_confidential_client { 1169 now.saturating_add(15_552_000_i64) // 6 months for confidential clients 1170 } else { 1171 now.saturating_add(604_800_i64) // 1 week maximum for public clients 1172 }; 1173 1174 // Rule 5: Include subject claim with user DID 1175 let access_token_claims = json!({ 1176 "iss": format!("https://{}", config.host_name), 1177 "sub": auth_code.subject, // User's DID 1178 "aud": client_id, 1179 "exp": access_token_expires_at, 1180 "iat": now, 1181 "cnf": { 1182 "jkt": dpop_thumbprint_res // Rule 1: Bind to DPoP key 1183 }, 1184 "scope": auth_code.scope 1185 }); 1186 1187 let access_token = crate::auth::sign(&skey, "at+jwt", &access_token_claims) 1188 .context("failed to sign access token")?; 1189 1190 // Create refresh token with similar binding 1191 let refresh_token_claims = json!({ 1192 "iss": format!("https://{}", config.host_name), 1193 "sub": auth_code.subject, 1194 "aud": client_id, 1195 "exp": refresh_token_expires_at, 1196 "iat": now, 1197 "cnf": { 1198 "jkt": dpop_thumbprint_res // Rule 1: Bind to DPoP key 1199 }, 1200 "scope": auth_code.scope 1201 }); 1202 1203 let refresh_token = crate::auth::sign(&skey, "rt+jwt", &refresh_token_claims) 1204 .context("failed to sign refresh token")?; 1205 1206 // Store the refresh token with DPoP binding 1207 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1208 let refresh_token_cloned = refresh_token.to_owned(); 1209 let client_id_cloned = client_id.to_owned(); 1210 let subject = auth_code.subject.to_owned(); 1211 let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned(); 1212 let scope = auth_code.scope.to_owned(); 1213 let created_at = now; 1214 let expires_at = refresh_token_expires_at; 1215 _ = db 1216 .get() 1217 .await 1218 .expect("Failed to get database connection") 1219 .interact(move |conn| { 1220 insert_into(RefreshTokenSchema::oauth_refresh_tokens) 1221 .values(( 1222 RefreshTokenSchema::token.eq(refresh_token_cloned), 1223 RefreshTokenSchema::client_id.eq(client_id_cloned), 1224 RefreshTokenSchema::subject.eq(subject), 1225 RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1226 RefreshTokenSchema::scope.eq(scope), 1227 RefreshTokenSchema::created_at.eq(created_at), 1228 RefreshTokenSchema::expires_at.eq(expires_at), 1229 RefreshTokenSchema::revoked.eq(false), 1230 )) 1231 .execute(conn) 1232 }) 1233 .await 1234 .expect("Failed to store refresh token") 1235 .expect("Failed to store refresh token"); 1236 1237 // Return token response with the subject claim 1238 Ok(Json(json!({ 1239 "access_token": access_token, 1240 "token_type": "DPoP", 1241 "expires_in": access_token_expires_in, 1242 "refresh_token": refresh_token, 1243 "scope": auth_code.scope, 1244 "sub": auth_code.subject // Rule 5: Include subject claim 1245 }))) 1246 } 1247 "refresh_token" => { 1248 // Process refresh token grant 1249 let refresh_token = form_data 1250 .get("refresh_token") 1251 .context("refresh_token is required")?; 1252 1253 let timestamp = chrono::Utc::now().timestamp(); 1254 1255 // Rules 7 & 8: Verify refresh token and DPoP consistency 1256 // Retrieve the refresh token 1257 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1258 #[derive(Queryable, Selectable, Serialize)] 1259 #[diesel(table_name = crate::schema::pds::oauth_refresh_tokens)] 1260 #[diesel(check_for_backend(sqlite::Sqlite))] 1261 struct TokenData { 1262 token: String, 1263 client_id: String, 1264 subject: String, 1265 dpop_thumbprint: String, 1266 scope: Option<String>, 1267 created_at: i64, 1268 expires_at: i64, 1269 revoked: bool, 1270 } 1271 let dpop_thumbprint_clone = dpop_thumbprint_res.to_owned(); 1272 let refresh_token_clone = refresh_token.to_owned(); 1273 let client_id_clone = client_id.to_owned(); 1274 let token_data = db 1275 .get() 1276 .await 1277 .expect("Failed to get database connection") 1278 .interact(move |conn| { 1279 RefreshTokenSchema::oauth_refresh_tokens 1280 .filter(RefreshTokenSchema::token.eq(refresh_token_clone)) 1281 .filter(RefreshTokenSchema::client_id.eq(client_id_clone)) 1282 .filter(RefreshTokenSchema::expires_at.gt(timestamp)) 1283 .filter(RefreshTokenSchema::revoked.eq(false)) 1284 .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_clone)) 1285 .first::<TokenData>(conn) 1286 .optional() 1287 }) 1288 .await 1289 .expect("Failed to query refresh token") 1290 .expect("Failed to query refresh token") 1291 .expect("Failed to query refresh token"); 1292 1293 // Rule 10: For confidential clients, verify key is still advertised in their jwks 1294 if is_confidential_client { 1295 let client_still_advertises_key = true; // Implement actual check against client jwks 1296 if !client_still_advertises_key { 1297 // Revoke all tokens bound to this key 1298 let client_id_cloned = client_id.to_owned(); 1299 let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned(); 1300 _ = db 1301 .get() 1302 .await 1303 .expect("Failed to get database connection") 1304 .interact(move |conn| { 1305 update(RefreshTokenSchema::oauth_refresh_tokens) 1306 .filter(RefreshTokenSchema::client_id.eq(client_id_cloned)) 1307 .filter( 1308 RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1309 ) 1310 .set(RefreshTokenSchema::revoked.eq(true)) 1311 .execute(conn) 1312 }) 1313 .await 1314 .expect("Failed to revoke tokens") 1315 .expect("Failed to revoke tokens"); 1316 1317 return Err(Error::with_status( 1318 StatusCode::BAD_REQUEST, 1319 anyhow!("Key no longer advertised by client"), 1320 )); 1321 } 1322 } 1323 1324 // Rotate the refresh token 1325 let refresh_token_cloned = refresh_token.to_owned(); 1326 _ = db 1327 .get() 1328 .await 1329 .expect("Failed to get database connection") 1330 .interact(move |conn| { 1331 update(RefreshTokenSchema::oauth_refresh_tokens) 1332 .filter(RefreshTokenSchema::token.eq(refresh_token_cloned)) 1333 .set(RefreshTokenSchema::revoked.eq(true)) 1334 .execute(conn) 1335 }) 1336 .await 1337 .expect("Failed to revoke old refresh token") 1338 .expect("Failed to revoke old refresh token"); 1339 1340 // Generate new tokens 1341 let now = chrono::Utc::now().timestamp(); 1342 let access_token_expires_in = 3600_i64; 1343 let access_token_expires_at = now.saturating_add(access_token_expires_in); 1344 1345 // Maintain the original expiry policy for refresh tokens 1346 let original_duration = token_data.expires_at.saturating_sub(token_data.created_at); 1347 let refresh_token_expires_at = now.saturating_add(original_duration); 1348 1349 // Create access token 1350 let access_token_claims = json!({ 1351 "iss": format!("https://{}", config.host_name), 1352 "sub": token_data.subject, 1353 "aud": client_id, 1354 "exp": access_token_expires_at, 1355 "iat": now, 1356 "cnf": { 1357 "jkt": dpop_thumbprint_res 1358 }, 1359 "scope": token_data.scope 1360 }); 1361 1362 let access_token = crate::auth::sign(&skey, "at+jwt", &access_token_claims) 1363 .context("failed to sign access token")?; 1364 1365 // Create new refresh token 1366 let new_refresh_token_claims = json!({ 1367 "iss": format!("https://{}", config.host_name), 1368 "sub": token_data.subject, 1369 "aud": client_id, 1370 "exp": refresh_token_expires_at, 1371 "iat": now, 1372 "cnf": { 1373 "jkt": dpop_thumbprint_res 1374 }, 1375 "scope": token_data.scope 1376 }); 1377 1378 let new_refresh_token = crate::auth::sign(&skey, "rt+jwt", &new_refresh_token_claims) 1379 .context("failed to sign refresh token")?; 1380 1381 // Store the new refresh token 1382 let new_refresh_token_cloned = new_refresh_token.to_owned(); 1383 let client_id_cloned = client_id.to_owned(); 1384 let subject = token_data.subject.to_owned(); 1385 let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned(); 1386 let scope = token_data.scope.to_owned(); 1387 let created_at = now; 1388 let expires_at = refresh_token_expires_at; 1389 _ = db 1390 .get() 1391 .await 1392 .expect("Failed to get database connection") 1393 .interact(move |conn| { 1394 insert_into(RefreshTokenSchema::oauth_refresh_tokens) 1395 .values(( 1396 RefreshTokenSchema::token.eq(new_refresh_token_cloned), 1397 RefreshTokenSchema::client_id.eq(client_id_cloned), 1398 RefreshTokenSchema::subject.eq(subject), 1399 RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1400 RefreshTokenSchema::scope.eq(scope), 1401 RefreshTokenSchema::created_at.eq(created_at), 1402 RefreshTokenSchema::expires_at.eq(expires_at), 1403 RefreshTokenSchema::revoked.eq(false), 1404 )) 1405 .execute(conn) 1406 }) 1407 .await 1408 .expect("Failed to store refresh token") 1409 .expect("Failed to store refresh token"); 1410 1411 // Return token response 1412 Ok(Json(json!({ 1413 "access_token": access_token, 1414 "token_type": "DPoP", 1415 "expires_in": access_token_expires_in, 1416 "refresh_token": new_refresh_token, 1417 "scope": token_data.scope, 1418 "sub": token_data.subject 1419 }))) 1420 } 1421 _ => Err(Error::with_status( 1422 StatusCode::BAD_REQUEST, 1423 anyhow!("unsupported grant_type: {}", grant_type), 1424 )), 1425 } 1426} 1427 1428/// JWKS (JSON Web Key Set) endpoint 1429/// - GET `/oauth/jwks` 1430/// 1431/// Returns the server's public keys as a JWKS document 1432/// 1433/// WIP: Practically unimplemented 1434/// 1435/// TODO: Figure out if/how this actually works 1436async fn jwks(State(skey): State<SigningKey>) -> Result<Json<Value>> { 1437 let did_string = skey.did(); 1438 1439 // Extract the public key data from the DID string 1440 let (_, public_key_bytes) = 1441 atrium_crypto::did::parse_did_key(&did_string).context("failed to parse did key")?; 1442 1443 // Secp256k1 uncompressed public keys should be 65 bytes: 0x04 + x(32 bytes) + y(32 bytes) 1444 if public_key_bytes.len() != 65 || public_key_bytes.first().copied() != Some(0x04) { 1445 return Err(Error::with_status( 1446 StatusCode::INTERNAL_SERVER_ERROR, 1447 anyhow!("unexpected public key format"), 1448 )); 1449 } 1450 1451 // Extract and encode the X and Y coordinates 1452 let x_coord = public_key_bytes 1453 .get(1..33) 1454 .map(|slice| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(slice)) 1455 .context("failed to extract X coordinate")?; 1456 1457 let y_coord = public_key_bytes 1458 .get(33..65) 1459 .map(|slice| base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(slice)) 1460 .context("failed to extract Y coordinate")?; 1461 1462 // Create a stable key ID based on the DID 1463 let key_id = did_string.strip_prefix("did:key:").unwrap_or(&did_string); 1464 1465 let jwk = json!({ 1466 "kty": "EC", 1467 "crv": "secp256k1", 1468 "kid": key_id, 1469 "use": "sig", 1470 "alg": "ES256K", 1471 "x": x_coord, 1472 "y": y_coord 1473 }); 1474 1475 // Return the JWKS document 1476 Ok(Json(json!({ 1477 "keys": [jwk] 1478 }))) 1479} 1480 1481/// Token Revocation endpoint 1482/// - POST `/oauth/revoke` 1483/// 1484/// Implements RFC7009 for revoking refresh tokens 1485async fn revoke( 1486 State(db): State<Pool>, 1487 Json(form_data): Json<HashMap<String, String>>, 1488) -> Result<Json<Value>> { 1489 // Extract required parameters 1490 let token = form_data.get("token").context("token is required")?; 1491 let refresh_token_string = "refresh_token".to_owned(); 1492 let token_type_hint = form_data 1493 .get("token_type_hint") 1494 .unwrap_or(&refresh_token_string); 1495 1496 // We only support revoking refresh tokens 1497 if token_type_hint != "refresh_token" { 1498 return Err(Error::with_status( 1499 StatusCode::BAD_REQUEST, 1500 anyhow!("unsupported token_type_hint: {}", token_type_hint), 1501 )); 1502 } 1503 1504 // Revoke the token 1505 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1506 let token_cloned = token.to_owned(); 1507 _ = db 1508 .get() 1509 .await 1510 .expect("Failed to get database connection") 1511 .interact(move |conn| { 1512 update(RefreshTokenSchema::oauth_refresh_tokens) 1513 .filter(RefreshTokenSchema::token.eq(token_cloned)) 1514 .set(RefreshTokenSchema::revoked.eq(true)) 1515 .execute(conn) 1516 }) 1517 .await 1518 .expect("Failed to revoke token") 1519 .expect("Failed to revoke token"); 1520 1521 // RFC7009 requires a 200 OK with an empty response 1522 Ok(Json(json!({}))) 1523} 1524 1525/// Token Introspection endpoint 1526/// - POST `/oauth/introspect` 1527/// 1528/// Implements RFC7662 for introspecting tokens 1529async fn introspect( 1530 State(db): State<Pool>, 1531 State(skey): State<SigningKey>, 1532 Json(form_data): Json<HashMap<String, String>>, 1533) -> Result<Json<Value>> { 1534 // Extract required parameters 1535 let token = form_data.get("token").context("token is required")?; 1536 let token_type_hint = form_data.get("token_type_hint"); 1537 1538 // Parse the token 1539 let Ok((typ, claims)) = crate::auth::verify(&skey.did(), token) else { 1540 // Per RFC7662, invalid tokens return { "active": false } 1541 return Ok(Json(json!({"active": false}))); 1542 }; 1543 1544 // Check token type 1545 let is_refresh_token = typ == "rt+jwt"; 1546 let is_access_token = typ == "at+jwt"; 1547 1548 if !is_access_token && !is_refresh_token { 1549 return Ok(Json(json!({"active": false}))); 1550 } 1551 1552 // If token_type_hint is provided, verify it matches 1553 if let Some(hint) = token_type_hint { 1554 if (hint == "refresh_token" && !is_refresh_token) 1555 || (hint == "access_token" && !is_access_token) 1556 { 1557 return Ok(Json(json!({"active": false}))); 1558 } 1559 } 1560 1561 // Check expiration 1562 if let Some(exp) = claims.get("exp").and_then(Value::as_i64) { 1563 let now = chrono::Utc::now().timestamp(); 1564 if now >= exp { 1565 return Ok(Json(json!({"active": false}))); 1566 } 1567 } else { 1568 return Ok(Json(json!({"active": false}))); 1569 } 1570 1571 // For refresh tokens, check if it's been revoked 1572 if is_refresh_token { 1573 use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1574 let token_cloned = token.to_owned(); 1575 let is_revoked = db 1576 .get() 1577 .await 1578 .expect("Failed to get database connection") 1579 .interact(move |conn| { 1580 RefreshTokenSchema::oauth_refresh_tokens 1581 .filter(RefreshTokenSchema::token.eq(token_cloned)) 1582 .select(RefreshTokenSchema::revoked) 1583 .first::<bool>(conn) 1584 .optional() 1585 }) 1586 .await 1587 .expect("Failed to query token") 1588 .expect("Failed to query token") 1589 .unwrap_or(true); 1590 1591 if is_revoked { 1592 return Ok(Json(json!({"active": false}))); 1593 } 1594 } 1595 1596 // Token is valid, return introspection info 1597 let subject = claims.get("sub").and_then(Value::as_str); 1598 let client_id = claims.get("aud").and_then(Value::as_str); 1599 let scope = claims.get("scope").and_then(Value::as_str); 1600 let expiration = claims.get("exp").and_then(Value::as_i64); 1601 let issued_at = claims.get("iat").and_then(Value::as_i64); 1602 1603 Ok(Json(json!({ 1604 "active": true, 1605 "sub": subject, 1606 "client_id": client_id, 1607 "scope": scope, 1608 "exp": expiration, 1609 "iat": issued_at, 1610 "token_type": if is_access_token { "access_token" } else { "refresh_token" } 1611 }))) 1612} 1613 1614/// Register OAuth routes 1615pub(crate) fn routes() -> Router<AppState> { 1616 Router::new() 1617 .route( 1618 "/.well-known/oauth-protected-resource", 1619 get(protected_resource), 1620 ) 1621 .route( 1622 "/.well-known/oauth-authorization-server", 1623 get(authorization_server), 1624 ) 1625 .route("/oauth/par", post(par)) 1626 .route("/oauth/authorize", get(authorize)) 1627 .route("/oauth/authorize/sign-in", post(authorize_signin)) 1628 .route("/oauth/token", post(token)) 1629 .route("/oauth/jwks", get(jwks)) 1630 .route("/oauth/revoke", post(revoke)) 1631 .route("/oauth/introspect", post(introspect)) 1632}