+4
-4
crates/atproto-oauth-aip/src/lib.rs
+4
-4
crates/atproto-oauth-aip/src/lib.rs
···
62
62
//! &http_client,
63
63
//! &oauth_client,
64
64
//! Some("user_handle"),
65
-
//! &authorization_server,
65
+
//! &authorization_server.pushed_authorization_request_endpoint,
66
66
//! &oauth_request_state
67
67
//! ).await?;
68
68
//!
···
70
70
//! # let oauth_request = atproto_oauth::workflow::OAuthRequest {
71
71
//! # oauth_state: "state".to_string(),
72
72
//! # issuer: "https://auth.example.com".to_string(),
73
-
//! # did: "did:plc:example".to_string(),
73
+
//! # authorization_server: "https://auth.example.com".to_string(),
74
74
//! # nonce: "nonce".to_string(),
75
75
//! # signing_public_key: "public_key".to_string(),
76
76
//! # pkce_verifier: "verifier".to_string(),
···
81
81
//! let token_response = oauth_complete(
82
82
//! &http_client,
83
83
//! &oauth_client,
84
-
//! &authorization_server,
84
+
//! &authorization_server.token_endpoint,
85
85
//! "authorization_code",
86
86
//! &oauth_request
87
87
//! ).await?;
···
95
95
//! # };
96
96
//! let session = session_exchange(
97
97
//! &http_client,
98
-
//! &protected_resource,
98
+
//! &protected_resource.resource,
99
99
//! &token_response.access_token
100
100
//! ).await?;
101
101
//! # Ok(())
+19
-23
crates/atproto-oauth-aip/src/workflow.rs
+19
-23
crates/atproto-oauth-aip/src/workflow.rs
···
69
69
//! &http_client,
70
70
//! &oauth_client,
71
71
//! Some("user.bsky.social"),
72
-
//! &authorization_server,
72
+
//! &authorization_server.pushed_authorization_request_endpoint,
73
73
//! &oauth_request_state
74
74
//! ).await?;
75
75
//!
···
79
79
//! # let oauth_request = OAuthRequest {
80
80
//! # oauth_state: "state".to_string(),
81
81
//! # issuer: "https://auth.example.com".to_string(),
82
-
//! # did: "did:plc:example".to_string(),
82
+
//! # authorization_server: "https://auth.example.com".to_string(),
83
83
//! # nonce: "nonce".to_string(),
84
84
//! # signing_public_key: "public_key".to_string(),
85
85
//! # pkce_verifier: "verifier".to_string(),
···
90
90
//! let token_response = oauth_complete(
91
91
//! &http_client,
92
92
//! &oauth_client,
93
-
//! &authorization_server,
93
+
//! &authorization_server.token_endpoint,
94
94
//! "received_auth_code",
95
95
//! &oauth_request
96
96
//! ).await?;
···
104
104
//! # };
105
105
//! let session = session_exchange(
106
106
//! &http_client,
107
-
//! &protected_resource,
107
+
//! &protected_resource.resource,
108
108
//! &token_response.access_token
109
109
//! ).await?;
110
110
//! # Ok(())
···
118
118
//! and protocol violations.
119
119
120
120
use anyhow::Result;
121
-
use atproto_oauth::{
122
-
resources::{AuthorizationServer, OAuthProtectedResource},
123
-
workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse},
124
-
};
121
+
use atproto_oauth::workflow::{OAuthRequest, OAuthRequestState, ParResponse, TokenResponse};
125
122
use serde::Deserialize;
126
123
127
124
use crate::errors::OAuthWorkflowError;
···
235
232
/// client_id: "client123".to_string(),
236
233
/// client_secret: "secret456".to_string(),
237
234
/// };
238
-
/// # let authorization_server = todo!();
235
+
/// # let authorization_server = "https://auth.example.com/par";
239
236
/// let oauth_request_state = OAuthRequestState {
240
237
/// state: "random-state".to_string(),
241
238
/// nonce: "random-nonce".to_string(),
···
246
243
/// &http_client,
247
244
/// &oauth_client,
248
245
/// Some("alice.bsky.social"),
249
-
/// &authorization_server,
246
+
/// authorization_server,
250
247
/// &oauth_request_state,
251
248
/// ).await?;
252
249
/// # Ok(())
···
256
253
http_client: &reqwest::Client,
257
254
oauth_client: &OAuthClient,
258
255
login_hint: Option<&str>,
259
-
authorization_server: &AuthorizationServer,
256
+
par_url: &str,
260
257
oauth_request_state: &OAuthRequestState,
261
258
) -> Result<ParResponse> {
262
-
let par_url = authorization_server
263
-
.pushed_authorization_request_endpoint
264
-
.clone();
265
-
266
259
let scope = &oauth_request_state.scope;
267
260
268
261
let mut params = vec![
···
337
330
/// use atproto_oauth_aip::workflow::oauth_complete;
338
331
/// # let http_client = reqwest::Client::new();
339
332
/// # let oauth_client = todo!();
340
-
/// # let authorization_server = todo!();
333
+
/// # let token_endpoint = "https://auth.example.com/token";
341
334
/// # let oauth_request = todo!();
342
335
/// let token_response = oauth_complete(
343
336
/// &http_client,
344
337
/// &oauth_client,
345
-
/// &authorization_server,
338
+
/// token_endpoint,
346
339
/// "auth_code_from_callback",
347
340
/// &oauth_request,
348
341
/// ).await?;
···
353
346
pub async fn oauth_complete(
354
347
http_client: &reqwest::Client,
355
348
oauth_client: &OAuthClient,
356
-
authorization_server: &AuthorizationServer,
349
+
token_endpoint: &str,
357
350
callback_code: &str,
358
351
oauth_request: &OAuthRequest,
359
352
) -> Result<TokenResponse> {
···
366
359
];
367
360
368
361
http_client
369
-
.post(&authorization_server.token_endpoint)
362
+
.post(token_endpoint)
370
363
.basic_auth(
371
364
oauth_client.client_id.as_str(),
372
365
Some(oauth_client.client_secret.as_str()),
···
374
367
.form(¶ms)
375
368
.send()
376
369
.await
370
+
.inspect(|value| {
371
+
println!("{value:?}");
372
+
})
377
373
.map_err(OAuthWorkflowError::TokenRequestFailed)?
378
374
.json()
379
375
.await
···
404
400
/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
405
401
/// use atproto_oauth_aip::workflow::session_exchange;
406
402
/// # let http_client = reqwest::Client::new();
407
-
/// # let protected_resource = todo!();
403
+
/// # let protected_resource = "https://pds.example.com";
408
404
/// # let access_token = "example_token";
409
405
/// let session = session_exchange(
410
406
/// &http_client,
411
-
/// &protected_resource,
407
+
/// protected_resource,
412
408
/// access_token,
413
409
/// ).await?;
414
410
/// println!("Authenticated as {} ({})", session.handle, session.did);
···
418
414
/// ```
419
415
pub async fn session_exchange(
420
416
http_client: &reqwest::Client,
421
-
protected_resource: &OAuthProtectedResource,
417
+
protected_resource_base: &str,
422
418
access_token: &str,
423
419
) -> Result<ATProtocolSession> {
424
420
let response = http_client
425
421
.get(format!(
426
422
"{}/api/atprotocol/session",
427
-
protected_resource.resource
423
+
protected_resource_base
428
424
))
429
425
.bearer_auth(access_token)
430
426
.send()
+1
-1
crates/atproto-oauth-axum/src/bin/atproto-oauth-tool.rs
+1
-1
crates/atproto-oauth-axum/src/bin/atproto-oauth-tool.rs
···
457
457
let oauth_request = OAuthRequest {
458
458
oauth_state: state.clone(),
459
459
issuer: authorization_server.issuer.clone(),
460
-
did: did.clone(),
460
+
authorization_server: authorization_server.issuer.clone(),
461
461
nonce: nonce.clone(),
462
462
pkce_verifier,
463
463
signing_public_key: public_signing_key.to_string(),
+18
-9
crates/atproto-oauth-axum/src/handle_complete.rs
+18
-9
crates/atproto-oauth-axum/src/handle_complete.rs
···
10
10
};
11
11
use atproto_oauth::{
12
12
axum::state::OAuthRequestStorageExtractor,
13
+
resources::pds_resources,
13
14
workflow::{OAuthClient, oauth_complete},
14
15
};
15
16
use axum::{
···
74
75
});
75
76
}
76
77
77
-
let document = did_document_storage
78
-
.0
79
-
.get_document_by_did(&oauth_request.did)
80
-
.await?;
81
-
82
-
let document = document.ok_or(OAuthCallbackError::NoDIDDocumentFound)?;
83
-
84
78
let private_signing_key_data = key_provider
85
79
.0
86
80
.get_private_key_by_id(&oauth_request.signing_public_key)
···
97
91
private_signing_key_data,
98
92
};
99
93
94
+
// We need to get the DID from the token response after OAuth completion
95
+
// First, get authorization server from the issuer to complete OAuth
96
+
let (_, authorization_server) =
97
+
pds_resources(&client, &oauth_request.authorization_server).await?;
98
+
100
99
let token_response = oauth_complete(
101
100
&client,
102
101
&oauth_client,
103
102
&private_dpop_key_data,
104
103
&callback_form.code,
105
104
&oauth_request,
106
-
&document,
105
+
&authorization_server,
107
106
)
108
107
.await?;
108
+
109
+
// Now get the DID from the token response subject claim
110
+
let did = token_response
111
+
.sub
112
+
.clone()
113
+
.ok_or(OAuthCallbackError::NoDIDDocumentFound)?;
114
+
115
+
let document = did_document_storage.0.get_document_by_did(&did).await?;
116
+
117
+
let document = document.ok_or(OAuthCallbackError::NoDIDDocumentFound)?;
109
118
110
119
// Format the response with OAuth tokens and DPoP key information
111
120
let response_body = format!(
···
120
129
Scope: {}\n\
121
130
Subject: {}\n\
122
131
Private DPoP Key: {}\n",
123
-
oauth_request.did,
132
+
document.id,
124
133
oauth_request.issuer,
125
134
token_response.access_token,
126
135
token_response.refresh_token,
+3
-3
crates/atproto-oauth/src/storage.rs
+3
-3
crates/atproto-oauth/src/storage.rs
···
114
114
/// async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<()> {
115
115
/// sqlx::query!(
116
116
/// "INSERT INTO oauth_requests
117
-
/// (oauth_state, issuer, did, nonce, pkce_verifier, signing_public_key,
117
+
/// (oauth_state, issuer, authorization_server, nonce, pkce_verifier, signing_public_key,
118
118
/// dpop_private_key, created_at, expires_at)
119
119
/// VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
120
120
/// request.oauth_state,
121
121
/// request.issuer,
122
-
/// request.did,
122
+
/// request.authorization_server,
123
123
/// request.nonce,
124
124
/// request.pkce_verifier,
125
125
/// request.signing_public_key,
···
177
177
/// let request = storage.get_oauth_request_by_state("unique-state-value").await?;
178
178
/// match request {
179
179
/// Some(req) => {
180
-
/// println!("Found OAuth request for DID: {}", req.did);
180
+
/// println!("Found OAuth request for issuer: {}", req.issuer);
181
181
/// // Continue with OAuth flow
182
182
/// },
183
183
/// None => {
+16
-16
crates/atproto-oauth/src/storage_lru.rs
+16
-16
crates/atproto-oauth/src/storage_lru.rs
···
73
73
/// let request = OAuthRequest {
74
74
/// oauth_state: "unique-state-123".to_string(),
75
75
/// issuer: "https://pds.example.com".to_string(),
76
-
/// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(),
76
+
/// authorization_server: "https://pds.example.com".to_string(),
77
77
/// nonce: "secure-nonce".to_string(),
78
78
/// pkce_verifier: "code-verifier".to_string(),
79
79
/// signing_public_key: "public-key-data".to_string(),
···
185
185
/// let request = OAuthRequest {
186
186
/// oauth_state: "state1".to_string(),
187
187
/// issuer: "https://pds.example.com".to_string(),
188
-
/// did: "did:plc:example1".to_string(),
188
+
/// authorization_server: "https://pds.example.com".to_string(),
189
189
/// nonce: "nonce1".to_string(),
190
190
/// pkce_verifier: "verifier1".to_string(),
191
191
/// signing_public_key: "pubkey1".to_string(),
···
262
262
/// let request = OAuthRequest {
263
263
/// oauth_state: "test-state".to_string(),
264
264
/// issuer: "https://pds.example.com".to_string(),
265
-
/// did: "did:plc:example".to_string(),
265
+
/// authorization_server: "https://pds.example.com".to_string(),
266
266
/// nonce: "test-nonce".to_string(),
267
267
/// pkce_verifier: "test-verifier".to_string(),
268
268
/// signing_public_key: "test-pubkey".to_string(),
···
329
329
/// let oauth_req = OAuthRequest {
330
330
/// oauth_state: "valid-state-123".to_string(),
331
331
/// issuer: "https://pds.example.com".to_string(),
332
-
/// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(),
332
+
/// authorization_server: "https://pds.example.com".to_string(),
333
333
/// nonce: "secure-nonce".to_string(),
334
334
/// pkce_verifier: "code-verifier".to_string(),
335
335
/// signing_public_key: "public-key-data".to_string(),
···
402
402
/// let request = OAuthRequest {
403
403
/// oauth_state: "deletable-state".to_string(),
404
404
/// issuer: "https://pds.example.com".to_string(),
405
-
/// did: "did:plc:bv6ggog3tya2z3vxsub7hnal".to_string(),
405
+
/// authorization_server: "https://pds.example.com".to_string(),
406
406
/// nonce: "test-nonce".to_string(),
407
407
/// pkce_verifier: "test-verifier".to_string(),
408
408
/// signing_public_key: "test-pubkey".to_string(),
···
473
473
/// let req1 = OAuthRequest {
474
474
/// oauth_state: "state1".to_string(),
475
475
/// issuer: "https://pds.example.com".to_string(),
476
-
/// did: "did:plc:user1".to_string(),
476
+
/// authorization_server: "https://pds.example.com".to_string(),
477
477
/// nonce: "nonce1".to_string(),
478
478
/// pkce_verifier: "verifier1".to_string(),
479
479
/// signing_public_key: "pubkey1".to_string(),
···
488
488
/// let req2 = OAuthRequest {
489
489
/// oauth_state: "state2".to_string(),
490
490
/// issuer: "https://pds.example.com".to_string(),
491
-
/// did: "did:plc:user2".to_string(),
491
+
/// authorization_server: "https://pds.example.com".to_string(),
492
492
/// nonce: "nonce2".to_string(),
493
493
/// pkce_verifier: "verifier2".to_string(),
494
494
/// signing_public_key: "pubkey2".to_string(),
···
503
503
/// let req3 = OAuthRequest {
504
504
/// oauth_state: "state3".to_string(),
505
505
/// issuer: "https://pds.example.com".to_string(),
506
-
/// did: "did:plc:user3".to_string(),
506
+
/// authorization_server: "https://pds.example.com".to_string(),
507
507
/// nonce: "nonce3".to_string(),
508
508
/// pkce_verifier: "verifier3".to_string(),
509
509
/// signing_public_key: "pubkey3".to_string(),
···
571
571
/// let expired_request = OAuthRequest {
572
572
/// oauth_state: "soon-expired".to_string(),
573
573
/// issuer: "https://pds.example.com".to_string(),
574
-
/// did: "did:plc:user1".to_string(),
574
+
/// authorization_server: "https://pds.example.com".to_string(),
575
575
/// nonce: "nonce1".to_string(),
576
576
/// pkce_verifier: "verifier1".to_string(),
577
577
/// signing_public_key: "pubkey1".to_string(),
···
585
585
/// let valid_request = OAuthRequest {
586
586
/// oauth_state: "still-valid".to_string(),
587
587
/// issuer: "https://pds.example.com".to_string(),
588
-
/// did: "did:plc:user2".to_string(),
588
+
/// authorization_server: "https://pds.example.com".to_string(),
589
589
/// nonce: "nonce2".to_string(),
590
590
/// pkce_verifier: "verifier2".to_string(),
591
591
/// signing_public_key: "pubkey2".to_string(),
···
645
645
use chrono::{Duration, Utc};
646
646
use std::num::NonZeroUsize;
647
647
648
-
fn create_test_oauth_request(state: &str, issuer: &str, did: &str) -> OAuthRequest {
648
+
fn create_test_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest {
649
649
OAuthRequest {
650
650
oauth_state: state.to_string(),
651
651
issuer: issuer.to_string(),
652
-
did: did.to_string(),
652
+
authorization_server: issuer.to_string(),
653
653
nonce: format!("nonce-{}", state),
654
654
pkce_verifier: format!("verifier-{}", state),
655
655
signing_public_key: format!("pubkey-{}", state),
···
659
659
}
660
660
}
661
661
662
-
fn create_expired_oauth_request(state: &str, issuer: &str, did: &str) -> OAuthRequest {
662
+
fn create_expired_oauth_request(state: &str, issuer: &str, _did: &str) -> OAuthRequest {
663
663
OAuthRequest {
664
664
oauth_state: state.to_string(),
665
665
issuer: issuer.to_string(),
666
-
did: did.to_string(),
666
+
authorization_server: issuer.to_string(),
667
667
nonce: format!("nonce-{}", state),
668
668
pkce_verifier: format!("verifier-{}", state),
669
669
signing_public_key: format!("pubkey-{}", state),
···
764
764
// req1 and req3 should be present, req2 should be evicted
765
765
let result1 = storage.get_oauth_request_by_state("state1").await?;
766
766
assert!(result1.is_some());
767
-
assert_eq!(result1.unwrap().did, req1.did);
767
+
assert_eq!(result1.unwrap().oauth_state, req1.oauth_state);
768
768
769
769
let result3 = storage.get_oauth_request_by_state("state3").await?;
770
770
assert!(result3.is_some());
771
-
assert_eq!(result3.unwrap().did, req3.did);
771
+
assert_eq!(result3.unwrap().oauth_state, req3.oauth_state);
772
772
773
773
assert_eq!(storage.get_oauth_request_by_state("state2").await?, None);
774
774
+9
-9
crates/atproto-oauth/src/workflow.rs
+9
-9
crates/atproto-oauth/src/workflow.rs
···
158
158
#[cfg_attr(feature = "zeroize", zeroize(skip))]
159
159
pub issuer: String,
160
160
161
-
/// The DID (Decentralized Identifier) of the user.
161
+
/// The authorization server identifier.
162
162
#[cfg_attr(feature = "zeroize", zeroize(skip))]
163
-
pub did: String,
163
+
pub authorization_server: String,
164
164
165
165
/// The nonce value for additional security.
166
166
#[cfg_attr(feature = "zeroize", zeroize(skip))]
···
191
191
f.debug_struct("OAuthRequest")
192
192
.field("oauth_state", &self.oauth_state)
193
193
.field("issuer", &self.issuer)
194
-
.field("did", &self.did)
194
+
.field("authorization_server", &self.authorization_server)
195
195
.field("nonce", &self.nonce)
196
196
.field("pkce_verifier", &"[REDACTED]")
197
197
.field("signing_public_key", &self.signing_public_key)
···
356
356
dpop_key_data: &KeyData,
357
357
callback_code: &str,
358
358
oauth_request: &OAuthRequest,
359
-
document: &atproto_identity::model::Document,
359
+
authorization_server: &AuthorizationServer,
360
360
) -> Result<TokenResponse, OAuthClientError> {
361
-
let pds_endpoints = document.pds_endpoints();
362
-
let pds_endpoint = pds_endpoints
363
-
.first()
364
-
.ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
365
-
let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
361
+
// let pds_endpoints = document.pds_endpoints();
362
+
// let pds_endpoint = pds_endpoints
363
+
// .first()
364
+
// .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
365
+
// let (_, authorization_server) = pds_resources(http_client, pds_endpoint).await?;
366
366
367
367
let client_assertion_header: Header = (oauth_client.private_signing_key_data.clone())
368
368
.try_into()