Alternative ATProto PDS implementation

prototype oauth sqlx to diesel conversion

Changed files
+690 -234
src
+690 -234
src/oauth.rs
··· 1 1 //! OAuth endpoints 2 2 3 3 use crate::metrics::AUTH_FAILED; 4 - use crate::{AppConfig, AppState, Client, Db, Error, Result, SigningKey}; 4 + use crate::{AppConfig, AppState, Client, Error, Result, SigningKey}; 5 5 use anyhow::{Context as _, anyhow}; 6 6 use argon2::{Argon2, PasswordHash, PasswordVerifier as _}; 7 7 use atrium_crypto::keypair::Did as _; ··· 14 14 routing::{get, post}, 15 15 }; 16 16 use base64::Engine as _; 17 + use deadpool_diesel::sqlite::Pool; 18 + use diesel::*; 17 19 use metrics::counter; 18 20 use rand::distributions::Alphanumeric; 19 21 use rand::{Rng as _, thread_rng}; ··· 252 254 /// POST `/oauth/par` 253 255 #[expect(clippy::too_many_lines)] 254 256 async fn par( 255 - State(db): State<Db>, 257 + State(db): State<Pool>, 256 258 State(client): State<Client>, 257 259 Json(form_data): Json<HashMap<String, String>>, 258 260 ) -> Result<Json<Value>> { ··· 357 359 .context("failed to compute expiration time")? 358 360 .timestamp(); 359 361 360 - _ = sqlx::query!( 361 - r#" 362 - INSERT INTO oauth_par_requests ( 363 - request_uri, client_id, response_type, code_challenge, code_challenge_method, 364 - state, login_hint, scope, redirect_uri, response_mode, display, 365 - created_at, expires_at 366 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 367 - "#, 368 - request_uri, 369 - client_id, 370 - response_type, 371 - code_challenge, 372 - code_challenge_method, 373 - state, 374 - login_hint, 375 - scope, 376 - redirect_uri, 377 - response_mode, 378 - display, 379 - created_at, 380 - expires_at 381 - ) 382 - .execute(&db) 383 - .await 384 - .context("failed to store PAR request")?; 362 + // _ = sqlx::query!( 363 + // r#" 364 + // INSERT INTO oauth_par_requests ( 365 + // request_uri, client_id, response_type, code_challenge, code_challenge_method, 366 + // state, login_hint, scope, redirect_uri, response_mode, display, 367 + // created_at, expires_at 368 + // ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 369 + // "#, 370 + // request_uri, 371 + // client_id, 372 + // response_type, 373 + // code_challenge, 374 + // code_challenge_method, 375 + // state, 376 + // login_hint, 377 + // scope, 378 + // redirect_uri, 379 + // response_mode, 380 + // display, 381 + // created_at, 382 + // expires_at 383 + // ) 384 + // .execute(&db) 385 + // .await 386 + // .context("failed to store PAR request")?; 387 + use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 388 + let client_id = client_id.to_owned(); 389 + let request_uri_cloned = request_uri.to_owned(); 390 + let response_type = response_type.to_owned(); 391 + let code_challenge = code_challenge.to_owned(); 392 + let code_challenge_method = code_challenge_method.to_owned(); 393 + let state = state.map(|s| s.to_owned()); 394 + let login_hint = login_hint.map(|s| s.to_owned()); 395 + let scope = scope.map(|s| s.to_owned()); 396 + let redirect_uri = redirect_uri.map(|s| s.to_owned()); 397 + let response_mode = response_mode.map(|s| s.to_owned()); 398 + let display = display.map(|s| s.to_owned()); 399 + let created_at = created_at; 400 + let expires_at = expires_at; 401 + _ = db 402 + .get() 403 + .await 404 + .expect("Failed to get database connection") 405 + .interact(move |conn| { 406 + insert_into(ParRequestSchema::oauth_par_requests) 407 + .values(( 408 + ParRequestSchema::request_uri.eq(&request_uri_cloned), 409 + ParRequestSchema::client_id.eq(client_id), 410 + ParRequestSchema::response_type.eq(response_type), 411 + ParRequestSchema::code_challenge.eq(code_challenge), 412 + ParRequestSchema::code_challenge_method.eq(code_challenge_method), 413 + ParRequestSchema::state.eq(state), 414 + ParRequestSchema::login_hint.eq(login_hint), 415 + ParRequestSchema::scope.eq(scope), 416 + ParRequestSchema::redirect_uri.eq(redirect_uri), 417 + ParRequestSchema::response_mode.eq(response_mode), 418 + ParRequestSchema::display.eq(display), 419 + ParRequestSchema::created_at.eq(created_at), 420 + ParRequestSchema::expires_at.eq(expires_at), 421 + )) 422 + .execute(conn) 423 + }) 424 + .await 425 + .expect("Failed to store PAR request") 426 + .expect("Failed to store PAR request"); 385 427 386 428 Ok(Json(json!({ 387 429 "request_uri": request_uri, ··· 392 434 /// OAuth Authorization endpoint 393 435 /// GET `/oauth/authorize` 394 436 async fn authorize( 395 - State(db): State<Db>, 437 + State(db): State<Pool>, 396 438 State(client): State<Client>, 397 439 Query(params): Query<HashMap<String, String>>, 398 440 ) -> Result<impl IntoResponse> { ··· 407 449 let timestamp = chrono::Utc::now().timestamp(); 408 450 409 451 // Retrieve the PAR request from the database 410 - let par_request = sqlx::query!( 411 - r#" 412 - SELECT * FROM oauth_par_requests 413 - WHERE request_uri = ? AND client_id = ? AND expires_at > ? 414 - "#, 415 - request_uri, 416 - client_id, 417 - timestamp 418 - ) 419 - .fetch_optional(&db) 420 - .await 421 - .context("failed to query PAR request")? 422 - .context("PAR request not found or expired")?; 452 + // let par_request = sqlx::query!( 453 + // r#" 454 + // SELECT * FROM oauth_par_requests 455 + // WHERE request_uri = ? AND client_id = ? AND expires_at > ? 456 + // "#, 457 + // request_uri, 458 + // client_id, 459 + // timestamp 460 + // ) 461 + // .fetch_optional(&db) 462 + // .await 463 + // .context("failed to query PAR request")? 464 + // .context("PAR request not found or expired")?; 465 + use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 466 + 467 + let request_uri_clone = request_uri.to_owned(); 468 + let client_id_clone = client_id.to_owned(); 469 + let timestamp_clone = timestamp.clone(); 470 + let login_hint = db 471 + .get() 472 + .await 473 + .expect("Failed to get database connection") 474 + .interact(move |conn| { 475 + ParRequestSchema::oauth_par_requests 476 + .select(ParRequestSchema::login_hint) 477 + .filter(ParRequestSchema::request_uri.eq(request_uri_clone)) 478 + .filter(ParRequestSchema::client_id.eq(client_id_clone)) 479 + .filter(ParRequestSchema::expires_at.gt(timestamp_clone)) 480 + .first::<Option<String>>(conn) 481 + .optional() 482 + }) 483 + .await 484 + .expect("Failed to query PAR request") 485 + .expect("Failed to query PAR request") 486 + .expect("Failed to query PAR request"); 423 487 424 488 // Validate client metadata 425 489 let client_metadata = fetch_client_metadata(&client, client_id).await?; 426 490 427 491 // Authorization page with login form 428 - let login_hint = par_request.login_hint.unwrap_or_default(); 492 + let login_hint = login_hint.unwrap_or_default(); 429 493 let html = format!( 430 494 r#"<!DOCTYPE html> 431 495 <html> ··· 491 555 /// POST `/oauth/authorize/sign-in` 492 556 #[expect(clippy::too_many_lines)] 493 557 async fn authorize_signin( 494 - State(db): State<Db>, 558 + State(db): State<Pool>, 495 559 State(config): State<AppConfig>, 496 560 State(client): State<Client>, 497 561 extract::Form(form_data): extract::Form<HashMap<String, String>>, ··· 511 575 let timestamp = chrono::Utc::now().timestamp(); 512 576 513 577 // Retrieve the PAR request 514 - let par_request = sqlx::query!( 515 - r#" 516 - SELECT * FROM oauth_par_requests 517 - WHERE request_uri = ? AND client_id = ? AND expires_at > ? 518 - "#, 519 - request_uri, 520 - client_id, 521 - timestamp 522 - ) 523 - .fetch_optional(&db) 524 - .await 525 - .context("failed to query PAR request")? 526 - .context("PAR request not found or expired")?; 578 + // let par_request = sqlx::query!( 579 + // r#" 580 + // SELECT * FROM oauth_par_requests 581 + // WHERE request_uri = ? AND client_id = ? AND expires_at > ? 582 + // "#, 583 + // request_uri, 584 + // client_id, 585 + // timestamp 586 + // ) 587 + // .fetch_optional(&db) 588 + // .await 589 + // .context("failed to query PAR request")? 590 + // .context("PAR request not found or expired")?; 591 + use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 592 + // diesel::table! { 593 + // pds.oauth_par_requests (request_uri) { 594 + // request_uri -> Varchar, 595 + // client_id -> Varchar, 596 + // response_type -> Varchar, 597 + // code_challenge -> Varchar, 598 + // code_challenge_method -> Varchar, 599 + // state -> Nullable<Varchar>, 600 + // login_hint -> Nullable<Varchar>, 601 + // scope -> Nullable<Varchar>, 602 + // redirect_uri -> Nullable<Varchar>, 603 + // response_mode -> Nullable<Varchar>, 604 + // display -> Nullable<Varchar>, 605 + // created_at -> Int8, 606 + // expires_at -> Int8, 607 + // } 608 + // } 609 + #[derive(Queryable, Selectable)] 610 + #[diesel(table_name = crate::schema::pds::oauth_par_requests)] 611 + #[diesel(check_for_backend(sqlite::Sqlite))] 612 + struct ParRequest { 613 + request_uri: String, 614 + client_id: String, 615 + response_type: String, 616 + code_challenge: String, 617 + code_challenge_method: String, 618 + state: Option<String>, 619 + login_hint: Option<String>, 620 + scope: Option<String>, 621 + redirect_uri: Option<String>, 622 + response_mode: Option<String>, 623 + display: Option<String>, 624 + created_at: i64, 625 + expires_at: i64, 626 + } 627 + let request_uri_clone = request_uri.to_owned(); 628 + let client_id_clone = client_id.to_owned(); 629 + let timestamp_clone = timestamp.clone(); 630 + let par_request = db 631 + .get() 632 + .await 633 + .expect("Failed to get database connection") 634 + .interact(move |conn| { 635 + ParRequestSchema::oauth_par_requests 636 + .filter(ParRequestSchema::request_uri.eq(request_uri_clone)) 637 + .filter(ParRequestSchema::client_id.eq(client_id_clone)) 638 + .filter(ParRequestSchema::expires_at.gt(timestamp_clone)) 639 + .first::<ParRequest>(conn) 640 + .optional() 641 + }) 642 + .await 643 + .expect("Failed to query PAR request") 644 + .expect("Failed to query PAR request") 645 + .expect("Failed to query PAR request"); 527 646 528 647 // Authenticate the user 529 - let account = sqlx::query!( 530 - r#" 531 - WITH LatestHandles AS ( 532 - SELECT did, handle 533 - FROM handles 534 - WHERE (did, created_at) IN ( 535 - SELECT did, MAX(created_at) AS max_created_at 536 - FROM handles 537 - GROUP BY did 538 - ) 539 - ) 540 - SELECT a.did, a.email, a.password, h.handle 541 - FROM accounts a 542 - LEFT JOIN LatestHandles h ON a.did = h.did 543 - WHERE h.handle = ? 544 - "#, 545 - username 546 - ) 547 - .fetch_optional(&db) 548 - .await 549 - .context("failed to query database")? 550 - .context("user not found")?; 648 + use rsky_pds::schema::pds::account::dsl as AccountSchema; 649 + use rsky_pds::schema::pds::actor::dsl as ActorSchema; 650 + let username_clone = username.to_owned(); 651 + let account = db 652 + .get() 653 + .await 654 + .expect("Failed to get database connection") 655 + .interact(move |conn| { 656 + AccountSchema::account 657 + .filter(AccountSchema::email.eq(username_clone)) 658 + .first::<rsky_pds::models::Account>(conn) 659 + .optional() 660 + }) 661 + .await 662 + .expect("Failed to query account") 663 + .expect("Failed to query account") 664 + .expect("Failed to query account"); 665 + // let actor = db 666 + // .get() 667 + // .await 668 + // .expect("Failed to get database connection") 669 + // .interact(move |conn| { 670 + // ActorSchema::actor 671 + // .filter(ActorSchema::did.eq(did)) 672 + // .first::<rsky_pds::models::Actor>(conn) 673 + // .optional() 674 + // }) 675 + // .await 676 + // .expect("Failed to query actor") 677 + // .expect("Failed to query actor") 678 + // .expect("Failed to query actor"); 551 679 552 680 // Verify password - fixed to use equality check instead of pattern matching 553 681 if Argon2::default().verify_password( ··· 592 720 .context("failed to compute expiration time")? 593 721 .timestamp(); 594 722 595 - _ = sqlx::query!( 596 - r#" 597 - INSERT INTO oauth_authorization_codes ( 598 - code, client_id, subject, code_challenge, code_challenge_method, 599 - redirect_uri, scope, created_at, expires_at, used 600 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 601 - "#, 602 - code, 603 - client_id, 604 - account.did, 605 - par_request.code_challenge, 606 - par_request.code_challenge_method, 607 - redirect_uri, 608 - par_request.scope, 609 - created_at, 610 - expires_at, 611 - false 612 - ) 613 - .execute(&db) 614 - .await 615 - .context("failed to store authorization code")?; 723 + // _ = sqlx::query!( 724 + // r#" 725 + // INSERT INTO oauth_authorization_codes ( 726 + // code, client_id, subject, code_challenge, code_challenge_method, 727 + // redirect_uri, scope, created_at, expires_at, used 728 + // ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 729 + // "#, 730 + // code, 731 + // client_id, 732 + // account.did, 733 + // par_request.code_challenge, 734 + // par_request.code_challenge_method, 735 + // redirect_uri, 736 + // par_request.scope, 737 + // created_at, 738 + // expires_at, 739 + // false 740 + // ) 741 + // .execute(&db) 742 + // .await 743 + // .context("failed to store authorization code")?; 744 + use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema; 745 + let code_cloned = code.to_owned(); 746 + let client_id = client_id.to_owned(); 747 + let subject = account.did.to_owned(); 748 + let code_challenge = par_request.code_challenge.to_owned(); 749 + let code_challenge_method = par_request.code_challenge_method.to_owned(); 750 + let redirect_uri_cloned = redirect_uri.to_owned(); 751 + let scope = par_request.scope.to_owned(); 752 + let used = false; 753 + _ = db 754 + .get() 755 + .await 756 + .expect("Failed to get database connection") 757 + .interact(move |conn| { 758 + insert_into(AuthCodeSchema::oauth_authorization_codes) 759 + .values(( 760 + AuthCodeSchema::code.eq(code_cloned), 761 + AuthCodeSchema::client_id.eq(client_id), 762 + AuthCodeSchema::subject.eq(subject), 763 + AuthCodeSchema::code_challenge.eq(code_challenge), 764 + AuthCodeSchema::code_challenge_method.eq(code_challenge_method), 765 + AuthCodeSchema::redirect_uri.eq(redirect_uri_cloned), 766 + AuthCodeSchema::scope.eq(scope), 767 + AuthCodeSchema::created_at.eq(created_at), 768 + AuthCodeSchema::expires_at.eq(expires_at), 769 + AuthCodeSchema::used.eq(used), 770 + )) 771 + .execute(conn) 772 + }) 773 + .await 774 + .expect("Failed to store authorization code") 775 + .expect("Failed to store authorization code"); 616 776 617 777 // Use state from the PAR request or generate one 618 778 let state = par_request.state.unwrap_or_else(|| { ··· 673 833 dpop_token: &str, 674 834 http_method: &str, 675 835 http_uri: &str, 676 - db: &Db, 836 + db: &Pool, 677 837 access_token: Option<&str>, 678 838 bound_key_thumbprint: Option<&str>, 679 839 ) -> Result<String> { ··· 811 971 } 812 972 813 973 // 11. Check for replay attacks via JTI tracking 814 - let jti_used = 815 - sqlx::query_scalar!(r#"SELECT COUNT(*) FROM oauth_used_jtis WHERE jti = ?"#, jti) 816 - .fetch_one(db) 817 - .await 818 - .context("failed to check JTI")?; 974 + // let jti_used = 975 + // sqlx::query_scalar!(r#"SELECT COUNT(*) FROM oauth_used_jtis WHERE jti = ?"#, jti) 976 + // .fetch_one(db) 977 + // .await 978 + // .context("failed to check JTI")?; 979 + use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema; 980 + let jti_clone = jti.to_owned(); 981 + let jti_used = db 982 + .get() 983 + .await 984 + .expect("Failed to get database connection") 985 + .interact(move |conn| { 986 + JtiSchema::oauth_used_jtis 987 + .filter(JtiSchema::jti.eq(jti_clone)) 988 + .count() 989 + .get_result::<i64>(conn) 990 + .optional() 991 + }) 992 + .await 993 + .expect("Failed to check JTI") 994 + .expect("Failed to check JTI") 995 + .unwrap_or(0); 819 996 820 997 if jti_used > 0 { 821 998 return Err(Error::with_status( ··· 825 1002 } 826 1003 827 1004 // 12. Store the JTI to prevent replay attacks 828 - _ = sqlx::query!( 829 - r#" 830 - INSERT INTO oauth_used_jtis (jti, issuer, created_at, expires_at) 831 - VALUES (?, ?, ?, ?) 832 - "#, 833 - jti, 834 - thumbprint, // Use thumbprint as issuer identifier 835 - now, 836 - exp 837 - ) 838 - .execute(db) 839 - .await 840 - .context("failed to store JTI")?; 1005 + // _ = sqlx::query!( 1006 + // r#" 1007 + // INSERT INTO oauth_used_jtis (jti, issuer, created_at, expires_at) 1008 + // VALUES (?, ?, ?, ?) 1009 + // "#, 1010 + // jti, 1011 + // thumbprint, // Use thumbprint as issuer identifier 1012 + // now, 1013 + // exp 1014 + // ) 1015 + // .execute(db) 1016 + // .await 1017 + // .context("failed to store JTI")?; 1018 + let jti_cloned = jti.to_owned(); 1019 + let issuer = thumbprint.to_owned(); 1020 + let created_at = now; 1021 + let expires_at = exp; 1022 + _ = db 1023 + .get() 1024 + .await 1025 + .expect("Failed to get database connection") 1026 + .interact(move |conn| { 1027 + insert_into(JtiSchema::oauth_used_jtis) 1028 + .values(( 1029 + JtiSchema::jti.eq(jti_cloned), 1030 + JtiSchema::issuer.eq(issuer), 1031 + JtiSchema::created_at.eq(created_at), 1032 + JtiSchema::expires_at.eq(expires_at), 1033 + )) 1034 + .execute(conn) 1035 + }) 1036 + .await 1037 + .expect("Failed to store JTI") 1038 + .expect("Failed to store JTI"); 841 1039 842 1040 // 13. Cleanup expired JTIs periodically (1% chance on each request) 843 1041 if thread_rng().gen_range(0_i32..100_i32) == 0_i32 { 844 - _ = sqlx::query!(r#"DELETE FROM oauth_used_jtis WHERE expires_at < ?"#, now) 845 - .execute(db) 1042 + // _ = sqlx::query!(r#"DELETE FROM oauth_used_jtis WHERE expires_at < ?"#, now) 1043 + // .execute(db) 1044 + // .await 1045 + // .context("failed to clean up expired JTIs")?; 1046 + let now_clone = now.to_owned(); 1047 + _ = db 1048 + .get() 846 1049 .await 847 - .context("failed to clean up expired JTIs")?; 1050 + .expect("Failed to get database connection") 1051 + .interact(move |conn| { 1052 + delete(JtiSchema::oauth_used_jtis) 1053 + .filter(JtiSchema::expires_at.lt(now_clone)) 1054 + .execute(conn) 1055 + }) 1056 + .await 1057 + .expect("Failed to clean up expired JTIs") 1058 + .expect("Failed to clean up expired JTIs"); 848 1059 } 849 1060 850 1061 Ok(thumbprint) ··· 882 1093 /// Handles both `authorization_code` and `refresh_token` grants 883 1094 #[expect(clippy::too_many_lines)] 884 1095 async fn token( 885 - State(db): State<Db>, 1096 + State(db): State<Pool>, 886 1097 State(skey): State<SigningKey>, 887 1098 State(config): State<AppConfig>, 888 1099 State(client): State<Client>, ··· 959 1170 // } 960 1171 } else { 961 1172 // Rule 2: For public clients, check if this DPoP key has been used before 962 - let is_key_reused = sqlx::query_scalar!( 963 - r#"SELECT COUNT(*) FROM oauth_refresh_tokens WHERE dpop_thumbprint = ? AND client_id = ?"#, 964 - dpop_thumbprint, 965 - client_id 966 - ) 967 - .fetch_one(&db) 968 - .await 969 - .context("failed to check key usage history")? > 0; 1173 + // let is_key_reused = sqlx::query_scalar!( 1174 + // r#"SELECT COUNT(*) FROM oauth_refresh_tokens WHERE dpop_thumbprint = ? AND client_id = ?"#, 1175 + // dpop_thumbprint, 1176 + // client_id 1177 + // ) 1178 + // .fetch_one(&db) 1179 + // .await 1180 + // .context("failed to check key usage history")? > 0; 1181 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1182 + let is_key_reused = db 1183 + .get() 1184 + .await 1185 + .expect("Failed to get database connection") 1186 + .interact(move |conn| { 1187 + RefreshTokenSchema::oauth_refresh_tokens 1188 + .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint)) 1189 + .filter(RefreshTokenSchema::client_id.eq(client_id)) 1190 + .count() 1191 + .get_result::<i64>(conn) 1192 + .optional() 1193 + }) 1194 + .await 1195 + .expect("Failed to check key usage history") 1196 + .expect("Failed to check key usage history") 1197 + .unwrap_or(0) 1198 + > 0; 970 1199 971 1200 if is_key_reused && grant_type == "authorization_code" { 972 1201 return Err(Error::with_status( ··· 990 1219 let timestamp = chrono::Utc::now().timestamp(); 991 1220 992 1221 // Retrieve and validate the authorization code 993 - let auth_code = sqlx::query!( 994 - r#" 995 - SELECT * FROM oauth_authorization_codes 996 - WHERE code = ? AND client_id = ? AND redirect_uri = ? AND expires_at > ? AND used = FALSE 997 - "#, 998 - code, 999 - client_id, 1000 - redirect_uri, 1001 - timestamp 1002 - ) 1003 - .fetch_optional(&db) 1004 - .await 1005 - .context("failed to query authorization code")? 1006 - .context("authorization code not found, expired, or already used")?; 1222 + // let auth_code = sqlx::query!( 1223 + // r#" 1224 + // SELECT * FROM oauth_authorization_codes 1225 + // WHERE code = ? AND client_id = ? AND redirect_uri = ? AND expires_at > ? AND used = FALSE 1226 + // "#, 1227 + // code, 1228 + // client_id, 1229 + // redirect_uri, 1230 + // timestamp 1231 + // ) 1232 + // .fetch_optional(&db) 1233 + // .await 1234 + // .context("failed to query authorization code")? 1235 + // .context("authorization code not found, expired, or already used")?; 1236 + use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema; 1237 + // diesel::table! { 1238 + // pds.oauth_authorization_codes (code) { 1239 + // code -> Varchar, 1240 + // client_id -> Varchar, 1241 + // subject -> Varchar, 1242 + // code_challenge -> Varchar, 1243 + // code_challenge_method -> Varchar, 1244 + // redirect_uri -> Varchar, 1245 + // scope -> Nullable<Varchar>, 1246 + // created_at -> Int8, 1247 + // expires_at -> Int8, 1248 + // used -> Bool, 1249 + // } 1250 + // } 1251 + #[derive(Queryable, Selectable)] 1252 + #[diesel(table_name = crate::schema::pds::oauth_authorization_codes)] 1253 + #[diesel(check_for_backend(sqlite::Sqlite))] 1254 + struct AuthCode { 1255 + code: String, 1256 + client_id: String, 1257 + subject: String, 1258 + code_challenge: String, 1259 + code_challenge_method: String, 1260 + redirect_uri: String, 1261 + scope: Option<String>, 1262 + created_at: i64, 1263 + expires_at: i64, 1264 + used: bool, 1265 + } 1266 + let auth_code = db 1267 + .get() 1268 + .await 1269 + .expect("Failed to get database connection") 1270 + .interact(move |conn| { 1271 + AuthCodeSchema::oauth_authorization_codes 1272 + .filter(AuthCodeSchema::code.eq(code)) 1273 + .filter(AuthCodeSchema::client_id.eq(client_id)) 1274 + .filter(AuthCodeSchema::redirect_uri.eq(redirect_uri)) 1275 + .filter(AuthCodeSchema::expires_at.gt(timestamp)) 1276 + .filter(AuthCodeSchema::used.eq(false)) 1277 + .first::<AuthCode>(conn) 1278 + .optional() 1279 + }) 1280 + .await 1281 + .expect("Failed to query authorization code") 1282 + .expect("Failed to query authorization code") 1283 + .expect("Failed to query authorization code"); 1007 1284 1008 1285 // Verify PKCE code challenge 1009 1286 verify_pkce( ··· 1013 1290 )?; 1014 1291 1015 1292 // Mark the code as used 1016 - _ = sqlx::query!( 1017 - r#"UPDATE oauth_authorization_codes SET used = TRUE WHERE code = ?"#, 1018 - code 1019 - ) 1020 - .execute(&db) 1021 - .await 1022 - .context("failed to mark code as used")?; 1293 + // _ = sqlx::query!( 1294 + // r#"UPDATE oauth_authorization_codes SET used = TRUE WHERE code = ?"#, 1295 + // code 1296 + // ) 1297 + // .execute(&db) 1298 + // .await 1299 + // .context("failed to mark code as used")?; 1300 + let code_cloned = code.to_owned(); 1301 + _ = db 1302 + .get() 1303 + .await 1304 + .expect("Failed to get database connection") 1305 + .interact(move |conn| { 1306 + update(AuthCodeSchema::oauth_authorization_codes) 1307 + .filter(AuthCodeSchema::code.eq(code_cloned)) 1308 + .set(AuthCodeSchema::used.eq(true)) 1309 + .execute(conn) 1310 + }) 1311 + .await 1312 + .expect("Failed to mark code as used") 1313 + .expect("Failed to mark code as used"); 1023 1314 1024 1315 // Generate tokens with appropriate lifetimes 1025 1316 let now = chrono::Utc::now().timestamp(); ··· 1068 1359 .context("failed to sign refresh token")?; 1069 1360 1070 1361 // Store the refresh token with DPoP binding 1071 - _ = sqlx::query!( 1072 - r#" 1073 - INSERT INTO oauth_refresh_tokens ( 1074 - token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked 1075 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 1076 - "#, 1077 - refresh_token, 1078 - client_id, 1079 - auth_code.subject, 1080 - dpop_thumbprint, 1081 - auth_code.scope, 1082 - now, 1083 - refresh_token_expires_at, 1084 - false 1085 - ) 1086 - .execute(&db) 1087 - .await 1088 - .context("failed to store refresh token")?; 1362 + // _ = sqlx::query!( 1363 + // r#" 1364 + // INSERT INTO oauth_refresh_tokens ( 1365 + // token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked 1366 + // ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 1367 + // "#, 1368 + // refresh_token, 1369 + // client_id, 1370 + // auth_code.subject, 1371 + // dpop_thumbprint, 1372 + // auth_code.scope, 1373 + // now, 1374 + // refresh_token_expires_at, 1375 + // false 1376 + // ) 1377 + // .execute(&db) 1378 + // .await 1379 + // .context("failed to store refresh token")?; 1380 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1381 + let refresh_token_cloned = refresh_token.to_owned(); 1382 + let client_id_cloned = client_id.to_owned(); 1383 + let subject = auth_code.subject.to_owned(); 1384 + let dpop_thumbprint_cloned = dpop_thumbprint.to_owned(); 1385 + let scope = auth_code.scope.to_owned(); 1386 + let created_at = now; 1387 + let expires_at = refresh_token_expires_at; 1388 + _ = db 1389 + .get() 1390 + .await 1391 + .expect("Failed to get database connection") 1392 + .interact(move |conn| { 1393 + insert_into(RefreshTokenSchema::oauth_refresh_tokens) 1394 + .values(( 1395 + RefreshTokenSchema::token.eq(refresh_token_cloned), 1396 + RefreshTokenSchema::client_id.eq(client_id_cloned), 1397 + RefreshTokenSchema::subject.eq(subject), 1398 + RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1399 + RefreshTokenSchema::scope.eq(scope), 1400 + RefreshTokenSchema::created_at.eq(created_at), 1401 + RefreshTokenSchema::expires_at.eq(expires_at), 1402 + RefreshTokenSchema::revoked.eq(false), 1403 + )) 1404 + .execute(conn) 1405 + }) 1406 + .await 1407 + .expect("Failed to store refresh token") 1408 + .expect("Failed to store refresh token"); 1089 1409 1090 1410 // Return token response with the subject claim 1091 1411 Ok(Json(json!({ ··· 1107 1427 1108 1428 // Rules 7 & 8: Verify refresh token and DPoP consistency 1109 1429 // Retrieve the refresh token 1110 - let token_data = sqlx::query!( 1111 - r#" 1112 - SELECT * FROM oauth_refresh_tokens 1113 - WHERE token = ? AND client_id = ? AND expires_at > ? AND revoked = FALSE AND dpop_thumbprint = ? 1114 - "#, 1115 - refresh_token, 1116 - client_id, 1117 - timestamp, 1118 - dpop_thumbprint // Rule 8: Must use same DPoP key 1119 - ) 1120 - .fetch_optional(&db) 1121 - .await 1122 - .context("failed to query refresh token")? 1123 - .context("refresh token not found, expired, revoked, or invalid for this DPoP key")?; 1430 + // let token_data = sqlx::query!( 1431 + // r#" 1432 + // SELECT * FROM oauth_refresh_tokens 1433 + // WHERE token = ? AND client_id = ? AND expires_at > ? AND revoked = FALSE AND dpop_thumbprint = ? 1434 + // "#, 1435 + // refresh_token, 1436 + // client_id, 1437 + // timestamp, 1438 + // dpop_thumbprint // Rule 8: Must use same DPoP key 1439 + // ) 1440 + // .fetch_optional(&db) 1441 + // .await 1442 + // .context("failed to query refresh token")? 1443 + // .context("refresh token not found, expired, revoked, or invalid for this DPoP key")?; 1444 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1445 + // diesel::table! { 1446 + // pds.oauth_refresh_tokens (token) { 1447 + // token -> Varchar, 1448 + // client_id -> Varchar, 1449 + // subject -> Varchar, 1450 + // dpop_thumbprint -> Varchar, 1451 + // scope -> Nullable<Varchar>, 1452 + // created_at -> Int8, 1453 + // expires_at -> Int8, 1454 + // revoked -> Bool, 1455 + // } 1456 + // } 1457 + #[derive(Queryable, Selectable)] 1458 + #[diesel(table_name = crate::schema::pds::oauth_refresh_tokens)] 1459 + #[diesel(check_for_backend(sqlite::Sqlite))] 1460 + struct TokenData { 1461 + token: String, 1462 + client_id: String, 1463 + subject: String, 1464 + dpop_thumbprint: String, 1465 + scope: Option<String>, 1466 + created_at: i64, 1467 + expires_at: i64, 1468 + revoked: bool, 1469 + } 1470 + let token_data = db 1471 + .get() 1472 + .await 1473 + .expect("Failed to get database connection") 1474 + .interact(move |conn| { 1475 + RefreshTokenSchema::oauth_refresh_tokens 1476 + .filter(RefreshTokenSchema::token.eq(refresh_token)) 1477 + .filter(RefreshTokenSchema::client_id.eq(client_id)) 1478 + .filter(RefreshTokenSchema::expires_at.gt(timestamp)) 1479 + .filter(RefreshTokenSchema::revoked.eq(false)) 1480 + .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint)) 1481 + .first::<TokenData>(conn) 1482 + .optional() 1483 + }) 1484 + .await 1485 + .expect("Failed to query refresh token") 1486 + .expect("Failed to query refresh token") 1487 + .expect("Failed to query refresh token"); 1124 1488 1125 1489 // Rule 10: For confidential clients, verify key is still advertised in their jwks 1126 1490 if is_confidential_client { 1127 1491 let client_still_advertises_key = true; // Implement actual check against client jwks 1128 1492 if !client_still_advertises_key { 1129 1493 // Revoke all tokens bound to this key 1130 - _ = sqlx::query!( 1131 - r#"UPDATE oauth_refresh_tokens SET revoked = TRUE 1132 - WHERE client_id = ? AND dpop_thumbprint = ?"#, 1133 - client_id, 1134 - dpop_thumbprint 1135 - ) 1136 - .execute(&db) 1137 - .await 1138 - .context("failed to revoke tokens")?; 1494 + // _ = sqlx::query!( 1495 + // r#"UPDATE oauth_refresh_tokens SET revoked = TRUE 1496 + // WHERE client_id = ? AND dpop_thumbprint = ?"#, 1497 + // client_id, 1498 + // dpop_thumbprint 1499 + // ) 1500 + // .execute(&db) 1501 + // .await 1502 + // .context("failed to revoke tokens")?; 1503 + let client_id_cloned = client_id.to_owned(); 1504 + let dpop_thumbprint_cloned = dpop_thumbprint.to_owned(); 1505 + _ = db 1506 + .get() 1507 + .await 1508 + .expect("Failed to get database connection") 1509 + .interact(move |conn| { 1510 + update(RefreshTokenSchema::oauth_refresh_tokens) 1511 + .filter(RefreshTokenSchema::client_id.eq(client_id_cloned)) 1512 + .filter( 1513 + RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1514 + ) 1515 + .set(RefreshTokenSchema::revoked.eq(true)) 1516 + .execute(conn) 1517 + }) 1518 + .await 1519 + .expect("Failed to revoke tokens") 1520 + .expect("Failed to revoke tokens"); 1139 1521 1140 1522 return Err(Error::with_status( 1141 1523 StatusCode::BAD_REQUEST, ··· 1145 1527 } 1146 1528 1147 1529 // Rotate the refresh token 1148 - _ = sqlx::query!( 1149 - r#"UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?"#, 1150 - refresh_token 1151 - ) 1152 - .execute(&db) 1153 - .await 1154 - .context("failed to revoke old refresh token")?; 1530 + // _ = sqlx::query!( 1531 + // r#"UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?"#, 1532 + // refresh_token 1533 + // ) 1534 + // .execute(&db) 1535 + // .await 1536 + // .context("failed to revoke old refresh token")?; 1537 + let refresh_token_cloned = refresh_token.to_owned(); 1538 + _ = db 1539 + .get() 1540 + .await 1541 + .expect("Failed to get database connection") 1542 + .interact(move |conn| { 1543 + update(RefreshTokenSchema::oauth_refresh_tokens) 1544 + .filter(RefreshTokenSchema::token.eq(refresh_token_cloned)) 1545 + .set(RefreshTokenSchema::revoked.eq(true)) 1546 + .execute(conn) 1547 + }) 1548 + .await 1549 + .expect("Failed to revoke old refresh token") 1550 + .expect("Failed to revoke old refresh token"); 1155 1551 1156 1552 // Generate new tokens 1157 1553 let now = chrono::Utc::now().timestamp(); ··· 1195 1591 .context("failed to sign refresh token")?; 1196 1592 1197 1593 // Store the new refresh token 1198 - _ = sqlx::query!( 1199 - r#" 1200 - INSERT INTO oauth_refresh_tokens ( 1201 - token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked 1202 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 1203 - "#, 1204 - new_refresh_token, 1205 - client_id, 1206 - token_data.subject, 1207 - dpop_thumbprint, 1208 - token_data.scope, 1209 - now, 1210 - refresh_token_expires_at, 1211 - false 1212 - ) 1213 - .execute(&db) 1214 - .await 1215 - .context("failed to store refresh token")?; 1594 + // _ = sqlx::query!( 1595 + // r#" 1596 + // INSERT INTO oauth_refresh_tokens ( 1597 + // token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked 1598 + // ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 1599 + // "#, 1600 + // new_refresh_token, 1601 + // client_id, 1602 + // token_data.subject, 1603 + // dpop_thumbprint, 1604 + // token_data.scope, 1605 + // now, 1606 + // refresh_token_expires_at, 1607 + // false 1608 + // ) 1609 + // .execute(&db) 1610 + // .await 1611 + // .context("failed to store refresh token")?; 1612 + let new_refresh_token_cloned = new_refresh_token.to_owned(); 1613 + let client_id_cloned = client_id.to_owned(); 1614 + let subject = token_data.subject.to_owned(); 1615 + let dpop_thumbprint_cloned = dpop_thumbprint.to_owned(); 1616 + let scope = token_data.scope.to_owned(); 1617 + let created_at = now; 1618 + let expires_at = refresh_token_expires_at; 1619 + _ = db 1620 + .get() 1621 + .await 1622 + .expect("Failed to get database connection") 1623 + .interact(move |conn| { 1624 + insert_into(RefreshTokenSchema::oauth_refresh_tokens) 1625 + .values(( 1626 + RefreshTokenSchema::token.eq(new_refresh_token_cloned), 1627 + RefreshTokenSchema::client_id.eq(client_id_cloned), 1628 + RefreshTokenSchema::subject.eq(subject), 1629 + RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1630 + RefreshTokenSchema::scope.eq(scope), 1631 + RefreshTokenSchema::created_at.eq(created_at), 1632 + RefreshTokenSchema::expires_at.eq(expires_at), 1633 + RefreshTokenSchema::revoked.eq(false), 1634 + )) 1635 + .execute(conn) 1636 + }) 1637 + .await 1638 + .expect("Failed to store refresh token") 1639 + .expect("Failed to store refresh token"); 1216 1640 1217 1641 // Return token response 1218 1642 Ok(Json(json!({ ··· 1289 1713 /// 1290 1714 /// Implements RFC7009 for revoking refresh tokens 1291 1715 async fn revoke( 1292 - State(db): State<Db>, 1716 + State(db): State<Pool>, 1293 1717 Json(form_data): Json<HashMap<String, String>>, 1294 1718 ) -> Result<Json<Value>> { 1295 1719 // Extract required parameters ··· 1308 1732 } 1309 1733 1310 1734 // Revoke the token 1311 - _ = sqlx::query!( 1312 - r#"UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?"#, 1313 - token 1314 - ) 1315 - .execute(&db) 1316 - .await 1317 - .context("failed to revoke token")?; 1735 + // _ = sqlx::query!( 1736 + // r#"UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?"#, 1737 + // token 1738 + // ) 1739 + // .execute(&db) 1740 + // .await 1741 + // .context("failed to revoke token")?; 1742 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1743 + let token_cloned = token.to_owned(); 1744 + _ = db 1745 + .get() 1746 + .await 1747 + .expect("Failed to get database connection") 1748 + .interact(move |conn| { 1749 + update(RefreshTokenSchema::oauth_refresh_tokens) 1750 + .filter(RefreshTokenSchema::token.eq(token_cloned)) 1751 + .set(RefreshTokenSchema::revoked.eq(true)) 1752 + .execute(conn) 1753 + }) 1754 + .await 1755 + .expect("Failed to revoke token") 1756 + .expect("Failed to revoke token"); 1318 1757 1319 1758 // RFC7009 requires a 200 OK with an empty response 1320 1759 Ok(Json(json!({}))) ··· 1325 1764 /// 1326 1765 /// Implements RFC7662 for introspecting tokens 1327 1766 async fn introspect( 1328 - State(db): State<Db>, 1767 + State(db): State<Pool>, 1329 1768 State(skey): State<SigningKey>, 1330 1769 Json(form_data): Json<HashMap<String, String>>, 1331 1770 ) -> Result<Json<Value>> { ··· 1368 1807 1369 1808 // For refresh tokens, check if it's been revoked 1370 1809 if is_refresh_token { 1371 - let is_revoked = sqlx::query_scalar!( 1372 - r#"SELECT revoked FROM oauth_refresh_tokens WHERE token = ?"#, 1373 - token 1374 - ) 1375 - .fetch_optional(&db) 1376 - .await 1377 - .context("failed to query token")? 1378 - .unwrap_or(true); 1810 + // let is_revoked = sqlx::query_scalar!( 1811 + // r#"SELECT revoked FROM oauth_refresh_tokens WHERE token = ?"#, 1812 + // token 1813 + // ) 1814 + // .fetch_optional(&db) 1815 + // .await 1816 + // .context("failed to query token")? 1817 + // .unwrap_or(true); 1818 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1819 + let token_cloned = token.to_owned(); 1820 + let is_revoked = db 1821 + .get() 1822 + .await 1823 + .expect("Failed to get database connection") 1824 + .interact(move |conn| { 1825 + RefreshTokenSchema::oauth_refresh_tokens 1826 + .filter(RefreshTokenSchema::token.eq(token_cloned)) 1827 + .select(RefreshTokenSchema::revoked) 1828 + .first::<bool>(conn) 1829 + .optional() 1830 + }) 1831 + .await 1832 + .expect("Failed to query token") 1833 + .expect("Failed to query token") 1834 + .unwrap_or(true); 1379 1835 1380 1836 if is_revoked { 1381 1837 return Ok(Json(json!({"active": false})));