Built for people who think better out loud.
at main 1683 lines 61 kB view raw
1use std::collections::HashMap; 2use std::env; 3use std::num::NonZeroUsize; 4 5use anyhow::{Result, anyhow}; 6use atproto_identity::key::{KeyType, generate_key}; 7use atproto_identity::model::Document; 8use atproto_identity::storage_lru::LruDidDocumentStorage; 9use atproto_identity::traits::IdentityResolver; 10use axum::body::Body; 11use axum::http::{Request, StatusCode}; 12use httpmock::prelude::*; 13use slipnote_backend::config::Config; 14use slipnote_backend::http::routers; 15use slipnote_backend::infrastructure::oauth::{DbOAuthRequestStorage, build_oauth_dependencies}; 16use slipnote_backend::infrastructure::whisper::WhisperClient; 17use slipnote_backend::state::AppState; 18use sqlx::{PgPool, Row}; 19use tower::ServiceExt; 20use uuid::Uuid; 21 22struct TestDb { 23 base_pool: PgPool, 24 pool: PgPool, 25 schema: String, 26 schema_url: String, 27} 28 29impl TestDb { 30 async fn setup(database_url: &str) -> Result<Self> { 31 let base_pool = PgPool::connect(database_url).await?; 32 let schema = format!("oauth_flow_{}", Uuid::new_v4().simple()); 33 ensure_safe_schema_name(&schema); 34 35 let create_schema = format!("CREATE SCHEMA {}", schema); 36 sqlx::query(&create_schema).execute(&base_pool).await?; 37 38 let schema_url = append_search_path(database_url, &schema); 39 let pool = PgPool::connect(&schema_url).await?; 40 41 apply_migrations(&pool).await?; 42 43 Ok(Self { 44 base_pool, 45 pool, 46 schema, 47 schema_url, 48 }) 49 } 50 51 async fn cleanup(self) -> Result<()> { 52 let drop_schema = format!("DROP SCHEMA {} CASCADE", self.schema); 53 sqlx::query(&drop_schema).execute(&self.base_pool).await?; 54 Ok(()) 55 } 56} 57 58fn ensure_safe_schema_name(schema: &str) { 59 let ok = schema 60 .chars() 61 .all(|ch| ch.is_ascii_lowercase() || ch.is_ascii_digit() || ch == '_'); 62 assert!(ok, "invalid schema name"); 63} 64 65fn append_search_path(database_url: &str, schema: &str) -> String { 66 let separator = if database_url.contains('?') { '&' } else { '?' }; 67 format!("{database_url}{separator}options=-csearch_path={schema}") 68} 69 70async fn apply_migrations(pool: &PgPool) -> Result<()> { 71 let migrations = [ 72 include_str!("../migrations/20250101000000_create_auth_tables.sql"), 73 include_str!("../migrations/20260228000000_use_did_primary_key.sql"), 74 include_str!("../migrations/20260228001000_create_entitlements_tables.sql"), 75 ]; 76 for migration in migrations { 77 for statement in migration.split(';') { 78 let trimmed = statement.trim(); 79 if trimmed.is_empty() { 80 continue; 81 } 82 sqlx::query(trimmed).execute(pool).await?; 83 } 84 } 85 Ok(()) 86} 87 88#[derive(Clone, Default)] 89struct TestIdentityResolver { 90 documents: HashMap<String, Document>, 91} 92 93impl TestIdentityResolver { 94 fn insert(&mut self, subject: &str, document: Document) { 95 self.documents.insert(subject.to_string(), document); 96 } 97} 98 99#[async_trait::async_trait] 100impl IdentityResolver for TestIdentityResolver { 101 async fn resolve(&self, subject: &str) -> Result<Document> { 102 self.documents 103 .get(subject) 104 .cloned() 105 .ok_or_else(|| anyhow!("unknown subject {subject}")) 106 } 107} 108 109fn test_config(database_url: &str) -> Config { 110 let oauth_signing_key = 111 generate_key(KeyType::P256Private).expect("oauth signing key").to_string(); 112 Config { 113 openai_api_key: "test-key".to_string(), 114 openai_base_url: "http://localhost".to_string(), 115 openai_whisper_model: "whisper-1".to_string(), 116 openai_whisper_response_format: "json".to_string(), 117 bind_addr: "0.0.0.0:3001".to_string(), 118 database_url: database_url.to_string(), 119 slipnote_env: "local".to_string(), 120 log_sample_rate: 1.0, 121 transcription_cost_per_second_dollars: 0.0, 122 axiom_token: None, 123 axiom_dataset: None, 124 axiom_url: None, 125 oauth_client_id: "https://client.example/oauth/client-metadata.json".to_string(), 126 oauth_redirect_uri: "https://client.example/oauth/callback".to_string(), 127 oauth_client_name: Some("Slipnote".to_string()), 128 oauth_client_uri: Some("https://client.example".to_string()), 129 oauth_jwks_uri: Some("https://client.example/.well-known/jwks.json".to_string()), 130 oauth_signing_key, 131 oauth_scopes: "atproto transition:generic".to_string(), 132 oauth_post_auth_redirect: "/".to_string(), 133 oauth_cookie_name: "slipnote_session".to_string(), 134 oauth_session_ttl_seconds: 60 * 60, 135 plc_hostname: "plc.directory".to_string(), 136 oauth_base_url: None, 137 } 138} 139 140fn mock_oauth_server(server: &MockServer) { 141 let base_url = server.base_url(); 142 143 server.mock(|when, then| { 144 when.method(GET) 145 .path("/.well-known/oauth-protected-resource"); 146 then.status(200).json_body(serde_json::json!({ 147 "resource": base_url, 148 "authorization_servers": [base_url], 149 "scopes_supported": ["atproto", "transition:generic"], 150 "bearer_methods_supported": ["header"] 151 })); 152 }); 153 154 server.mock(|when, then| { 155 when.method(GET) 156 .path("/.well-known/oauth-authorization-server"); 157 then.status(200).json_body(serde_json::json!({ 158 "issuer": base_url, 159 "authorization_endpoint": format!("{}/authorize", base_url), 160 "token_endpoint": format!("{}/token", base_url), 161 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 162 "authorization_response_iss_parameter_supported": true, 163 "client_id_metadata_document_supported": true, 164 "code_challenge_methods_supported": ["S256"], 165 "dpop_signing_alg_values_supported": ["ES256"], 166 "grant_types_supported": ["authorization_code", "refresh_token"], 167 "response_types_supported": ["code"], 168 "scopes_supported": ["atproto", "transition:generic"], 169 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 170 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 171 "require_pushed_authorization_requests": true, 172 "request_parameter_supported": true 173 })); 174 }); 175 176 server.mock(|when, then| { 177 when.method(POST).path("/par"); 178 then.status(200).json_body(serde_json::json!({ 179 "request_uri": "urn:example:par", 180 "expires_in": 90 181 })); 182 }); 183 184 server.mock(|when, then| { 185 when.method(POST).path("/token"); 186 then.status(200).json_body(serde_json::json!({ 187 "access_token": "access123", 188 "token_type": "DPoP", 189 "refresh_token": "refresh123", 190 "scope": "atproto transition:generic", 191 "expires_in": 3600, 192 "sub": "did:plc:alice" 193 })); 194 }); 195} 196 197fn mock_oauth_server_with_token( 198 server: &MockServer, 199 token_status: u16, 200 token_body: serde_json::Value, 201) { 202 let base_url = server.base_url(); 203 204 server.mock(|when, then| { 205 when.method(GET) 206 .path("/.well-known/oauth-protected-resource"); 207 then.status(200).json_body(serde_json::json!({ 208 "resource": base_url, 209 "authorization_servers": [base_url], 210 "scopes_supported": ["atproto", "transition:generic"] 211 })); 212 }); 213 214 server.mock(|when, then| { 215 when.method(GET) 216 .path("/.well-known/oauth-authorization-server"); 217 then.status(200).json_body(serde_json::json!({ 218 "issuer": base_url, 219 "authorization_endpoint": format!("{}/authorize", base_url), 220 "token_endpoint": format!("{}/token", base_url), 221 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 222 "authorization_response_iss_parameter_supported": true, 223 "client_id_metadata_document_supported": true, 224 "code_challenge_methods_supported": ["S256"], 225 "dpop_signing_alg_values_supported": ["ES256"], 226 "grant_types_supported": ["authorization_code", "refresh_token"], 227 "response_types_supported": ["code"], 228 "scopes_supported": ["atproto", "transition:generic"], 229 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 230 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 231 "require_pushed_authorization_requests": true, 232 "request_parameter_supported": true 233 })); 234 }); 235 236 server.mock(|when, then| { 237 when.method(POST).path("/par"); 238 then.status(200).json_body(serde_json::json!({ 239 "request_uri": "urn:example:par", 240 "expires_in": 90 241 })); 242 }); 243 244 server.mock(|when, then| { 245 when.method(POST).path("/token"); 246 then.status(token_status).json_body(token_body); 247 }); 248} 249 250#[tokio::test] 251async fn oauth_flow_roundtrip() { 252 let Some(database_url) = env::var("DATABASE_URL").ok() else { 253 eprintln!("DATABASE_URL not set, skipping oauth_flow_roundtrip"); 254 return; 255 }; 256 257 let server = MockServer::start(); 258 mock_oauth_server(&server); 259 260 let test_db = TestDb::setup(&database_url) 261 .await 262 .expect("setup db"); 263 264 let config = test_config(&test_db.schema_url); 265 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 266 267 let pds_endpoint = server.base_url(); 268 let subject = "alice.test"; 269 let did = "did:plc:alice"; 270 let did_document = Document::builder() 271 .id(did) 272 .add_also_known_as(format!("at://{subject}")) 273 .add_pds_service(&pds_endpoint) 274 .build() 275 .expect("build did document"); 276 let subject_document = did_document.clone(); 277 278 let mut identity_resolver = TestIdentityResolver::default(); 279 identity_resolver.insert(subject, subject_document); 280 identity_resolver.insert(did, did_document); 281 282 let did_document_storage = 283 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 284 285 let state = AppState { 286 config: config.clone(), 287 whisper_client: WhisperClient::new(config.clone()), 288 logger: slipnote_backend::logging::Logger::disabled(), 289 db_pool: test_db.pool.clone(), 290 http_client: oauth_dependencies.http_client, 291 oauth_client_config: oauth_dependencies.oauth_client_config, 292 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 293 test_db.pool.clone(), 294 )), 295 did_document_storage: std::sync::Arc::new(did_document_storage), 296 key_resolver: oauth_dependencies.key_resolver, 297 identity_resolver: std::sync::Arc::new(identity_resolver), 298 oauth_signing_key: oauth_dependencies.oauth_signing_key, 299 }; 300 301 let app = routers::router(state.clone()); 302 303 let start_request = Request::builder() 304 .uri(format!("/api/auth/atproto/start?subject={subject}")) 305 .body(Body::empty()) 306 .expect("start request"); 307 308 let start_response = app 309 .clone() 310 .oneshot(start_request) 311 .await 312 .expect("start response"); 313 assert_eq!(start_response.status(), StatusCode::FOUND); 314 315 let location = start_response 316 .headers() 317 .get(axum::http::header::LOCATION) 318 .and_then(|value| value.to_str().ok()) 319 .expect("missing location header"); 320 assert!(location.contains("/authorize"), "unexpected location {location}"); 321 assert!(location.contains("request_uri="), "missing request_uri"); 322 323 let oauth_request = sqlx::query( 324 r#" 325 SELECT oauth_state, issuer, authorization_server 326 FROM oauth_requests 327 "#, 328 ) 329 .fetch_one(&test_db.pool) 330 .await 331 .expect("fetch oauth request"); 332 333 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 334 let issuer: String = oauth_request.try_get("issuer").expect("issuer"); 335 let authorization_server: String = oauth_request 336 .try_get("authorization_server") 337 .expect("authorization_server"); 338 339 assert_eq!(issuer, pds_endpoint); 340 assert_eq!(authorization_server, pds_endpoint); 341 342 let callback_request = Request::builder() 343 .uri(format!( 344 "/oauth/callback?code=authcode&state={}&iss={}", 345 oauth_state, issuer 346 )) 347 .body(Body::empty()) 348 .expect("callback request"); 349 350 let callback_response = app 351 .oneshot(callback_request) 352 .await 353 .expect("callback response"); 354 assert_eq!(callback_response.status(), StatusCode::FOUND); 355 356 let redirect = callback_response 357 .headers() 358 .get(axum::http::header::LOCATION) 359 .and_then(|value| value.to_str().ok()) 360 .expect("missing redirect location"); 361 assert_eq!(redirect, "/"); 362 363 let cookie = callback_response 364 .headers() 365 .get(axum::http::header::SET_COOKIE) 366 .and_then(|value| value.to_str().ok()) 367 .expect("missing set-cookie header"); 368 assert!(cookie.contains("slipnote_session=")); 369 370 let remaining = sqlx::query("SELECT COUNT(*) as total FROM oauth_requests") 371 .fetch_one(&test_db.pool) 372 .await 373 .expect("count oauth requests"); 374 let remaining: i64 = remaining.try_get("total").expect("count"); 375 assert_eq!(remaining, 0); 376 377 let token_row = sqlx::query( 378 r#" 379 SELECT access_token, refresh_token, token_type, scopes, issuer 380 FROM oauth_tokens 381 "#, 382 ) 383 .fetch_one(&test_db.pool) 384 .await 385 .expect("fetch oauth token"); 386 387 let access_token: String = token_row.try_get("access_token").expect("access_token"); 388 let refresh_token: String = token_row.try_get("refresh_token").expect("refresh_token"); 389 let token_type: String = token_row.try_get("token_type").expect("token_type"); 390 let scopes: String = token_row.try_get("scopes").expect("scopes"); 391 let stored_issuer: String = token_row.try_get("issuer").expect("issuer"); 392 393 assert_eq!(access_token, "access123"); 394 assert_eq!(refresh_token, "refresh123"); 395 assert_eq!(token_type, "DPoP"); 396 assert_eq!(scopes, "atproto transition:generic"); 397 assert_eq!(stored_issuer, pds_endpoint); 398 399 let sessions = sqlx::query("SELECT COUNT(*) as total FROM sessions") 400 .fetch_one(&test_db.pool) 401 .await 402 .expect("count sessions"); 403 let session_count: i64 = sessions.try_get("total").expect("session count"); 404 assert_eq!(session_count, 1); 405 406 test_db.cleanup().await.expect("cleanup db"); 407} 408 409#[tokio::test] 410async fn oauth_start_requires_subject() { 411 let Some(database_url) = env::var("DATABASE_URL").ok() else { 412 eprintln!("DATABASE_URL not set, skipping oauth_start_requires_subject"); 413 return; 414 }; 415 416 let server = MockServer::start(); 417 mock_oauth_server(&server); 418 419 let test_db = TestDb::setup(&database_url) 420 .await 421 .expect("setup db"); 422 let config = test_config(&test_db.schema_url); 423 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 424 425 let did_document_storage = 426 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 427 428 let state = AppState { 429 config: config.clone(), 430 whisper_client: WhisperClient::new(config.clone()), 431 logger: slipnote_backend::logging::Logger::disabled(), 432 db_pool: test_db.pool.clone(), 433 http_client: oauth_dependencies.http_client, 434 oauth_client_config: oauth_dependencies.oauth_client_config, 435 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 436 test_db.pool.clone(), 437 )), 438 did_document_storage: std::sync::Arc::new(did_document_storage), 439 key_resolver: oauth_dependencies.key_resolver, 440 identity_resolver: std::sync::Arc::new(TestIdentityResolver::default()), 441 oauth_signing_key: oauth_dependencies.oauth_signing_key, 442 }; 443 444 let app = routers::router(state.clone()); 445 446 let start_request = Request::builder() 447 .uri("/api/auth/atproto/start") 448 .body(Body::empty()) 449 .expect("start request"); 450 451 let response = app 452 .oneshot(start_request) 453 .await 454 .expect("start response"); 455 456 assert_eq!(response.status(), StatusCode::BAD_REQUEST); 457 458 test_db.cleanup().await.expect("cleanup db"); 459} 460 461#[tokio::test] 462async fn oauth_callback_rejects_invalid_issuer() { 463 let Some(database_url) = env::var("DATABASE_URL").ok() else { 464 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_invalid_issuer"); 465 return; 466 }; 467 468 let server = MockServer::start(); 469 mock_oauth_server(&server); 470 471 let test_db = TestDb::setup(&database_url) 472 .await 473 .expect("setup db"); 474 475 let config = test_config(&test_db.schema_url); 476 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 477 478 let pds_endpoint = server.base_url(); 479 let subject = "bob.test"; 480 let did = "did:plc:bob"; 481 let did_document = Document::builder() 482 .id(did) 483 .add_also_known_as(format!("at://{subject}")) 484 .add_pds_service(&pds_endpoint) 485 .build() 486 .expect("build did document"); 487 488 let mut identity_resolver = TestIdentityResolver::default(); 489 identity_resolver.insert(subject, did_document.clone()); 490 identity_resolver.insert(did, did_document); 491 492 let did_document_storage = 493 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 494 495 let state = AppState { 496 config: config.clone(), 497 whisper_client: WhisperClient::new(config.clone()), 498 logger: slipnote_backend::logging::Logger::disabled(), 499 db_pool: test_db.pool.clone(), 500 http_client: oauth_dependencies.http_client, 501 oauth_client_config: oauth_dependencies.oauth_client_config, 502 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 503 test_db.pool.clone(), 504 )), 505 did_document_storage: std::sync::Arc::new(did_document_storage), 506 key_resolver: oauth_dependencies.key_resolver, 507 identity_resolver: std::sync::Arc::new(identity_resolver), 508 oauth_signing_key: oauth_dependencies.oauth_signing_key, 509 }; 510 511 let app = routers::router(state.clone()); 512 513 let start_request = Request::builder() 514 .uri(format!("/api/auth/atproto/start?subject={subject}")) 515 .body(Body::empty()) 516 .expect("start request"); 517 518 let start_response = app 519 .clone() 520 .oneshot(start_request) 521 .await 522 .expect("start response"); 523 assert_eq!(start_response.status(), StatusCode::FOUND); 524 525 let oauth_request = sqlx::query( 526 r#" 527 SELECT oauth_state 528 FROM oauth_requests 529 "#, 530 ) 531 .fetch_one(&test_db.pool) 532 .await 533 .expect("fetch oauth request"); 534 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 535 536 let callback_request = Request::builder() 537 .uri(format!( 538 "/oauth/callback?code=authcode&state={}&iss=https://wrong-issuer.example", 539 oauth_state 540 )) 541 .body(Body::empty()) 542 .expect("callback request"); 543 544 let callback_response = app 545 .oneshot(callback_request) 546 .await 547 .expect("callback response"); 548 549 assert_eq!(callback_response.status(), StatusCode::BAD_REQUEST); 550 551 test_db.cleanup().await.expect("cleanup db"); 552} 553 554#[tokio::test] 555async fn oauth_callback_rejects_scope_without_atproto() { 556 let Some(database_url) = env::var("DATABASE_URL").ok() else { 557 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_scope_without_atproto"); 558 return; 559 }; 560 561 let server = MockServer::start(); 562 mock_oauth_server_with_token( 563 &server, 564 200, 565 serde_json::json!({ 566 "access_token": "access123", 567 "token_type": "DPoP", 568 "refresh_token": "refresh123", 569 "scope": "transition:generic", 570 "expires_in": 3600, 571 "sub": "did:plc:alice" 572 }), 573 ); 574 575 let test_db = TestDb::setup(&database_url) 576 .await 577 .expect("setup db"); 578 579 let config = test_config(&test_db.schema_url); 580 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 581 582 let pds_endpoint = server.base_url(); 583 let subject = "alice.test"; 584 let did = "did:plc:alice"; 585 let did_document = Document::builder() 586 .id(did) 587 .add_also_known_as(format!("at://{subject}")) 588 .add_pds_service(&pds_endpoint) 589 .build() 590 .expect("build did document"); 591 592 let mut identity_resolver = TestIdentityResolver::default(); 593 identity_resolver.insert(subject, did_document.clone()); 594 identity_resolver.insert(did, did_document); 595 596 let did_document_storage = 597 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 598 599 let state = AppState { 600 config: config.clone(), 601 whisper_client: WhisperClient::new(config.clone()), 602 logger: slipnote_backend::logging::Logger::disabled(), 603 db_pool: test_db.pool.clone(), 604 http_client: oauth_dependencies.http_client, 605 oauth_client_config: oauth_dependencies.oauth_client_config, 606 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 607 test_db.pool.clone(), 608 )), 609 did_document_storage: std::sync::Arc::new(did_document_storage), 610 key_resolver: oauth_dependencies.key_resolver, 611 identity_resolver: std::sync::Arc::new(identity_resolver), 612 oauth_signing_key: oauth_dependencies.oauth_signing_key, 613 }; 614 615 let app = routers::router(state.clone()); 616 617 let start_request = Request::builder() 618 .uri(format!("/api/auth/atproto/start?subject={subject}")) 619 .body(Body::empty()) 620 .expect("start request"); 621 622 let start_response = app 623 .clone() 624 .oneshot(start_request) 625 .await 626 .expect("start response"); 627 assert_eq!(start_response.status(), StatusCode::FOUND); 628 629 let oauth_request = sqlx::query("SELECT oauth_state, issuer FROM oauth_requests") 630 .fetch_one(&test_db.pool) 631 .await 632 .expect("fetch oauth request"); 633 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 634 let issuer: String = oauth_request.try_get("issuer").expect("issuer"); 635 636 let callback_request = Request::builder() 637 .uri(format!( 638 "/oauth/callback?code=authcode&state={}&iss={}", 639 oauth_state, issuer 640 )) 641 .body(Body::empty()) 642 .expect("callback request"); 643 644 let callback_response = app 645 .oneshot(callback_request) 646 .await 647 .expect("callback response"); 648 649 assert_eq!(callback_response.status(), StatusCode::BAD_REQUEST); 650 651 test_db.cleanup().await.expect("cleanup db"); 652} 653 654#[tokio::test] 655async fn oauth_callback_rejects_did_with_mismatched_pds() { 656 let Some(database_url) = env::var("DATABASE_URL").ok() else { 657 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_did_with_mismatched_pds"); 658 return; 659 }; 660 661 let server = MockServer::start(); 662 mock_oauth_server(&server); 663 664 let other_server = MockServer::start(); 665 mock_oauth_server(&other_server); 666 667 let test_db = TestDb::setup(&database_url) 668 .await 669 .expect("setup db"); 670 671 let config = test_config(&test_db.schema_url); 672 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 673 674 let pds_endpoint = server.base_url(); 675 let other_pds = other_server.base_url(); 676 let subject = "carol.test"; 677 let did = "did:plc:carol"; 678 let subject_document = Document::builder() 679 .id(did) 680 .add_also_known_as(format!("at://{subject}")) 681 .add_pds_service(&pds_endpoint) 682 .build() 683 .expect("build subject document"); 684 let did_document = Document::builder() 685 .id(did) 686 .add_also_known_as(format!("at://{subject}")) 687 .add_pds_service(&other_pds) 688 .build() 689 .expect("build did document"); 690 691 let mut identity_resolver = TestIdentityResolver::default(); 692 identity_resolver.insert(subject, subject_document); 693 identity_resolver.insert(did, did_document); 694 695 let did_document_storage = 696 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 697 698 let state = AppState { 699 config: config.clone(), 700 whisper_client: WhisperClient::new(config.clone()), 701 logger: slipnote_backend::logging::Logger::disabled(), 702 db_pool: test_db.pool.clone(), 703 http_client: oauth_dependencies.http_client, 704 oauth_client_config: oauth_dependencies.oauth_client_config, 705 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 706 test_db.pool.clone(), 707 )), 708 did_document_storage: std::sync::Arc::new(did_document_storage), 709 key_resolver: oauth_dependencies.key_resolver, 710 identity_resolver: std::sync::Arc::new(identity_resolver), 711 oauth_signing_key: oauth_dependencies.oauth_signing_key, 712 }; 713 714 let app = routers::router(state.clone()); 715 716 let start_request = Request::builder() 717 .uri(format!("/api/auth/atproto/start?subject={subject}")) 718 .body(Body::empty()) 719 .expect("start request"); 720 721 let start_response = app 722 .clone() 723 .oneshot(start_request) 724 .await 725 .expect("start response"); 726 assert_eq!(start_response.status(), StatusCode::FOUND); 727 728 let oauth_request = sqlx::query("SELECT oauth_state, issuer FROM oauth_requests") 729 .fetch_one(&test_db.pool) 730 .await 731 .expect("fetch oauth request"); 732 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 733 let issuer: String = oauth_request.try_get("issuer").expect("issuer"); 734 735 let callback_request = Request::builder() 736 .uri(format!( 737 "/oauth/callback?code=authcode&state={}&iss={}", 738 oauth_state, issuer 739 )) 740 .body(Body::empty()) 741 .expect("callback request"); 742 743 let callback_response = app 744 .oneshot(callback_request) 745 .await 746 .expect("callback response"); 747 748 assert_eq!(callback_response.status(), StatusCode::BAD_REQUEST); 749 750 test_db.cleanup().await.expect("cleanup db"); 751} 752 753#[tokio::test] 754async fn oauth_start_rejects_invalid_pds_metadata() { 755 let Some(database_url) = env::var("DATABASE_URL").ok() else { 756 eprintln!("DATABASE_URL not set, skipping oauth_start_rejects_invalid_pds_metadata"); 757 return; 758 }; 759 760 let server = MockServer::start(); 761 let base_url = server.base_url(); 762 763 server.mock(|when, then| { 764 when.method(GET) 765 .path("/.well-known/oauth-protected-resource"); 766 then.status(200).json_body(serde_json::json!({ 767 "resource": "https://wrong.example", 768 "authorization_servers": [base_url] 769 })); 770 }); 771 772 server.mock(|when, then| { 773 when.method(GET) 774 .path("/.well-known/oauth-authorization-server"); 775 then.status(200).json_body(serde_json::json!({ 776 "issuer": base_url, 777 "authorization_endpoint": format!("{}/authorize", base_url), 778 "token_endpoint": format!("{}/token", base_url), 779 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 780 "authorization_response_iss_parameter_supported": true, 781 "client_id_metadata_document_supported": true, 782 "code_challenge_methods_supported": ["S256"], 783 "dpop_signing_alg_values_supported": ["ES256"], 784 "grant_types_supported": ["authorization_code", "refresh_token"], 785 "response_types_supported": ["code"], 786 "scopes_supported": ["atproto", "transition:generic"], 787 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 788 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 789 "require_pushed_authorization_requests": true, 790 "request_parameter_supported": true 791 })); 792 }); 793 794 let test_db = TestDb::setup(&database_url) 795 .await 796 .expect("setup db"); 797 let config = test_config(&test_db.schema_url); 798 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 799 800 let subject = "dave.test"; 801 let did = "did:plc:dave"; 802 let did_document = Document::builder() 803 .id(did) 804 .add_also_known_as(format!("at://{subject}")) 805 .add_pds_service(&base_url) 806 .build() 807 .expect("build did document"); 808 809 let mut identity_resolver = TestIdentityResolver::default(); 810 identity_resolver.insert(subject, did_document); 811 812 let did_document_storage = 813 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 814 815 let state = AppState { 816 config: config.clone(), 817 whisper_client: WhisperClient::new(config.clone()), 818 logger: slipnote_backend::logging::Logger::disabled(), 819 db_pool: test_db.pool.clone(), 820 http_client: oauth_dependencies.http_client, 821 oauth_client_config: oauth_dependencies.oauth_client_config, 822 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 823 test_db.pool.clone(), 824 )), 825 did_document_storage: std::sync::Arc::new(did_document_storage), 826 key_resolver: oauth_dependencies.key_resolver, 827 identity_resolver: std::sync::Arc::new(identity_resolver), 828 oauth_signing_key: oauth_dependencies.oauth_signing_key, 829 }; 830 831 let app = routers::router(state.clone()); 832 833 let start_request = Request::builder() 834 .uri(format!("/api/auth/atproto/start?subject={subject}")) 835 .body(Body::empty()) 836 .expect("start request"); 837 838 let start_response = app 839 .oneshot(start_request) 840 .await 841 .expect("start response"); 842 assert_eq!(start_response.status(), StatusCode::BAD_GATEWAY); 843 844 test_db.cleanup().await.expect("cleanup db"); 845} 846 847#[tokio::test] 848async fn oauth_callback_rejects_missing_state() { 849 let Some(database_url) = env::var("DATABASE_URL").ok() else { 850 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_missing_state"); 851 return; 852 }; 853 854 let server = MockServer::start(); 855 mock_oauth_server(&server); 856 857 let test_db = TestDb::setup(&database_url) 858 .await 859 .expect("setup db"); 860 861 let config = test_config(&test_db.schema_url); 862 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 863 864 let did_document_storage = 865 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 866 867 let state = AppState { 868 config: config.clone(), 869 whisper_client: WhisperClient::new(config.clone()), 870 logger: slipnote_backend::logging::Logger::disabled(), 871 db_pool: test_db.pool.clone(), 872 http_client: oauth_dependencies.http_client, 873 oauth_client_config: oauth_dependencies.oauth_client_config, 874 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 875 test_db.pool.clone(), 876 )), 877 did_document_storage: std::sync::Arc::new(did_document_storage), 878 key_resolver: oauth_dependencies.key_resolver, 879 identity_resolver: std::sync::Arc::new(TestIdentityResolver::default()), 880 oauth_signing_key: oauth_dependencies.oauth_signing_key, 881 }; 882 883 let app = routers::router(state.clone()); 884 885 let callback_request = Request::builder() 886 .uri("/oauth/callback?code=authcode&state=missing&iss=https://issuer.example") 887 .body(Body::empty()) 888 .expect("callback request"); 889 890 let callback_response = app 891 .oneshot(callback_request) 892 .await 893 .expect("callback response"); 894 895 assert_eq!(callback_response.status(), StatusCode::BAD_REQUEST); 896 897 test_db.cleanup().await.expect("cleanup db"); 898} 899 900#[tokio::test] 901async fn oauth_start_rejects_auth_server_without_atproto_scope() { 902 let Some(database_url) = env::var("DATABASE_URL").ok() else { 903 eprintln!( 904 "DATABASE_URL not set, skipping oauth_start_rejects_auth_server_without_atproto_scope" 905 ); 906 return; 907 }; 908 909 let server = MockServer::start(); 910 let base_url = server.base_url(); 911 912 server.mock(|when, then| { 913 when.method(GET) 914 .path("/.well-known/oauth-protected-resource"); 915 then.status(200).json_body(serde_json::json!({ 916 "resource": base_url, 917 "authorization_servers": [base_url], 918 "scopes_supported": ["transition:generic"] 919 })); 920 }); 921 922 server.mock(|when, then| { 923 when.method(GET) 924 .path("/.well-known/oauth-authorization-server"); 925 then.status(200).json_body(serde_json::json!({ 926 "issuer": base_url, 927 "authorization_endpoint": format!("{}/authorize", base_url), 928 "token_endpoint": format!("{}/token", base_url), 929 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 930 "authorization_response_iss_parameter_supported": true, 931 "client_id_metadata_document_supported": true, 932 "code_challenge_methods_supported": ["S256"], 933 "dpop_signing_alg_values_supported": ["ES256"], 934 "grant_types_supported": ["authorization_code", "refresh_token"], 935 "response_types_supported": ["code"], 936 "scopes_supported": ["transition:generic"], 937 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 938 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 939 "require_pushed_authorization_requests": true, 940 "request_parameter_supported": true 941 })); 942 }); 943 944 let test_db = TestDb::setup(&database_url) 945 .await 946 .expect("setup db"); 947 let config = test_config(&test_db.schema_url); 948 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 949 950 let subject = "erin.test"; 951 let did = "did:plc:erin"; 952 let did_document = Document::builder() 953 .id(did) 954 .add_also_known_as(format!("at://{subject}")) 955 .add_pds_service(&base_url) 956 .build() 957 .expect("build did document"); 958 959 let mut identity_resolver = TestIdentityResolver::default(); 960 identity_resolver.insert(subject, did_document); 961 962 let did_document_storage = 963 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 964 965 let state = AppState { 966 config: config.clone(), 967 whisper_client: WhisperClient::new(config.clone()), 968 logger: slipnote_backend::logging::Logger::disabled(), 969 db_pool: test_db.pool.clone(), 970 http_client: oauth_dependencies.http_client, 971 oauth_client_config: oauth_dependencies.oauth_client_config, 972 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 973 test_db.pool.clone(), 974 )), 975 did_document_storage: std::sync::Arc::new(did_document_storage), 976 key_resolver: oauth_dependencies.key_resolver, 977 identity_resolver: std::sync::Arc::new(identity_resolver), 978 oauth_signing_key: oauth_dependencies.oauth_signing_key, 979 }; 980 981 let app = routers::router(state.clone()); 982 983 let start_request = Request::builder() 984 .uri(format!("/api/auth/atproto/start?subject={subject}")) 985 .body(Body::empty()) 986 .expect("start request"); 987 988 let start_response = app 989 .oneshot(start_request) 990 .await 991 .expect("start response"); 992 assert_eq!(start_response.status(), StatusCode::BAD_GATEWAY); 993 994 test_db.cleanup().await.expect("cleanup db"); 995} 996 997#[tokio::test] 998async fn oauth_start_rejects_malformed_par_response() { 999 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1000 eprintln!("DATABASE_URL not set, skipping oauth_start_rejects_malformed_par_response"); 1001 return; 1002 }; 1003 1004 let server = MockServer::start(); 1005 let base_url = server.base_url(); 1006 1007 server.mock(|when, then| { 1008 when.method(GET) 1009 .path("/.well-known/oauth-protected-resource"); 1010 then.status(200).json_body(serde_json::json!({ 1011 "resource": base_url, 1012 "authorization_servers": [base_url] 1013 })); 1014 }); 1015 1016 server.mock(|when, then| { 1017 when.method(GET) 1018 .path("/.well-known/oauth-authorization-server"); 1019 then.status(200).json_body(serde_json::json!({ 1020 "issuer": base_url, 1021 "authorization_endpoint": format!("{}/authorize", base_url), 1022 "token_endpoint": format!("{}/token", base_url), 1023 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 1024 "authorization_response_iss_parameter_supported": true, 1025 "client_id_metadata_document_supported": true, 1026 "code_challenge_methods_supported": ["S256"], 1027 "dpop_signing_alg_values_supported": ["ES256"], 1028 "grant_types_supported": ["authorization_code", "refresh_token"], 1029 "response_types_supported": ["code"], 1030 "scopes_supported": ["atproto", "transition:generic"], 1031 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 1032 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 1033 "require_pushed_authorization_requests": true, 1034 "request_parameter_supported": true 1035 })); 1036 }); 1037 1038 server.mock(|when, then| { 1039 when.method(POST).path("/par"); 1040 then.status(200).json_body(serde_json::json!({ 1041 "expires_in": 90 1042 })); 1043 }); 1044 1045 let test_db = TestDb::setup(&database_url) 1046 .await 1047 .expect("setup db"); 1048 let config = test_config(&test_db.schema_url); 1049 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1050 1051 let subject = "frank.test"; 1052 let did = "did:plc:frank"; 1053 let did_document = Document::builder() 1054 .id(did) 1055 .add_also_known_as(format!("at://{subject}")) 1056 .add_pds_service(&base_url) 1057 .build() 1058 .expect("build did document"); 1059 1060 let mut identity_resolver = TestIdentityResolver::default(); 1061 identity_resolver.insert(subject, did_document); 1062 1063 let did_document_storage = 1064 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1065 1066 let state = AppState { 1067 config: config.clone(), 1068 whisper_client: WhisperClient::new(config.clone()), 1069 logger: slipnote_backend::logging::Logger::disabled(), 1070 db_pool: test_db.pool.clone(), 1071 http_client: oauth_dependencies.http_client, 1072 oauth_client_config: oauth_dependencies.oauth_client_config, 1073 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1074 test_db.pool.clone(), 1075 )), 1076 did_document_storage: std::sync::Arc::new(did_document_storage), 1077 key_resolver: oauth_dependencies.key_resolver, 1078 identity_resolver: std::sync::Arc::new(identity_resolver), 1079 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1080 }; 1081 1082 let app = routers::router(state.clone()); 1083 1084 let start_request = Request::builder() 1085 .uri(format!("/api/auth/atproto/start?subject={subject}")) 1086 .body(Body::empty()) 1087 .expect("start request"); 1088 1089 let start_response = app 1090 .oneshot(start_request) 1091 .await 1092 .expect("start response"); 1093 assert_eq!(start_response.status(), StatusCode::BAD_GATEWAY); 1094 1095 test_db.cleanup().await.expect("cleanup db"); 1096} 1097 1098#[tokio::test] 1099async fn oauth_callback_rejects_missing_code() { 1100 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1101 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_missing_code"); 1102 return; 1103 }; 1104 1105 let server = MockServer::start(); 1106 mock_oauth_server(&server); 1107 1108 let test_db = TestDb::setup(&database_url) 1109 .await 1110 .expect("setup db"); 1111 1112 let config = test_config(&test_db.schema_url); 1113 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1114 1115 let did_document_storage = 1116 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1117 1118 let state = AppState { 1119 config: config.clone(), 1120 whisper_client: WhisperClient::new(config.clone()), 1121 logger: slipnote_backend::logging::Logger::disabled(), 1122 db_pool: test_db.pool.clone(), 1123 http_client: oauth_dependencies.http_client, 1124 oauth_client_config: oauth_dependencies.oauth_client_config, 1125 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1126 test_db.pool.clone(), 1127 )), 1128 did_document_storage: std::sync::Arc::new(did_document_storage), 1129 key_resolver: oauth_dependencies.key_resolver, 1130 identity_resolver: std::sync::Arc::new(TestIdentityResolver::default()), 1131 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1132 }; 1133 1134 let app = routers::router(state.clone()); 1135 1136 let callback_request = Request::builder() 1137 .uri("/oauth/callback?state=state&iss=https://issuer.example") 1138 .body(Body::empty()) 1139 .expect("callback request"); 1140 1141 let callback_response = app 1142 .oneshot(callback_request) 1143 .await 1144 .expect("callback response"); 1145 1146 assert_eq!(callback_response.status(), StatusCode::BAD_REQUEST); 1147 1148 test_db.cleanup().await.expect("cleanup db"); 1149} 1150 1151#[tokio::test] 1152async fn oauth_callback_rejects_missing_iss() { 1153 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1154 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_missing_iss"); 1155 return; 1156 }; 1157 1158 let server = MockServer::start(); 1159 mock_oauth_server(&server); 1160 1161 let test_db = TestDb::setup(&database_url) 1162 .await 1163 .expect("setup db"); 1164 1165 let config = test_config(&test_db.schema_url); 1166 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1167 1168 let did_document_storage = 1169 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1170 1171 let state = AppState { 1172 config: config.clone(), 1173 whisper_client: WhisperClient::new(config.clone()), 1174 logger: slipnote_backend::logging::Logger::disabled(), 1175 db_pool: test_db.pool.clone(), 1176 http_client: oauth_dependencies.http_client, 1177 oauth_client_config: oauth_dependencies.oauth_client_config, 1178 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1179 test_db.pool.clone(), 1180 )), 1181 did_document_storage: std::sync::Arc::new(did_document_storage), 1182 key_resolver: oauth_dependencies.key_resolver, 1183 identity_resolver: std::sync::Arc::new(TestIdentityResolver::default()), 1184 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1185 }; 1186 1187 let app = routers::router(state.clone()); 1188 1189 let callback_request = Request::builder() 1190 .uri("/oauth/callback?state=state&code=authcode") 1191 .body(Body::empty()) 1192 .expect("callback request"); 1193 1194 let callback_response = app 1195 .oneshot(callback_request) 1196 .await 1197 .expect("callback response"); 1198 1199 assert_eq!(callback_response.status(), StatusCode::BAD_REQUEST); 1200 1201 test_db.cleanup().await.expect("cleanup db"); 1202} 1203 1204#[tokio::test] 1205async fn oauth_start_rejects_protected_resource_with_multiple_auth_servers() { 1206 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1207 eprintln!( 1208 "DATABASE_URL not set, skipping oauth_start_rejects_protected_resource_with_multiple_auth_servers" 1209 ); 1210 return; 1211 }; 1212 1213 let server = MockServer::start(); 1214 let base_url = server.base_url(); 1215 1216 server.mock(|when, then| { 1217 when.method(GET) 1218 .path("/.well-known/oauth-protected-resource"); 1219 then.status(200).json_body(serde_json::json!({ 1220 "resource": base_url, 1221 "authorization_servers": [base_url, "https://extra.example"] 1222 })); 1223 }); 1224 1225 server.mock(|when, then| { 1226 when.method(GET) 1227 .path("/.well-known/oauth-authorization-server"); 1228 then.status(200).json_body(serde_json::json!({ 1229 "issuer": base_url, 1230 "authorization_endpoint": format!("{}/authorize", base_url), 1231 "token_endpoint": format!("{}/token", base_url), 1232 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 1233 "authorization_response_iss_parameter_supported": true, 1234 "client_id_metadata_document_supported": true, 1235 "code_challenge_methods_supported": ["S256"], 1236 "dpop_signing_alg_values_supported": ["ES256"], 1237 "grant_types_supported": ["authorization_code", "refresh_token"], 1238 "response_types_supported": ["code"], 1239 "scopes_supported": ["atproto", "transition:generic"], 1240 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 1241 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 1242 "require_pushed_authorization_requests": true, 1243 "request_parameter_supported": true 1244 })); 1245 }); 1246 1247 let test_db = TestDb::setup(&database_url) 1248 .await 1249 .expect("setup db"); 1250 let config = test_config(&test_db.schema_url); 1251 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1252 1253 let subject = "gina.test"; 1254 let did = "did:plc:gina"; 1255 let did_document = Document::builder() 1256 .id(did) 1257 .add_also_known_as(format!("at://{subject}")) 1258 .add_pds_service(&base_url) 1259 .build() 1260 .expect("build did document"); 1261 1262 let mut identity_resolver = TestIdentityResolver::default(); 1263 identity_resolver.insert(subject, did_document); 1264 1265 let did_document_storage = 1266 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1267 1268 let state = AppState { 1269 config: config.clone(), 1270 whisper_client: WhisperClient::new(config.clone()), 1271 logger: slipnote_backend::logging::Logger::disabled(), 1272 db_pool: test_db.pool.clone(), 1273 http_client: oauth_dependencies.http_client, 1274 oauth_client_config: oauth_dependencies.oauth_client_config, 1275 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1276 test_db.pool.clone(), 1277 )), 1278 did_document_storage: std::sync::Arc::new(did_document_storage), 1279 key_resolver: oauth_dependencies.key_resolver, 1280 identity_resolver: std::sync::Arc::new(identity_resolver), 1281 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1282 }; 1283 1284 let app = routers::router(state.clone()); 1285 1286 let start_request = Request::builder() 1287 .uri(format!("/api/auth/atproto/start?subject={subject}")) 1288 .body(Body::empty()) 1289 .expect("start request"); 1290 1291 let start_response = app 1292 .oneshot(start_request) 1293 .await 1294 .expect("start response"); 1295 assert_eq!(start_response.status(), StatusCode::BAD_GATEWAY); 1296 1297 test_db.cleanup().await.expect("cleanup db"); 1298} 1299 1300#[tokio::test] 1301async fn oauth_start_rejects_auth_server_missing_par_requirement() { 1302 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1303 eprintln!( 1304 "DATABASE_URL not set, skipping oauth_start_rejects_auth_server_missing_par_requirement" 1305 ); 1306 return; 1307 }; 1308 1309 let server = MockServer::start(); 1310 let base_url = server.base_url(); 1311 1312 server.mock(|when, then| { 1313 when.method(GET) 1314 .path("/.well-known/oauth-protected-resource"); 1315 then.status(200).json_body(serde_json::json!({ 1316 "resource": base_url, 1317 "authorization_servers": [base_url] 1318 })); 1319 }); 1320 1321 server.mock(|when, then| { 1322 when.method(GET) 1323 .path("/.well-known/oauth-authorization-server"); 1324 then.status(200).json_body(serde_json::json!({ 1325 "issuer": base_url, 1326 "authorization_endpoint": format!("{}/authorize", base_url), 1327 "token_endpoint": format!("{}/token", base_url), 1328 "pushed_authorization_request_endpoint": format!("{}/par", base_url), 1329 "authorization_response_iss_parameter_supported": true, 1330 "client_id_metadata_document_supported": true, 1331 "code_challenge_methods_supported": ["S256"], 1332 "dpop_signing_alg_values_supported": ["ES256"], 1333 "grant_types_supported": ["authorization_code", "refresh_token"], 1334 "response_types_supported": ["code"], 1335 "scopes_supported": ["atproto", "transition:generic"], 1336 "token_endpoint_auth_methods_supported": ["none", "private_key_jwt"], 1337 "token_endpoint_auth_signing_alg_values_supported": ["ES256"], 1338 "require_pushed_authorization_requests": false, 1339 "request_parameter_supported": true 1340 })); 1341 }); 1342 1343 let test_db = TestDb::setup(&database_url) 1344 .await 1345 .expect("setup db"); 1346 let config = test_config(&test_db.schema_url); 1347 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1348 1349 let subject = "henry.test"; 1350 let did = "did:plc:henry"; 1351 let did_document = Document::builder() 1352 .id(did) 1353 .add_also_known_as(format!("at://{subject}")) 1354 .add_pds_service(&base_url) 1355 .build() 1356 .expect("build did document"); 1357 1358 let mut identity_resolver = TestIdentityResolver::default(); 1359 identity_resolver.insert(subject, did_document); 1360 1361 let did_document_storage = 1362 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1363 1364 let state = AppState { 1365 config: config.clone(), 1366 whisper_client: WhisperClient::new(config.clone()), 1367 logger: slipnote_backend::logging::Logger::disabled(), 1368 db_pool: test_db.pool.clone(), 1369 http_client: oauth_dependencies.http_client, 1370 oauth_client_config: oauth_dependencies.oauth_client_config, 1371 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1372 test_db.pool.clone(), 1373 )), 1374 did_document_storage: std::sync::Arc::new(did_document_storage), 1375 key_resolver: oauth_dependencies.key_resolver, 1376 identity_resolver: std::sync::Arc::new(identity_resolver), 1377 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1378 }; 1379 1380 let app = routers::router(state.clone()); 1381 1382 let start_request = Request::builder() 1383 .uri(format!("/api/auth/atproto/start?subject={subject}")) 1384 .body(Body::empty()) 1385 .expect("start request"); 1386 1387 let start_response = app 1388 .oneshot(start_request) 1389 .await 1390 .expect("start response"); 1391 assert_eq!(start_response.status(), StatusCode::BAD_GATEWAY); 1392 1393 test_db.cleanup().await.expect("cleanup db"); 1394} 1395 1396#[tokio::test] 1397async fn oauth_callback_rejects_token_response_missing_sub() { 1398 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1399 eprintln!( 1400 "DATABASE_URL not set, skipping oauth_callback_rejects_token_response_missing_sub" 1401 ); 1402 return; 1403 }; 1404 1405 let server = MockServer::start(); 1406 mock_oauth_server_with_token( 1407 &server, 1408 200, 1409 serde_json::json!({ 1410 "access_token": "access123", 1411 "token_type": "DPoP", 1412 "refresh_token": "refresh123", 1413 "scope": "atproto transition:generic", 1414 "expires_in": 3600 1415 }), 1416 ); 1417 1418 let test_db = TestDb::setup(&database_url) 1419 .await 1420 .expect("setup db"); 1421 1422 let config = test_config(&test_db.schema_url); 1423 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1424 1425 let pds_endpoint = server.base_url(); 1426 let subject = "ivy.test"; 1427 let did = "did:plc:ivy"; 1428 let did_document = Document::builder() 1429 .id(did) 1430 .add_also_known_as(format!("at://{subject}")) 1431 .add_pds_service(&pds_endpoint) 1432 .build() 1433 .expect("build did document"); 1434 1435 let mut identity_resolver = TestIdentityResolver::default(); 1436 identity_resolver.insert(subject, did_document.clone()); 1437 identity_resolver.insert(did, did_document); 1438 1439 let did_document_storage = 1440 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1441 1442 let state = AppState { 1443 config: config.clone(), 1444 whisper_client: WhisperClient::new(config.clone()), 1445 logger: slipnote_backend::logging::Logger::disabled(), 1446 db_pool: test_db.pool.clone(), 1447 http_client: oauth_dependencies.http_client, 1448 oauth_client_config: oauth_dependencies.oauth_client_config, 1449 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1450 test_db.pool.clone(), 1451 )), 1452 did_document_storage: std::sync::Arc::new(did_document_storage), 1453 key_resolver: oauth_dependencies.key_resolver, 1454 identity_resolver: std::sync::Arc::new(identity_resolver), 1455 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1456 }; 1457 1458 let app = routers::router(state.clone()); 1459 1460 let start_request = Request::builder() 1461 .uri(format!("/api/auth/atproto/start?subject={subject}")) 1462 .body(Body::empty()) 1463 .expect("start request"); 1464 1465 let start_response = app 1466 .clone() 1467 .oneshot(start_request) 1468 .await 1469 .expect("start response"); 1470 assert_eq!(start_response.status(), StatusCode::FOUND); 1471 1472 let oauth_request = sqlx::query("SELECT oauth_state, issuer FROM oauth_requests") 1473 .fetch_one(&test_db.pool) 1474 .await 1475 .expect("fetch oauth request"); 1476 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 1477 let issuer: String = oauth_request.try_get("issuer").expect("issuer"); 1478 1479 let callback_request = Request::builder() 1480 .uri(format!( 1481 "/oauth/callback?code=authcode&state={}&iss={}", 1482 oauth_state, issuer 1483 )) 1484 .body(Body::empty()) 1485 .expect("callback request"); 1486 1487 let callback_response = app 1488 .oneshot(callback_request) 1489 .await 1490 .expect("callback response"); 1491 1492 assert_eq!(callback_response.status(), StatusCode::INTERNAL_SERVER_ERROR); 1493 1494 test_db.cleanup().await.expect("cleanup db"); 1495} 1496 1497#[tokio::test] 1498async fn oauth_callback_rejects_token_endpoint_failure() { 1499 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1500 eprintln!("DATABASE_URL not set, skipping oauth_callback_rejects_token_endpoint_failure"); 1501 return; 1502 }; 1503 1504 let server = MockServer::start(); 1505 mock_oauth_server_with_token( 1506 &server, 1507 500, 1508 serde_json::json!({ "error": "server_error" }), 1509 ); 1510 1511 let test_db = TestDb::setup(&database_url) 1512 .await 1513 .expect("setup db"); 1514 1515 let config = test_config(&test_db.schema_url); 1516 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1517 1518 let pds_endpoint = server.base_url(); 1519 let subject = "jane.test"; 1520 let did = "did:plc:jane"; 1521 let did_document = Document::builder() 1522 .id(did) 1523 .add_also_known_as(format!("at://{subject}")) 1524 .add_pds_service(&pds_endpoint) 1525 .build() 1526 .expect("build did document"); 1527 1528 let mut identity_resolver = TestIdentityResolver::default(); 1529 identity_resolver.insert(subject, did_document.clone()); 1530 identity_resolver.insert(did, did_document); 1531 1532 let did_document_storage = 1533 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1534 1535 let state = AppState { 1536 config: config.clone(), 1537 whisper_client: WhisperClient::new(config.clone()), 1538 logger: slipnote_backend::logging::Logger::disabled(), 1539 db_pool: test_db.pool.clone(), 1540 http_client: oauth_dependencies.http_client, 1541 oauth_client_config: oauth_dependencies.oauth_client_config, 1542 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1543 test_db.pool.clone(), 1544 )), 1545 did_document_storage: std::sync::Arc::new(did_document_storage), 1546 key_resolver: oauth_dependencies.key_resolver, 1547 identity_resolver: std::sync::Arc::new(identity_resolver), 1548 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1549 }; 1550 1551 let app = routers::router(state.clone()); 1552 1553 let start_request = Request::builder() 1554 .uri(format!("/api/auth/atproto/start?subject={subject}")) 1555 .body(Body::empty()) 1556 .expect("start request"); 1557 1558 let start_response = app 1559 .clone() 1560 .oneshot(start_request) 1561 .await 1562 .expect("start response"); 1563 assert_eq!(start_response.status(), StatusCode::FOUND); 1564 1565 let oauth_request = sqlx::query("SELECT oauth_state, issuer FROM oauth_requests") 1566 .fetch_one(&test_db.pool) 1567 .await 1568 .expect("fetch oauth request"); 1569 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 1570 let issuer: String = oauth_request.try_get("issuer").expect("issuer"); 1571 1572 let callback_request = Request::builder() 1573 .uri(format!( 1574 "/oauth/callback?code=authcode&state={}&iss={}", 1575 oauth_state, issuer 1576 )) 1577 .body(Body::empty()) 1578 .expect("callback request"); 1579 1580 let callback_response = app 1581 .oneshot(callback_request) 1582 .await 1583 .expect("callback response"); 1584 1585 assert_eq!(callback_response.status(), StatusCode::INTERNAL_SERVER_ERROR); 1586 1587 test_db.cleanup().await.expect("cleanup db"); 1588} 1589 1590#[tokio::test] 1591async fn oauth_callback_rejects_malformed_token_response() { 1592 let Some(database_url) = env::var("DATABASE_URL").ok() else { 1593 eprintln!( 1594 "DATABASE_URL not set, skipping oauth_callback_rejects_malformed_token_response" 1595 ); 1596 return; 1597 }; 1598 1599 let server = MockServer::start(); 1600 mock_oauth_server_with_token( 1601 &server, 1602 200, 1603 serde_json::json!({ "access_token": 123 }), 1604 ); 1605 1606 let test_db = TestDb::setup(&database_url) 1607 .await 1608 .expect("setup db"); 1609 1610 let config = test_config(&test_db.schema_url); 1611 let oauth_dependencies = build_oauth_dependencies(&config, test_db.pool.clone()); 1612 1613 let pds_endpoint = server.base_url(); 1614 let subject = "kate.test"; 1615 let did = "did:plc:kate"; 1616 let did_document = Document::builder() 1617 .id(did) 1618 .add_also_known_as(format!("at://{subject}")) 1619 .add_pds_service(&pds_endpoint) 1620 .build() 1621 .expect("build did document"); 1622 1623 let mut identity_resolver = TestIdentityResolver::default(); 1624 identity_resolver.insert(subject, did_document.clone()); 1625 identity_resolver.insert(did, did_document); 1626 1627 let did_document_storage = 1628 LruDidDocumentStorage::new(NonZeroUsize::new(128).expect("storage size")); 1629 1630 let state = AppState { 1631 config: config.clone(), 1632 whisper_client: WhisperClient::new(config.clone()), 1633 logger: slipnote_backend::logging::Logger::disabled(), 1634 db_pool: test_db.pool.clone(), 1635 http_client: oauth_dependencies.http_client, 1636 oauth_client_config: oauth_dependencies.oauth_client_config, 1637 oauth_request_storage: std::sync::Arc::new(DbOAuthRequestStorage::new( 1638 test_db.pool.clone(), 1639 )), 1640 did_document_storage: std::sync::Arc::new(did_document_storage), 1641 key_resolver: oauth_dependencies.key_resolver, 1642 identity_resolver: std::sync::Arc::new(identity_resolver), 1643 oauth_signing_key: oauth_dependencies.oauth_signing_key, 1644 }; 1645 1646 let app = routers::router(state.clone()); 1647 1648 let start_request = Request::builder() 1649 .uri(format!("/api/auth/atproto/start?subject={subject}")) 1650 .body(Body::empty()) 1651 .expect("start request"); 1652 1653 let start_response = app 1654 .clone() 1655 .oneshot(start_request) 1656 .await 1657 .expect("start response"); 1658 assert_eq!(start_response.status(), StatusCode::FOUND); 1659 1660 let oauth_request = sqlx::query("SELECT oauth_state, issuer FROM oauth_requests") 1661 .fetch_one(&test_db.pool) 1662 .await 1663 .expect("fetch oauth request"); 1664 let oauth_state: String = oauth_request.try_get("oauth_state").expect("oauth_state"); 1665 let issuer: String = oauth_request.try_get("issuer").expect("issuer"); 1666 1667 let callback_request = Request::builder() 1668 .uri(format!( 1669 "/oauth/callback?code=authcode&state={}&iss={}", 1670 oauth_state, issuer 1671 )) 1672 .body(Body::empty()) 1673 .expect("callback request"); 1674 1675 let callback_response = app 1676 .oneshot(callback_request) 1677 .await 1678 .expect("callback response"); 1679 1680 assert_eq!(callback_response.status(), StatusCode::INTERNAL_SERVER_ERROR); 1681 1682 test_db.cleanup().await.expect("cleanup db"); 1683}