forked from
smokesignal.events/smokesignal
Fork i18n + search + filtering- v0.2
1use dpop::DpopRetry;
2use p256::SecretKey;
3use rand::distributions::{Alphanumeric, DistString};
4use reqwest_chain::ChainMiddleware;
5use reqwest_middleware::ClientBuilder;
6use std::time::Duration;
7
8use crate::oauth_client_errors::OAuthClientError;
9use crate::oauth_errors::{AuthServerValidationError, ResourceValidationError};
10use model::{AuthorizationServer, OAuthProtectedResource, ParResponse, TokenResponse};
11
12use crate::{
13 jose::{
14 jwt::{Claims, Header, JoseClaims},
15 mint_token,
16 },
17 storage::{
18 handle::model::Handle,
19 oauth::model::{OAuthRequest, OAuthRequestState},
20 },
21};
22
23const HTTP_CLIENT_TIMEOUT_SECS: u64 = 8;
24
25pub async fn pds_resources(
26 http_client: &reqwest::Client,
27 pds: &str,
28) -> Result<(OAuthProtectedResource, AuthorizationServer), OAuthClientError> {
29 let protected_resource = oauth_protected_resource(http_client, pds).await?;
30
31 let first_authorization_server = protected_resource
32 .authorization_servers
33 .first()
34 .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
35
36 let authorization_server =
37 oauth_authorization_server(http_client, first_authorization_server).await?;
38 Ok((protected_resource, authorization_server))
39}
40
41pub async fn oauth_protected_resource(
42 http_client: &reqwest::Client,
43 pds: &str,
44) -> Result<OAuthProtectedResource, OAuthClientError> {
45 let destination = format!("{}/.well-known/oauth-protected-resource", pds);
46
47 let resource: OAuthProtectedResource = http_client
48 .get(destination)
49 .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
50 .send()
51 .await
52 .map_err(OAuthClientError::OAuthProtectedResourceRequestFailed)?
53 .json()
54 .await
55 .map_err(OAuthClientError::MalformedOAuthProtectedResourceResponse)?;
56
57 if resource.resource != pds {
58 return Err(OAuthClientError::InvalidOAuthProtectedResourceResponse(
59 ResourceValidationError::ResourceMustMatchPds.into(),
60 ));
61 }
62
63 if resource.authorization_servers.is_empty() {
64 return Err(OAuthClientError::InvalidOAuthProtectedResourceResponse(
65 ResourceValidationError::AuthorizationServersMustNotBeEmpty.into(),
66 ));
67 }
68
69 Ok(resource)
70}
71
72#[tracing::instrument(skip(http_client), err)]
73pub async fn oauth_authorization_server(
74 http_client: &reqwest::Client,
75 pds: &str,
76) -> Result<AuthorizationServer, OAuthClientError> {
77 let destination = format!("{}/.well-known/oauth-authorization-server", pds);
78
79 let resource: AuthorizationServer = http_client
80 .get(destination)
81 .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
82 .send()
83 .await
84 .map_err(OAuthClientError::AuthorizationServerRequestFailed)?
85 .json()
86 .await
87 .map_err(OAuthClientError::MalformedAuthorizationServerResponse)?;
88
89 // All of this is going to change.
90
91 if resource.issuer != pds {
92 return Err(OAuthClientError::InvalidAuthorizationServerResponse(
93 AuthServerValidationError::IssuerMustMatchPds.into(),
94 ));
95 }
96
97 resource
98 .response_types_supported
99 .iter()
100 .find(|&x| x == "code")
101 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
102 AuthServerValidationError::ResponseTypesSupportMustIncludeCode.into(),
103 ))?;
104
105 resource
106 .grant_types_supported
107 .iter()
108 .find(|&x| x == "authorization_code")
109 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
110 AuthServerValidationError::GrantTypesSupportMustIncludeAuthorizationCode.into(),
111 ))?;
112 resource
113 .grant_types_supported
114 .iter()
115 .find(|&x| x == "refresh_token")
116 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
117 AuthServerValidationError::GrantTypesSupportMustIncludeRefreshToken.into(),
118 ))?;
119 resource
120 .code_challenge_methods_supported
121 .iter()
122 .find(|&x| x == "S256")
123 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
124 AuthServerValidationError::CodeChallengeMethodsSupportedMustIncludeS256.into(),
125 ))?;
126 resource
127 .token_endpoint_auth_methods_supported
128 .iter()
129 .find(|&x| x == "none")
130 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
131 AuthServerValidationError::TokenEndpointAuthMethodsSupportedMustIncludeNone.into(),
132 ))?;
133 resource
134 .token_endpoint_auth_methods_supported
135 .iter()
136 .find(|&x| x == "private_key_jwt")
137 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
138 AuthServerValidationError::TokenEndpointAuthMethodsSupportedMustIncludePrivateKeyJwt
139 .into(),
140 ))?;
141 resource
142 .token_endpoint_auth_signing_alg_values_supported
143 .iter()
144 .find(|&x| x == "ES256")
145 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
146 AuthServerValidationError::TokenEndpointAuthSigningAlgValuesMustIncludeES256.into(),
147 ))?;
148 resource
149 .scopes_supported
150 .iter()
151 .find(|&x| x == "atproto")
152 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
153 AuthServerValidationError::ScopesSupportedMustIncludeAtProto.into(),
154 ))?;
155 resource
156 .scopes_supported
157 .iter()
158 .find(|&x| x == "transition:generic")
159 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
160 AuthServerValidationError::ScopesSupportedMustIncludeTransitionGeneric.into(),
161 ))?;
162 resource
163 .dpop_signing_alg_values_supported
164 .iter()
165 .find(|&x| x == "ES256")
166 .ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
167 AuthServerValidationError::DpopSigningAlgValuesSupportedMustIncludeES256.into(),
168 ))?;
169
170 if !(resource.authorization_response_iss_parameter_supported
171 && resource.require_pushed_authorization_requests
172 && resource.client_id_metadata_document_supported)
173 {
174 return Err(OAuthClientError::InvalidAuthorizationServerResponse(
175 AuthServerValidationError::RequiredServerFeaturesMustBeSupported.into(),
176 ));
177 }
178
179 Ok(resource)
180}
181
182pub async fn oauth_init(
183 http_client: &reqwest::Client,
184 external_url_base: &str,
185 (secret_key_id, secret_key): (&str, SecretKey),
186 dpop_secret_key: &SecretKey,
187 handle: &str,
188 authorization_server: &AuthorizationServer,
189 oauth_request_state: &OAuthRequestState,
190) -> Result<ParResponse, OAuthClientError> {
191 let par_url = authorization_server
192 .pushed_authorization_request_endpoint
193 .clone();
194
195 let redirect_uri = format!("https://{}/oauth/callback", external_url_base);
196 let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base);
197
198 let scope = "atproto transition:generic".to_string();
199
200 let client_assertion_header = Header {
201 algorithm: Some("ES256".to_string()),
202 key_id: Some(secret_key_id.to_string()),
203 ..Default::default()
204 };
205
206 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
207 let client_assertion_claims = Claims::new(JoseClaims {
208 issuer: Some(client_id.clone()),
209 subject: Some(client_id.clone()),
210 audience: Some(authorization_server.issuer.clone()),
211 json_web_token_id: Some(client_assertion_jti),
212 issued_at: Some(chrono::Utc::now().timestamp() as u64),
213 ..Default::default()
214 });
215 tracing::info!("client_assertion_claims: {:?}", client_assertion_claims);
216
217 let client_assertion_token = mint_token(
218 &secret_key,
219 &client_assertion_header,
220 &client_assertion_claims,
221 )
222 .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
223
224 let now = chrono::Utc::now();
225 let public_key = dpop_secret_key.public_key();
226
227 let dpop_proof_header = Header {
228 type_: Some("dpop+jwt".to_string()),
229 algorithm: Some("ES256".to_string()),
230 json_web_key: Some(public_key.to_jwk()),
231 ..Default::default()
232 };
233 let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
234
235 let dpop_proof_claim = Claims::new(JoseClaims {
236 json_web_token_id: Some(dpop_proof_jti),
237 http_method: Some("POST".to_string()),
238 http_uri: Some(par_url.clone()),
239 issued_at: Some(now.timestamp() as u64),
240 expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
241 ..Default::default()
242 });
243 let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)
244 .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
245
246 let dpop_retry = DpopRetry::new(
247 dpop_proof_header.clone(),
248 dpop_proof_claim.clone(),
249 dpop_secret_key.clone(),
250 );
251
252 let dpop_retry_client = ClientBuilder::new(http_client.clone())
253 .with(ChainMiddleware::new(dpop_retry.clone()))
254 .build();
255
256 let params = [
257 ("response_type", "code"),
258 ("code_challenge", &oauth_request_state.code_challenge),
259 ("code_challenge_method", "S256"),
260 ("client_id", client_id.as_str()),
261 ("state", oauth_request_state.state.as_str()),
262 ("redirect_uri", redirect_uri.as_str()),
263 ("scope", scope.as_str()),
264 ("login_hint", handle),
265 (
266 "client_assertion_type",
267 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
268 ),
269 ("client_assertion", client_assertion_token.as_str()),
270 ];
271
272 tracing::warn!("params: {:?}", params);
273
274 dpop_retry_client
275 .post(par_url)
276 .header("DPoP", dpop_proof_token.as_str())
277 .form(¶ms)
278 .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
279 .send()
280 .await
281 .map_err(OAuthClientError::PARMiddlewareRequestFailed)?
282 .json()
283 .await
284 .map_err(OAuthClientError::MalformedPARResponse)
285}
286
287pub async fn oauth_complete(
288 http_client: &reqwest::Client,
289 external_url_base: &str,
290 (secret_key_id, secret_key): (&str, SecretKey),
291 callback_code: &str,
292 oauth_request: &OAuthRequest,
293 handle: &Handle,
294 dpop_secret_key: &SecretKey,
295) -> Result<TokenResponse, OAuthClientError> {
296 let (_, authorization_server) = pds_resources(http_client, &handle.pds).await?;
297
298 let client_assertion_header = Header {
299 algorithm: Some("ES256".to_string()),
300 key_id: Some(secret_key_id.to_string()),
301 ..Default::default()
302 };
303
304 let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base);
305 let redirect_uri = format!("https://{}/oauth/callback", external_url_base);
306
307 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
308 let client_assertion_claims = Claims::new(JoseClaims {
309 issuer: Some(client_id.clone()),
310 subject: Some(client_id.clone()),
311 audience: Some(authorization_server.issuer.clone()),
312 json_web_token_id: Some(client_assertion_jti),
313 issued_at: Some(chrono::Utc::now().timestamp() as u64),
314 ..Default::default()
315 });
316
317 let client_assertion_token = mint_token(
318 &secret_key,
319 &client_assertion_header,
320 &client_assertion_claims,
321 )
322 .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
323
324 let params = [
325 ("client_id", client_id.as_str()),
326 ("redirect_uri", redirect_uri.as_str()),
327 ("grant_type", "authorization_code"),
328 ("code", callback_code),
329 ("code_verifier", &oauth_request.pkce_verifier),
330 (
331 "client_assertion_type",
332 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
333 ),
334 ("client_assertion", client_assertion_token.as_str()),
335 ];
336
337 let public_key = dpop_secret_key.public_key();
338
339 let token_endpoint = authorization_server.token_endpoint.clone();
340
341 let now = chrono::Utc::now();
342
343 let dpop_proof_header = Header {
344 type_: Some("dpop+jwt".to_string()),
345 algorithm: Some("ES256".to_string()),
346 json_web_key: Some(public_key.to_jwk()),
347 ..Default::default()
348 };
349 let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
350 let dpop_proof_claim = Claims::new(JoseClaims {
351 json_web_token_id: Some(dpop_proof_jti),
352 http_method: Some("POST".to_string()),
353 http_uri: Some(authorization_server.token_endpoint.clone()),
354 issued_at: Some(now.timestamp() as u64),
355 expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
356 ..Default::default()
357 });
358 let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)
359 .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
360
361 let dpop_retry = DpopRetry::new(
362 dpop_proof_header.clone(),
363 dpop_proof_claim.clone(),
364 dpop_secret_key.clone(),
365 );
366
367 let dpop_retry_client = ClientBuilder::new(http_client.clone())
368 .with(ChainMiddleware::new(dpop_retry.clone()))
369 .build();
370
371 dpop_retry_client
372 .post(token_endpoint)
373 .header("DPoP", dpop_proof_token.as_str())
374 .form(¶ms)
375 .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
376 .send()
377 .await
378 .map_err(OAuthClientError::TokenMiddlewareRequestFailed)?
379 .json()
380 .await
381 .map_err(OAuthClientError::MalformedTokenResponse)
382}
383
384pub async fn client_oauth_refresh(
385 http_client: &reqwest::Client,
386 external_url_base: &str,
387 (secret_key_id, secret_key): (&str, SecretKey),
388 refresh_token: &str,
389 handle: &Handle,
390 dpop_secret_key: &SecretKey,
391) -> Result<TokenResponse, OAuthClientError> {
392 let (_, authorization_server) = pds_resources(http_client, &handle.pds).await?;
393
394 let client_assertion_header = Header {
395 algorithm: Some("ES256".to_string()),
396 key_id: Some(secret_key_id.to_string()),
397 ..Default::default()
398 };
399
400 let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base);
401 let redirect_uri = format!("https://{}/oauth/callback", external_url_base);
402
403 let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
404 let client_assertion_claims = Claims::new(JoseClaims {
405 issuer: Some(client_id.clone()),
406 subject: Some(client_id.clone()),
407 audience: Some(authorization_server.issuer.clone()),
408 json_web_token_id: Some(client_assertion_jti),
409 issued_at: Some(chrono::Utc::now().timestamp() as u64),
410 ..Default::default()
411 });
412
413 let client_assertion_token = mint_token(
414 &secret_key,
415 &client_assertion_header,
416 &client_assertion_claims,
417 )
418 .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
419
420 let params = [
421 ("client_id", client_id.as_str()),
422 ("redirect_uri", redirect_uri.as_str()),
423 ("grant_type", "refresh_token"),
424 ("refresh_token", refresh_token),
425 (
426 "client_assertion_type",
427 "urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
428 ),
429 ("client_assertion", client_assertion_token.as_str()),
430 ];
431
432 tracing::info!("params: {:?}", params);
433
434 let public_key = dpop_secret_key.public_key();
435
436 let token_endpoint = authorization_server.token_endpoint.clone();
437
438 let now = chrono::Utc::now();
439
440 let dpop_proof_header = Header {
441 type_: Some("dpop+jwt".to_string()),
442 algorithm: Some("ES256".to_string()),
443 json_web_key: Some(public_key.to_jwk()),
444 ..Default::default()
445 };
446 let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
447 let dpop_proof_claim = Claims::new(JoseClaims {
448 json_web_token_id: Some(dpop_proof_jti),
449 http_method: Some("POST".to_string()),
450 http_uri: Some(authorization_server.token_endpoint.clone()),
451 issued_at: Some(now.timestamp() as u64),
452 expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
453 ..Default::default()
454 });
455 let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)
456 .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
457
458 let dpop_retry = DpopRetry::new(
459 dpop_proof_header.clone(),
460 dpop_proof_claim.clone(),
461 dpop_secret_key.clone(),
462 );
463
464 let dpop_retry_client = ClientBuilder::new(http_client.clone())
465 .with(ChainMiddleware::new(dpop_retry.clone()))
466 .build();
467
468 dpop_retry_client
469 .post(token_endpoint)
470 .header("DPoP", dpop_proof_token.as_str())
471 .form(¶ms)
472 .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
473 .send()
474 .await
475 .map_err(OAuthClientError::TokenMiddlewareRequestFailed)?
476 .json()
477 .await
478 .map_err(OAuthClientError::MalformedTokenResponse)
479}
480
481pub mod dpop {
482 use p256::SecretKey;
483 use reqwest::header::HeaderValue;
484 use reqwest_chain::Chainer;
485 use serde::Deserialize;
486
487 use crate::{
488 jose::{
489 jwt::{Claims, Header},
490 mint_token,
491 },
492 jose_errors::JoseError,
493 };
494
495 #[derive(Clone, Debug, Deserialize)]
496 pub struct SimpleError {
497 pub error: Option<String>,
498 pub error_description: Option<String>,
499 pub message: Option<String>,
500 }
501
502 impl std::fmt::Display for SimpleError {
503 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504 if let Some(value) = &self.error {
505 write!(f, "{}", value)
506 } else if let Some(value) = &self.message {
507 write!(f, "{}", value)
508 } else if let Some(value) = &self.error_description {
509 write!(f, "{}", value)
510 } else {
511 write!(f, "unknown")
512 }
513 }
514 }
515
516 #[derive(Clone)]
517 pub struct DpopRetry {
518 pub header: Header,
519 pub claims: Claims,
520 pub secret: SecretKey,
521 }
522
523 impl DpopRetry {
524 pub fn new(header: Header, claims: Claims, secret: SecretKey) -> Self {
525 DpopRetry {
526 header,
527 claims,
528 secret,
529 }
530 }
531 }
532
533 #[async_trait::async_trait]
534 impl Chainer for DpopRetry {
535 type State = ();
536
537 async fn chain(
538 &self,
539 result: Result<reqwest::Response, reqwest_middleware::Error>,
540 _state: &mut Self::State,
541 request: &mut reqwest::Request,
542 ) -> Result<Option<reqwest::Response>, reqwest_middleware::Error> {
543 let response = result?;
544
545 let status_code = response.status();
546
547 if status_code != 400 && status_code != 401 {
548 return Ok(Some(response));
549 };
550
551 let headers = response.headers().clone();
552
553 let simple_error = response.json::<SimpleError>().await;
554 if simple_error.is_err() {
555 return Err(reqwest_middleware::Error::Middleware(
556 JoseError::UnableToParseSimpleError.into(),
557 ));
558 }
559
560 let simple_error = simple_error.unwrap();
561
562 tracing::error!("dpop error: {:?}", simple_error);
563
564 let is_use_dpop_nonce_error = simple_error
565 .clone()
566 .error
567 .is_some_and(|error_value| error_value == "use_dpop_nonce");
568
569 if !is_use_dpop_nonce_error {
570 return Err(reqwest_middleware::Error::Middleware(
571 JoseError::UnexpectedError(simple_error.to_string()).into(),
572 ));
573 }
574
575 let dpop_header = headers.get("DPoP-Nonce");
576
577 if dpop_header.is_none() {
578 return Err(reqwest_middleware::Error::Middleware(
579 JoseError::MissingDpopHeader.into(),
580 ));
581 }
582
583 let new_dpop_header = dpop_header.unwrap().to_str().map_err(|dpop_header_err| {
584 reqwest_middleware::Error::Middleware(
585 JoseError::UnableToParseDpopHeader(dpop_header_err.to_string()).into(),
586 )
587 })?;
588
589 let dpop_proof_header = self.header.clone();
590 let mut dpop_proof_claim = self.claims.clone();
591 dpop_proof_claim
592 .private
593 .insert("nonce".to_string(), new_dpop_header.to_string().into());
594
595 let dpop_proof_token = mint_token(&self.secret, &dpop_proof_header, &dpop_proof_claim)
596 .map_err(|dpop_proof_token_err| {
597 reqwest_middleware::Error::Middleware(
598 JoseError::UnableToMintDpopProofToken(dpop_proof_token_err.to_string())
599 .into(),
600 )
601 })?;
602
603 request.headers_mut().insert(
604 "DPoP",
605 HeaderValue::from_str(&dpop_proof_token).expect("invalid header value"),
606 );
607 Ok(None)
608 }
609 }
610}
611
612pub mod model {
613 use serde::Deserialize;
614
615 #[derive(Clone, Deserialize)]
616 pub struct OAuthProtectedResource {
617 pub resource: String,
618 pub authorization_servers: Vec<String>,
619 pub scopes_supported: Vec<String>,
620 pub bearer_methods_supported: Vec<String>,
621 }
622
623 #[derive(Clone, Deserialize, Default, Debug)]
624 pub struct AuthorizationServer {
625 pub introspection_endpoint: String,
626 pub authorization_endpoint: String,
627 pub authorization_response_iss_parameter_supported: bool,
628 pub client_id_metadata_document_supported: bool,
629 pub code_challenge_methods_supported: Vec<String>,
630 pub dpop_signing_alg_values_supported: Vec<String>,
631 pub grant_types_supported: Vec<String>,
632 pub issuer: String,
633 pub pushed_authorization_request_endpoint: String,
634 pub request_parameter_supported: bool,
635 pub require_pushed_authorization_requests: bool,
636 pub response_types_supported: Vec<String>,
637 pub scopes_supported: Vec<String>,
638 pub token_endpoint_auth_methods_supported: Vec<String>,
639 pub token_endpoint_auth_signing_alg_values_supported: Vec<String>,
640 pub token_endpoint: String,
641 }
642
643 #[derive(Clone, Deserialize)]
644 pub struct ParResponse {
645 pub request_uri: String,
646 pub expires_in: u64,
647 }
648
649 #[derive(Clone, Deserialize)]
650 pub struct TokenResponse {
651 pub access_token: String,
652 pub token_type: String,
653 pub refresh_token: String,
654 pub scope: String,
655 pub expires_in: u32,
656 pub sub: String,
657 }
658}
659
660// This errors module is now deprecated.
661// Use crate::oauth_client_errors::OAuthClientError instead.
662pub mod errors {
663 pub use crate::oauth_client_errors::OAuthClientError;
664}