A library for ATProtocol identities.

refactor: adjusting oauth components to support blind oauth workflows

Signed-off-by: Nick Gerakines <nick.gerakines@gmail.com>

Changed files
+70 -65
crates
atproto-oauth
atproto-oauth-aip
atproto-oauth-axum
+4 -4
crates/atproto-oauth-aip/src/lib.rs
··· 62 62 //! &http_client, 63 63 //! &oauth_client, 64 64 //! Some("user_handle"), 65 - //! &authorization_server, 65 + //! &authorization_server.pushed_authorization_request_endpoint, 66 66 //! &oauth_request_state 67 67 //! ).await?; 68 68 //! ··· 70 70 //! # let oauth_request = atproto_oauth::workflow::OAuthRequest { 71 71 //! # oauth_state: "state".to_string(), 72 72 //! # issuer: "https://auth.example.com".to_string(), 73 - //! # did: "did:plc:example".to_string(), 73 + //! # authorization_server: "https://auth.example.com".to_string(), 74 74 //! # nonce: "nonce".to_string(), 75 75 //! # signing_public_key: "public_key".to_string(), 76 76 //! # pkce_verifier: "verifier".to_string(), ··· 81 81 //! let token_response = oauth_complete( 82 82 //! &http_client, 83 83 //! &oauth_client, 84 - //! &authorization_server, 84 + //! &authorization_server.token_endpoint, 85 85 //! "authorization_code", 86 86 //! &oauth_request 87 87 //! ).await?; ··· 95 95 //! # }; 96 96 //! let session = session_exchange( 97 97 //! &http_client, 98 - //! &protected_resource, 98 + //! &protected_resource.resource, 99 99 //! &token_response.access_token 100 100 //! ).await?; 101 101 //! # Ok(())
+19 -23
crates/atproto-oauth-aip/src/workflow.rs
··· 69 69 //! &http_client, 70 70 //! &oauth_client, 71 71 //! Some("user.bsky.social"), 72 - //! &authorization_server, 72 + //! &authorization_server.pushed_authorization_request_endpoint, 73 73 //! &oauth_request_state 74 74 //! ).await?; 75 75 //! ··· 79 79 //! # let oauth_request = OAuthRequest { 80 80 //! # oauth_state: "state".to_string(), 81 81 //! # issuer: "https://auth.example.com".to_string(), 82 - //! # did: "did:plc:example".to_string(), 82 + //! # authorization_server: "https://auth.example.com".to_string(), 83 83 //! # nonce: "nonce".to_string(), 84 84 //! # signing_public_key: "public_key".to_string(), 85 85 //! # pkce_verifier: "verifier".to_string(), ··· 90 90 //! let token_response = oauth_complete( 91 91 //! &http_client, 92 92 //! &oauth_client, 93 - //! &authorization_server, 93 + //! &authorization_server.token_endpoint, 94 94 //! "received_auth_code", 95 95 //! &oauth_request 96 96 //! ).await?; ··· 104 104 //! # }; 105 105 //! let session = session_exchange( 106 106 //! &http_client, 107 - //! &protected_resource, 107 + //! &protected_resource.resource, 108 108 //! &token_response.access_token 109 109 //! ).await?; 110 110 //! # Ok(()) ··· 118 118 //! and protocol violations. 119 119 120 120 use anyhow::Result; 121 - use atproto_oauth::{ 122 - resources::{AuthorizationServer, OAuthProtectedResource}, 123 - workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse}, 124 - }; 121 + use atproto_oauth::workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse}; 125 122 use serde::Deserialize; 126 123 127 124 use crate::errors::OAuthWorkflowError; ··· 235 232 /// client_id: "client123".to_string(), 236 233 /// client_secret: "secret456".to_string(), 237 234 /// }; 238 - /// # let authorization_server = todo!(); 235 + /// # let authorization_server = "https://auth.example.com/par"; 239 236 /// let oauth_request_state = OAuthRequestState { 240 237 /// state: "random-state".to_string(), 241 238 /// nonce: "random-nonce".to_string(), ··· 246 243 /// &http_client, 247 244 /// &oauth_client, 248 245 /// Some("alice.bsky.social"), 249 - /// &authorization_server, 246 + /// authorization_server, 250 247 /// &oauth_request_state, 251 248 /// ).await?; 252 249 /// # Ok(()) ··· 256 253 http_client: &reqwest::Client, 257 254 oauth_client: &OAuthClient, 258 255 login_hint: Option<&str>, 259 - authorization_server: &AuthorizationServer, 256 + par_url: &str, 260 257 oauth_request_state: &OAuthRequestState, 261 258 ) -> Result<ParResponse> { 262 - let par_url = authorization_server 263 - .pushed_authorization_request_endpoint 264 - .clone(); 265 - 266 259 let scope = &oauth_request_state.scope; 267 260 268 261 let mut params = vec![ ··· 337 330 /// use atproto_oauth_aip::workflow::oauth_complete; 338 331 /// # let http_client = reqwest::Client::new(); 339 332 /// # let oauth_client = todo!(); 340 - /// # let authorization_server = todo!(); 333 + /// # let token_endpoint = "https://auth.example.com/token"; 341 334 /// # let oauth_request = todo!(); 342 335 /// let token_response = oauth_complete( 343 336 /// &http_client, 344 337 /// &oauth_client, 345 - /// &authorization_server, 338 + /// token_endpoint, 346 339 /// "auth_code_from_callback", 347 340 /// &oauth_request, 348 341 /// ).await?; ··· 353 346 pub async fn oauth_complete( 354 347 http_client: &reqwest::Client, 355 348 oauth_client: &OAuthClient, 356 - authorization_server: &AuthorizationServer, 349 + token_endpoint: &str, 357 350 callback_code: &str, 358 351 oauth_request: &OAuthRequest, 359 352 ) -> Result<TokenResponse> { ··· 366 359 ]; 367 360 368 361 http_client 369 - .post(&authorization_server.token_endpoint) 362 + .post(token_endpoint) 370 363 .basic_auth( 371 364 oauth_client.client_id.as_str(), 372 365 Some(oauth_client.client_secret.as_str()), ··· 374 367 .form(&params) 375 368 .send() 376 369 .await 370 + .inspect(|value| { 371 + println!("{value:?}"); 372 + }) 377 373 .map_err(OAuthWorkflowError::TokenRequestFailed)? 378 374 .json() 379 375 .await ··· 404 400 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> { 405 401 /// use atproto_oauth_aip::workflow::session_exchange; 406 402 /// # let http_client = reqwest::Client::new(); 407 - /// # let protected_resource = todo!(); 403 + /// # let protected_resource = "https://pds.example.com"; 408 404 /// # let access_token = "example_token"; 409 405 /// let session = session_exchange( 410 406 /// &http_client, 411 - /// &protected_resource, 407 + /// protected_resource, 412 408 /// access_token, 413 409 /// ).await?; 414 410 /// println!("Authenticated as {} ({})", session.handle, session.did); ··· 418 414 /// ``` 419 415 pub async fn session_exchange( 420 416 http_client: &reqwest::Client, 421 - protected_resource: &OAuthProtectedResource, 417 + protected_resource_base: &str, 422 418 access_token: &str, 423 419 ) -> Result<ATProtocolSession> { 424 420 let response = http_client 425 421 .get(format!( 426 422 "{}/api/atprotocol/session", 427 - protected_resource.resource 423 + protected_resource_base 428 424 )) 429 425 .bearer_auth(access_token) 430 426 .send()
+1 -1
crates/atproto-oauth-axum/src/bin/atproto-oauth-tool.rs
··· 457 457 let oauth_request = OAuthRequest { 458 458 oauth_state: state.clone(), 459 459 issuer: authorization_server.issuer.clone(), 460 - did: did.clone(), 460 + authorization_server: authorization_server.issuer.clone(), 461 461 nonce: nonce.clone(), 462 462 pkce_verifier, 463 463 signing_public_key: public_signing_key.to_string(),
+18 -9
crates/atproto-oauth-axum/src/handle_complete.rs
··· 10 10 }; 11 11 use atproto_oauth::{ 12 12 axum::state::OAuthRequestStorageExtractor, 13 + resources::pds_resources, 13 14 workflow::{OAuthClient, oauth_complete}, 14 15 }; 15 16 use axum::{ ··· 74 75 }); 75 76 } 76 77 77 - let document = did_document_storage 78 - .0 79 - .get_document_by_did(&oauth_request.did) 80 - .await?; 81 - 82 - let document = document.ok_or(OAuthCallbackError::NoDIDDocumentFound)?; 83 - 84 78 let private_signing_key_data = key_provider 85 79 .0 86 80 .get_private_key_by_id(&oauth_request.signing_public_key) ··· 97 91 private_signing_key_data, 98 92 }; 99 93 94 + // We need to get the DID from the token response after OAuth completion 95 + // First, get authorization server from the issuer to complete OAuth 96 + let (_, authorization_server) = 97 + pds_resources(&client, &oauth_request.authorization_server).await?; 98 + 100 99 let token_response = oauth_complete( 101 100 &client, 102 101 &oauth_client, 103 102 &private_dpop_key_data, 104 103 &callback_form.code, 105 104 &oauth_request, 106 - &document, 105 + &authorization_server, 107 106 ) 108 107 .await?; 108 + 109 + // Now get the DID from the token response subject claim 110 + let did = token_response 111 + .sub 112 + .clone() 113 + .ok_or(OAuthCallbackError::NoDIDDocumentFound)?; 114 + 115 + let document = did_document_storage.0.get_document_by_did(&did).await?; 116 + 117 + let document = document.ok_or(OAuthCallbackError::NoDIDDocumentFound)?; 109 118 110 119 // Format the response with OAuth tokens and DPoP key information 111 120 let response_body = format!( ··· 120 129 Scope: {}\n\ 121 130 Subject: {}\n\ 122 131 Private DPoP Key: {}\n", 123 - oauth_request.did, 132 + document.id, 124 133 oauth_request.issuer, 125 134 token_response.access_token, 126 135 token_response.refresh_token,
+3 -3
crates/atproto-oauth/src/storage.rs
··· 114 114 /// async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> { 115 115 /// sqlx::query!( 116 116 /// "INSERT INTO oauth_requests 117 - /// (oauth_state, issuer, did, nonce, pkce_verifier, signing_public_key, 117 + /// (oauth_state, issuer, authorization_server, nonce, pkce_verifier, signing_public_key, 118 118 /// dpop_private_key, created_at, expires_at) 119 119 /// VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", 120 120 /// request.oauth_state, 121 121 /// request.issuer, 122 - /// request.did, 122 + /// request.authorization_server, 123 123 /// request.nonce, 124 124 /// request.pkce_verifier, 125 125 /// request.signing_public_key, ··· 177 177 /// let request = storage.get_oauth_request_by_state("unique-state-value").await?; 178 178 /// match request { 179 179 /// Some(req) => { 180 - /// println!("Found OAuth request for DID: {}", req.did); 180 + /// println!("Found OAuth request for issuer: {}", req.issuer); 181 181 /// // Continue with OAuth flow 182 182 /// }, 183 183 /// None => {
+16 -16
crates/atproto-oauth/src/storage_lru.rs
··· 73 73 /// let request = OAuthRequest { 74 74 /// oauth_state: "unique-state-123".to_string(), 75 75 /// issuer: "https://pds.example.com".to_string(), 76 - /// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(), 76 + /// authorization_server: "https://pds.example.com".to_string(), 77 77 /// nonce: "secure-nonce".to_string(), 78 78 /// pkce_verifier: "code-verifier".to_string(), 79 79 /// signing_public_key: "public-key-data".to_string(), ··· 185 185 /// let request = OAuthRequest { 186 186 /// oauth_state: "state1".to_string(), 187 187 /// issuer: "https://pds.example.com".to_string(), 188 - /// did: "did:plc:example1".to_string(), 188 + /// authorization_server: "https://pds.example.com".to_string(), 189 189 /// nonce: "nonce1".to_string(), 190 190 /// pkce_verifier: "verifier1".to_string(), 191 191 /// signing_public_key: "pubkey1".to_string(), ··· 262 262 /// let request = OAuthRequest { 263 263 /// oauth_state: "test-state".to_string(), 264 264 /// issuer: "https://pds.example.com".to_string(), 265 - /// did: "did:plc:example".to_string(), 265 + /// authorization_server: "https://pds.example.com".to_string(), 266 266 /// nonce: "test-nonce".to_string(), 267 267 /// pkce_verifier: "test-verifier".to_string(), 268 268 /// signing_public_key: "test-pubkey".to_string(), ··· 329 329 /// let oauth_req = OAuthRequest { 330 330 /// oauth_state: "valid-state-123".to_string(), 331 331 /// issuer: "https://pds.example.com".to_string(), 332 - /// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(), 332 + /// authorization_server: "https://pds.example.com".to_string(), 333 333 /// nonce: "secure-nonce".to_string(), 334 334 /// pkce_verifier: "code-verifier".to_string(), 335 335 /// signing_public_key: "public-key-data".to_string(), ··· 402 402 /// let request = OAuthRequest { 403 403 /// oauth_state: "deletable-state".to_string(), 404 404 /// issuer: "https://pds.example.com".to_string(), 405 - /// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(), 405 + /// authorization_server: "https://pds.example.com".to_string(), 406 406 /// nonce: "test-nonce".to_string(), 407 407 /// pkce_verifier: "test-verifier".to_string(), 408 408 /// signing_public_key: "test-pubkey".to_string(), ··· 473 473 /// let req1 = OAuthRequest { 474 474 /// oauth_state: "state1".to_string(), 475 475 /// issuer: "https://pds.example.com".to_string(), 476 - /// did: "did:plc:user1".to_string(), 476 + /// authorization_server: "https://pds.example.com".to_string(), 477 477 /// nonce: "nonce1".to_string(), 478 478 /// pkce_verifier: "verifier1".to_string(), 479 479 /// signing_public_key: "pubkey1".to_string(), ··· 488 488 /// let req2 = OAuthRequest { 489 489 /// oauth_state: "state2".to_string(), 490 490 /// issuer: "https://pds.example.com".to_string(), 491 - /// did: "did:plc:user2".to_string(), 491 + /// authorization_server: "https://pds.example.com".to_string(), 492 492 /// nonce: "nonce2".to_string(), 493 493 /// pkce_verifier: "verifier2".to_string(), 494 494 /// signing_public_key: "pubkey2".to_string(), ··· 503 503 /// let req3 = OAuthRequest { 504 504 /// oauth_state: "state3".to_string(), 505 505 /// issuer: "https://pds.example.com".to_string(), 506 - /// did: "did:plc:user3".to_string(), 506 + /// authorization_server: "https://pds.example.com".to_string(), 507 507 /// nonce: "nonce3".to_string(), 508 508 /// pkce_verifier: "verifier3".to_string(), 509 509 /// signing_public_key: "pubkey3".to_string(), ··· 571 571 /// let expired_request = OAuthRequest { 572 572 /// oauth_state: "soon-expired".to_string(), 573 573 /// issuer: "https://pds.example.com".to_string(), 574 - /// did: "did:plc:user1".to_string(), 574 + /// authorization_server: "https://pds.example.com".to_string(), 575 575 /// nonce: "nonce1".to_string(), 576 576 /// pkce_verifier: "verifier1".to_string(), 577 577 /// signing_public_key: "pubkey1".to_string(), ··· 585 585 /// let valid_request = OAuthRequest { 586 586 /// oauth_state: "still-valid".to_string(), 587 587 /// issuer: "https://pds.example.com".to_string(), 588 - /// did: "did:plc:user2".to_string(), 588 + /// authorization_server: "https://pds.example.com".to_string(), 589 589 /// nonce: "nonce2".to_string(), 590 590 /// pkce_verifier: "verifier2".to_string(), 591 591 /// signing_public_key: "pubkey2".to_string(), ··· 645 645 use chrono::{Duration, Utc}; 646 646 use std::num::NonZeroUsize; 647 647 648 - fn create_test_oauth_request(state: &str, issuer: &str, did: &str) -> OAuthRequest { 648 + fn create_test_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest { 649 649 OAuthRequest { 650 650 oauth_state: state.to_string(), 651 651 issuer: issuer.to_string(), 652 - did: did.to_string(), 652 + authorization_server: issuer.to_string(), 653 653 nonce: format!("nonce-{}", state), 654 654 pkce_verifier: format!("verifier-{}", state), 655 655 signing_public_key: format!("pubkey-{}", state), ··· 659 659 } 660 660 } 661 661 662 - fn create_expired_oauth_request(state: &str, issuer: &str, did: &str) -> OAuthRequest { 662 + fn create_expired_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest { 663 663 OAuthRequest { 664 664 oauth_state: state.to_string(), 665 665 issuer: issuer.to_string(), 666 - did: did.to_string(), 666 + authorization_server: issuer.to_string(), 667 667 nonce: format!("nonce-{}", state), 668 668 pkce_verifier: format!("verifier-{}", state), 669 669 signing_public_key: format!("pubkey-{}", state), ··· 764 764 // req1 and req3 should be present, req2 should be evicted 765 765 let result1 = storage.get_oauth_request_by_state("state1").await?; 766 766 assert!(result1.is_some()); 767 - assert_eq!(result1.unwrap().did, req1.did); 767 + assert_eq!(result1.unwrap().oauth_state, req1.oauth_state); 768 768 769 769 let result3 = storage.get_oauth_request_by_state("state3").await?; 770 770 assert!(result3.is_some()); 771 - assert_eq!(result3.unwrap().did, req3.did); 771 + assert_eq!(result3.unwrap().oauth_state, req3.oauth_state); 772 772 773 773 assert_eq!(storage.get_oauth_request_by_state("state2").await?, None); 774 774
+9 -9
crates/atproto-oauth/src/workflow.rs
··· 158 158 #[cfg_attr(feature = "zeroize", zeroize(skip))] 159 159 pub issuer: String, 160 160 161 - /// The DID (Decentralized Identifier) of the user. 161 + /// The authorization server identifier. 162 162 #[cfg_attr(feature = "zeroize", zeroize(skip))] 163 - pub did: String, 163 + pub authorization_server: String, 164 164 165 165 /// The nonce value for additional security. 166 166 #[cfg_attr(feature = "zeroize", zeroize(skip))] ··· 191 191 f.debug_struct("OAuthRequest") 192 192 .field("oauth_state", &self.oauth_state) 193 193 .field("issuer", &self.issuer) 194 - .field("did", &self.did) 194 + .field("authorization_server", &self.authorization_server) 195 195 .field("nonce", &self.nonce) 196 196 .field("pkce_verifier", &"[REDACTED]") 197 197 .field("signing_public_key", &self.signing_public_key) ··· 356 356 dpop_key_data: &KeyData, 357 357 callback_code: &str, 358 358 oauth_request: &OAuthRequest, 359 - document: &atproto_identity::model::Document, 359 + authorization_server: &AuthorizationServer, 360 360 ) -> Result<TokenResponse, OAuthClientError> { 361 - let pds_endpoints = document.pds_endpoints(); 362 - let pds_endpoint = pds_endpoints 363 - .first() 364 - .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?; 365 - let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?; 361 + // let pds_endpoints = document.pds_endpoints(); 362 + // let pds_endpoint = pds_endpoints 363 + // .first() 364 + // .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?; 365 + // let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?; 366 366 367 367 let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone()) 368 368 .try_into()