Built for people who think better out loud.
at main 457 lines 14 kB view raw
1use atproto_identity::key::{KeyData, KeyType, generate_key, identify_key, to_public}; 2use atproto_identity::traits::{DidDocumentStorage, IdentityResolver, KeyResolver}; 3use atproto_oauth::pkce; 4use atproto_oauth::resources::pds_resources; 5use atproto_oauth::storage::OAuthRequestStorage; 6use atproto_oauth::workflow::{OAuthClient, OAuthRequest, OAuthRequestState, oauth_complete, oauth_init}; 7use atproto_oauth_axum::errors::OAuthCallbackError; 8use atproto_oauth_axum::handle_complete::OAuthCallbackForm; 9use atproto_oauth_axum::handle_jwks::handle_oauth_jwks; 10use atproto_oauth_axum::handler_metadata::handle_oauth_metadata; 11use atproto_oauth_axum::state::OAuthClientConfig; 12use axum::extract::{Query, State}; 13use axum::http::{HeaderMap, HeaderValue, StatusCode}; 14use axum::response::{IntoResponse, Redirect}; 15use axum::routing::{get, post}; 16use axum::{Json, Router}; 17use anyhow::anyhow; 18use chrono::{Duration, Utc}; 19use rand::distr::{Alphanumeric, SampleString}; 20use serde::{Deserialize, Serialize}; 21use sqlx::PgPool; 22use std::sync::Arc; 23use uuid::Uuid; 24 25use crate::application::oauth::{ 26 OAuthPolicyError, 27 ensure_did_matches_authorization_server, 28 ensure_granted_scopes_valid, 29 scope_contains_atproto, 30}; 31use crate::config::Config; 32use crate::state::AppState; 33use crate::http::cookies::read_session_cookie; 34use crate::http::AuthenticatedUser; 35use crate::infrastructure::db::auth as auth_db; 36 37#[derive(Serialize)] 38struct ErrorResponse { 39 error: ErrorDetail, 40} 41 42#[derive(Serialize)] 43struct ErrorDetail { 44 message: String, 45} 46 47#[derive(Deserialize)] 48struct StartParams { 49 subject: Option<String>, 50} 51 52#[derive(Serialize)] 53struct SessionUserResponse { 54 did: String, 55 handle: Option<String>, 56} 57 58pub fn router() -> Router<AppState> { 59 Router::<AppState>::new() 60 .route("/oauth/client-metadata.json", get(handle_oauth_metadata)) 61 .route("/.well-known/jwks.json", get(handle_oauth_jwks)) 62 .route("/oauth/callback", get(oauth_callback)) 63 .route("/api/auth/atproto/start", get(start_oauth)) 64 .route("/api/auth/me", get(current_user)) 65 .route("/api/auth/logout", post(logout)) 66} 67 68async fn start_oauth( 69 State(oauth_request_storage): State<Arc<dyn OAuthRequestStorage>>, 70 State(did_document_storage): State<Arc<dyn DidDocumentStorage>>, 71 State(identity_resolver): State<Arc<dyn IdentityResolver>>, 72 State(oauth_client_config): State<OAuthClientConfig>, 73 State(http_client): State<reqwest::Client>, 74 State(oauth_signing_key): State<KeyData>, 75 Query(params): Query<StartParams>, 76) -> Result<Redirect, (StatusCode, Json<ErrorResponse>)> { 77 if !scope_contains_atproto(oauth_client_config.scope()) { 78 return Err(error_response( 79 StatusCode::INTERNAL_SERVER_ERROR, 80 "OAuth scopes must include atproto.", 81 )); 82 } 83 84 let subject = params 85 .subject 86 .as_deref() 87 .filter(|value| !value.trim().is_empty()) 88 .ok_or_else(|| error_response(StatusCode::BAD_REQUEST, "Missing subject."))?; 89 90 let document = identity_resolver 91 .resolve(subject) 92 .await 93 .map_err(|err| { 94 error_response( 95 StatusCode::BAD_REQUEST, 96 &format!("Failed to resolve subject. {err}"), 97 ) 98 })?; 99 100 did_document_storage 101 .store_document(document.clone()) 102 .await 103 .map_err(|err| { 104 error_response( 105 StatusCode::INTERNAL_SERVER_ERROR, 106 &format!("Failed to store DID document. {err}"), 107 ) 108 })?; 109 110 let pds_endpoint = document.pds_endpoints().first().cloned().ok_or_else(|| { 111 error_response( 112 StatusCode::BAD_REQUEST, 113 "No PDS endpoint found for subject.", 114 ) 115 })?; 116 117 let (_, authorization_server) = pds_resources(&http_client, &pds_endpoint) 118 .await 119 .map_err(|err| { 120 error_response( 121 StatusCode::BAD_GATEWAY, 122 &format!("Failed to discover authorization server. {err}"), 123 ) 124 })?; 125 126 let (pkce_verifier, code_challenge) = pkce::generate(); 127 let (oauth_state, nonce) = { 128 let mut rng = rand::rng(); 129 ( 130 Alphanumeric.sample_string(&mut rng, 32), 131 Alphanumeric.sample_string(&mut rng, 32), 132 ) 133 }; 134 let dpop_key = generate_key(KeyType::P256Private).map_err(|err| { 135 error_response( 136 StatusCode::INTERNAL_SERVER_ERROR, 137 &format!("Failed to generate DPoP key. {err}"), 138 ) 139 })?; 140 141 let oauth_client = OAuthClient { 142 redirect_uri: oauth_client_config.redirect_uris.clone(), 143 client_id: oauth_client_config.client_id.clone(), 144 private_signing_key_data: oauth_signing_key.clone(), 145 }; 146 147 let oauth_request_state = OAuthRequestState { 148 state: oauth_state.clone(), 149 nonce: nonce.clone(), 150 code_challenge, 151 scope: oauth_client_config.scope().to_string(), 152 }; 153 154 let par_response = oauth_init( 155 &http_client, 156 &oauth_client, 157 &dpop_key, 158 Some(subject), 159 &authorization_server, 160 &oauth_request_state, 161 ) 162 .await 163 .map_err(|err| { 164 error_response( 165 StatusCode::BAD_GATEWAY, 166 &format!("Failed to initialize OAuth flow. {err}"), 167 ) 168 })?; 169 170 let public_signing_key = to_public(&oauth_signing_key).map_err(|err| { 171 error_response( 172 StatusCode::INTERNAL_SERVER_ERROR, 173 &format!("Failed to derive public signing key. {err}"), 174 ) 175 })?; 176 177 let now = Utc::now(); 178 let oauth_request = OAuthRequest { 179 oauth_state: oauth_state.clone(), 180 issuer: authorization_server.issuer.clone(), 181 authorization_server: pds_endpoint.to_string(), 182 nonce, 183 pkce_verifier, 184 signing_public_key: public_signing_key.to_string(), 185 dpop_private_key: dpop_key.to_string(), 186 created_at: now, 187 expires_at: now + Duration::hours(1), 188 }; 189 190 oauth_request_storage 191 .insert_oauth_request(oauth_request) 192 .await 193 .map_err(|err| { 194 error_response( 195 StatusCode::INTERNAL_SERVER_ERROR, 196 &format!("Failed to store OAuth request. {err}"), 197 ) 198 })?; 199 200 let auth_url = format!( 201 "{}?client_id={}&request_uri={}", 202 authorization_server.authorization_endpoint, 203 oauth_client.client_id, 204 par_response.request_uri 205 ); 206 207 Ok(Redirect::to(&auth_url)) 208} 209 210async fn oauth_callback( 211 State(oauth_request_storage): State<Arc<dyn OAuthRequestStorage>>, 212 State(did_document_storage): State<Arc<dyn DidDocumentStorage>>, 213 State(identity_resolver): State<Arc<dyn IdentityResolver>>, 214 State(oauth_client_config): State<OAuthClientConfig>, 215 State(key_resolver): State<Arc<dyn KeyResolver>>, 216 State(http_client): State<reqwest::Client>, 217 State(db_pool): State<PgPool>, 218 State(config): State<Config>, 219 Query(callback_form): Query<OAuthCallbackForm>, 220) -> Result<impl IntoResponse, OAuthCallbackError> { 221 let oauth_request = oauth_request_storage 222 .get_oauth_request_by_state(&callback_form.state) 223 .await? 224 .ok_or(OAuthCallbackError::NoOAuthRequestFound)?; 225 226 if oauth_request.issuer != callback_form.iss { 227 return Err(OAuthCallbackError::InvalidIssuer { 228 expected: oauth_request.issuer.clone(), 229 actual: callback_form.iss.clone(), 230 }); 231 } 232 233 let private_signing_key_data = key_resolver 234 .resolve(&oauth_request.signing_public_key) 235 .await 236 .map_err(|_| OAuthCallbackError::NoSigningKeyFound)?; 237 238 let private_dpop_key_data = identify_key(&oauth_request.dpop_private_key)?; 239 240 let oauth_client = OAuthClient { 241 redirect_uri: oauth_client_config.redirect_uris.clone(), 242 client_id: oauth_client_config.client_id.clone(), 243 private_signing_key_data, 244 }; 245 246 let (_, authorization_server) = 247 pds_resources(&http_client, &oauth_request.authorization_server).await?; 248 249 let token_response = oauth_complete( 250 &http_client, 251 &oauth_client, 252 &private_dpop_key_data, 253 &callback_form.code, 254 &oauth_request, 255 &authorization_server, 256 ) 257 .await?; 258 259 ensure_granted_scopes_valid( 260 oauth_client_config.scope(), 261 &token_response.scope, 262 ) 263 .map_err(map_oauth_policy_error)?; 264 265 let did = token_response 266 .sub 267 .clone() 268 .ok_or(OAuthCallbackError::NoDIDDocumentFound)?; 269 270 let document = identity_resolver.resolve(&did).await?; 271 did_document_storage.store_document(document.clone()).await?; 272 273 let did_pds_endpoint = document.pds_endpoints().first().cloned().ok_or( 274 OAuthCallbackError::OperationFailed { 275 error: anyhow!("No PDS endpoint found for token DID.").into(), 276 }, 277 )?; 278 let (_, did_auth_server) = 279 pds_resources(&http_client, &did_pds_endpoint).await?; 280 ensure_did_matches_authorization_server( 281 &oauth_request.issuer, 282 &oauth_request.authorization_server, 283 &did_pds_endpoint, 284 &did_auth_server.issuer, 285 ) 286 .map_err(map_oauth_policy_error)?; 287 288 let handle = document 289 .also_known_as 290 .iter() 291 .find_map(|item| item.strip_prefix("at://")) 292 .map(|value| value.to_string()); 293 294 auth_db::upsert_user(&db_pool, &did, handle.as_deref()) 295 .await 296 .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?; 297 auth_db::store_tokens( 298 &db_pool, 299 &did, 300 &token_response.access_token, 301 token_response.refresh_token.as_deref(), 302 &token_response.token_type, 303 &token_response.scope, 304 token_response.expires_in, 305 &oauth_request.issuer, 306 &oauth_request.dpop_private_key, 307 ) 308 .await 309 .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?; 310 311 oauth_request_storage 312 .delete_oauth_request_by_state(&oauth_request.oauth_state) 313 .await?; 314 315 let session_id = Uuid::new_v4(); 316 let session_expires_at = 317 Utc::now() + Duration::seconds(config.oauth_session_ttl_seconds); 318 auth_db::create_session(&db_pool, session_id, &did, session_expires_at) 319 .await 320 .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?; 321 322 let mut headers = HeaderMap::new(); 323 headers.insert( 324 axum::http::header::SET_COOKIE, 325 build_session_cookie( 326 &config.oauth_cookie_name, 327 session_id, 328 config.oauth_session_ttl_seconds, 329 config.slipnote_env == "prod", 330 )?, 331 ); 332 333 Ok((headers, Redirect::to(&config.oauth_post_auth_redirect))) 334} 335 336async fn current_user( 337 AuthenticatedUser(user): AuthenticatedUser, 338) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> { 339 Ok(Json(SessionUserResponse { 340 did: user.did, 341 handle: user.handle, 342 })) 343} 344 345async fn logout( 346 State(db_pool): State<PgPool>, 347 State(config): State<Config>, 348 headers: HeaderMap, 349) -> impl IntoResponse { 350 if let Some(session_id) = read_session_cookie(&headers, &config.oauth_cookie_name) { 351 let _ = auth_db::delete_session(&db_pool, session_id).await; 352 } 353 354 let mut response_headers = HeaderMap::new(); 355 if let Ok(cookie) = clear_session_cookie( 356 &config.oauth_cookie_name, 357 config.slipnote_env == "prod", 358 ) { 359 response_headers.insert(axum::http::header::SET_COOKIE, cookie); 360 } 361 362 (response_headers, StatusCode::NO_CONTENT) 363} 364 365fn error_response( 366 status: StatusCode, 367 message: &str, 368) -> (StatusCode, Json<ErrorResponse>) { 369 ( 370 status, 371 Json(ErrorResponse { 372 error: ErrorDetail { 373 message: message.to_string(), 374 }, 375 }), 376 ) 377} 378 379fn build_session_cookie( 380 name: &str, 381 session_id: Uuid, 382 max_age_seconds: i64, 383 secure: bool, 384) -> Result<HeaderValue, OAuthCallbackError> { 385 let mut cookie = format!( 386 "{}={}; Max-Age={}; Path=/; HttpOnly; SameSite=Lax", 387 name, session_id, max_age_seconds 388 ); 389 if secure { 390 cookie.push_str("; Secure"); 391 } 392 HeaderValue::from_str(&cookie).map_err(|err| { 393 OAuthCallbackError::OperationFailed { 394 error: err.into(), 395 } 396 }) 397} 398 399fn clear_session_cookie(name: &str, secure: bool) -> Result<HeaderValue, OAuthCallbackError> { 400 let mut cookie = format!( 401 "{}=; Max-Age=0; Path=/; HttpOnly; SameSite=Lax", 402 name 403 ); 404 if secure { 405 cookie.push_str("; Secure"); 406 } 407 HeaderValue::from_str(&cookie).map_err(|err| { 408 OAuthCallbackError::OperationFailed { 409 error: err.into(), 410 } 411 }) 412} 413 414 415fn map_oauth_policy_error(err: OAuthPolicyError) -> OAuthCallbackError { 416 match err { 417 OAuthPolicyError::MissingAtprotoScope => OAuthCallbackError::OperationFailed { 418 error: anyhow!("OAuth scopes must include atproto.").into(), 419 }, 420 OAuthPolicyError::ScopeNotSubset => OAuthCallbackError::OperationFailed { 421 error: anyhow!("OAuth scopes must be a subset of requested scopes.").into(), 422 }, 423 OAuthPolicyError::DidPdsMismatch => OAuthCallbackError::OperationFailed { 424 error: anyhow!("Token DID does not match PDS used for authorization.").into(), 425 }, 426 OAuthPolicyError::IssuerMismatch { expected, actual } => { 427 OAuthCallbackError::InvalidIssuer { expected, actual } 428 } 429 } 430} 431 432#[cfg(test)] 433mod tests { 434 use super::*; 435 use proptest::prelude::*; 436 437 proptest! { 438 #[test] 439 fn build_session_cookie_sets_expected_flags( 440 name in "[a-zA-Z0-9_]{1,24}", 441 max_age in 1i64..86400i64, 442 uuid in any::<[u8; 16]>() 443 ) { 444 let session_id = Uuid::from_bytes(uuid); 445 let header_value = build_session_cookie(&name, session_id, max_age, true).unwrap(); 446 let header_str = header_value.to_str().unwrap(); 447 prop_assert!(header_str.contains("HttpOnly")); 448 prop_assert!(header_str.contains("SameSite=Lax")); 449 prop_assert!(header_str.contains("Secure")); 450 prop_assert!( 451 header_str.contains(&format!("{name}={session_id}")), 452 "session cookie missing name/value" 453 ); 454 } 455 } 456 457}