use atproto_identity::key::{KeyData, KeyType, generate_key, identify_key, to_public}; use atproto_identity::traits::{DidDocumentStorage, IdentityResolver, KeyResolver}; use atproto_oauth::pkce; use atproto_oauth::resources::pds_resources; use atproto_oauth::storage::OAuthRequestStorage; use atproto_oauth::workflow::{OAuthClient, OAuthRequest, OAuthRequestState, oauth_complete, oauth_init}; use atproto_oauth_axum::errors::OAuthCallbackError; use atproto_oauth_axum::handle_complete::OAuthCallbackForm; use atproto_oauth_axum::handle_jwks::handle_oauth_jwks; use atproto_oauth_axum::handler_metadata::handle_oauth_metadata; use atproto_oauth_axum::state::OAuthClientConfig; use axum::extract::{Query, State}; use axum::http::{HeaderMap, HeaderValue, StatusCode}; use axum::response::{IntoResponse, Redirect}; use axum::routing::{get, post}; use axum::{Json, Router}; use anyhow::anyhow; use chrono::{Duration, Utc}; use rand::distr::{Alphanumeric, SampleString}; use serde::{Deserialize, Serialize}; use sqlx::PgPool; use std::sync::Arc; use uuid::Uuid; use crate::application::oauth::{ OAuthPolicyError, ensure_did_matches_authorization_server, ensure_granted_scopes_valid, scope_contains_atproto, }; use crate::config::Config; use crate::state::AppState; use crate::http::cookies::read_session_cookie; use crate::http::AuthenticatedUser; use crate::infrastructure::db::auth as auth_db; #[derive(Serialize)] struct ErrorResponse { error: ErrorDetail, } #[derive(Serialize)] struct ErrorDetail { message: String, } #[derive(Deserialize)] struct StartParams { subject: Option, } #[derive(Serialize)] struct SessionUserResponse { did: String, handle: Option, } pub fn router() -> Router { Router::::new() .route("/oauth/client-metadata.json", get(handle_oauth_metadata)) .route("/.well-known/jwks.json", get(handle_oauth_jwks)) .route("/oauth/callback", get(oauth_callback)) .route("/api/auth/atproto/start", get(start_oauth)) .route("/api/auth/me", get(current_user)) .route("/api/auth/logout", post(logout)) } async fn start_oauth( State(oauth_request_storage): State>, State(did_document_storage): State>, State(identity_resolver): State>, State(oauth_client_config): State, State(http_client): State, State(oauth_signing_key): State, Query(params): Query, ) -> Result)> { if !scope_contains_atproto(oauth_client_config.scope()) { return Err(error_response( StatusCode::INTERNAL_SERVER_ERROR, "OAuth scopes must include atproto.", )); } let subject = params .subject .as_deref() .filter(|value| !value.trim().is_empty()) .ok_or_else(|| error_response(StatusCode::BAD_REQUEST, "Missing subject."))?; let document = identity_resolver .resolve(subject) .await .map_err(|err| { error_response( StatusCode::BAD_REQUEST, &format!("Failed to resolve subject. {err}"), ) })?; did_document_storage .store_document(document.clone()) .await .map_err(|err| { error_response( StatusCode::INTERNAL_SERVER_ERROR, &format!("Failed to store DID document. {err}"), ) })?; let pds_endpoint = document.pds_endpoints().first().cloned().ok_or_else(|| { error_response( StatusCode::BAD_REQUEST, "No PDS endpoint found for subject.", ) })?; let (_, authorization_server) = pds_resources(&http_client, &pds_endpoint) .await .map_err(|err| { error_response( StatusCode::BAD_GATEWAY, &format!("Failed to discover authorization server. {err}"), ) })?; let (pkce_verifier, code_challenge) = pkce::generate(); let (oauth_state, nonce) = { let mut rng = rand::rng(); ( Alphanumeric.sample_string(&mut rng, 32), Alphanumeric.sample_string(&mut rng, 32), ) }; let dpop_key = generate_key(KeyType::P256Private).map_err(|err| { error_response( StatusCode::INTERNAL_SERVER_ERROR, &format!("Failed to generate DPoP key. {err}"), ) })?; let oauth_client = OAuthClient { redirect_uri: oauth_client_config.redirect_uris.clone(), client_id: oauth_client_config.client_id.clone(), private_signing_key_data: oauth_signing_key.clone(), }; let oauth_request_state = OAuthRequestState { state: oauth_state.clone(), nonce: nonce.clone(), code_challenge, scope: oauth_client_config.scope().to_string(), }; let par_response = oauth_init( &http_client, &oauth_client, &dpop_key, Some(subject), &authorization_server, &oauth_request_state, ) .await .map_err(|err| { error_response( StatusCode::BAD_GATEWAY, &format!("Failed to initialize OAuth flow. {err}"), ) })?; let public_signing_key = to_public(&oauth_signing_key).map_err(|err| { error_response( StatusCode::INTERNAL_SERVER_ERROR, &format!("Failed to derive public signing key. {err}"), ) })?; let now = Utc::now(); let oauth_request = OAuthRequest { oauth_state: oauth_state.clone(), issuer: authorization_server.issuer.clone(), authorization_server: pds_endpoint.to_string(), nonce, pkce_verifier, signing_public_key: public_signing_key.to_string(), dpop_private_key: dpop_key.to_string(), created_at: now, expires_at: now + Duration::hours(1), }; oauth_request_storage .insert_oauth_request(oauth_request) .await .map_err(|err| { error_response( StatusCode::INTERNAL_SERVER_ERROR, &format!("Failed to store OAuth request. {err}"), ) })?; let auth_url = format!( "{}?client_id={}&request_uri={}", authorization_server.authorization_endpoint, oauth_client.client_id, par_response.request_uri ); Ok(Redirect::to(&auth_url)) } async fn oauth_callback( State(oauth_request_storage): State>, State(did_document_storage): State>, State(identity_resolver): State>, State(oauth_client_config): State, State(key_resolver): State>, State(http_client): State, State(db_pool): State, State(config): State, Query(callback_form): Query, ) -> Result { let oauth_request = oauth_request_storage .get_oauth_request_by_state(&callback_form.state) .await? .ok_or(OAuthCallbackError::NoOAuthRequestFound)?; if oauth_request.issuer != callback_form.iss { return Err(OAuthCallbackError::InvalidIssuer { expected: oauth_request.issuer.clone(), actual: callback_form.iss.clone(), }); } let private_signing_key_data = key_resolver .resolve(&oauth_request.signing_public_key) .await .map_err(|_| OAuthCallbackError::NoSigningKeyFound)?; let private_dpop_key_data = identify_key(&oauth_request.dpop_private_key)?; let oauth_client = OAuthClient { redirect_uri: oauth_client_config.redirect_uris.clone(), client_id: oauth_client_config.client_id.clone(), private_signing_key_data, }; let (_, authorization_server) = pds_resources(&http_client, &oauth_request.authorization_server).await?; let token_response = oauth_complete( &http_client, &oauth_client, &private_dpop_key_data, &callback_form.code, &oauth_request, &authorization_server, ) .await?; ensure_granted_scopes_valid( oauth_client_config.scope(), &token_response.scope, ) .map_err(map_oauth_policy_error)?; let did = token_response .sub .clone() .ok_or(OAuthCallbackError::NoDIDDocumentFound)?; let document = identity_resolver.resolve(&did).await?; did_document_storage.store_document(document.clone()).await?; let did_pds_endpoint = document.pds_endpoints().first().cloned().ok_or( OAuthCallbackError::OperationFailed { error: anyhow!("No PDS endpoint found for token DID.").into(), }, )?; let (_, did_auth_server) = pds_resources(&http_client, &did_pds_endpoint).await?; ensure_did_matches_authorization_server( &oauth_request.issuer, &oauth_request.authorization_server, &did_pds_endpoint, &did_auth_server.issuer, ) .map_err(map_oauth_policy_error)?; let handle = document .also_known_as .iter() .find_map(|item| item.strip_prefix("at://")) .map(|value| value.to_string()); auth_db::upsert_user(&db_pool, &did, handle.as_deref()) .await .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?; auth_db::store_tokens( &db_pool, &did, &token_response.access_token, token_response.refresh_token.as_deref(), &token_response.token_type, &token_response.scope, token_response.expires_in, &oauth_request.issuer, &oauth_request.dpop_private_key, ) .await .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?; oauth_request_storage .delete_oauth_request_by_state(&oauth_request.oauth_state) .await?; let session_id = Uuid::new_v4(); let session_expires_at = Utc::now() + Duration::seconds(config.oauth_session_ttl_seconds); auth_db::create_session(&db_pool, session_id, &did, session_expires_at) .await .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?; let mut headers = HeaderMap::new(); headers.insert( axum::http::header::SET_COOKIE, build_session_cookie( &config.oauth_cookie_name, session_id, config.oauth_session_ttl_seconds, config.slipnote_env == "prod", )?, ); Ok((headers, Redirect::to(&config.oauth_post_auth_redirect))) } async fn current_user( AuthenticatedUser(user): AuthenticatedUser, ) -> Result)> { Ok(Json(SessionUserResponse { did: user.did, handle: user.handle, })) } async fn logout( State(db_pool): State, State(config): State, headers: HeaderMap, ) -> impl IntoResponse { if let Some(session_id) = read_session_cookie(&headers, &config.oauth_cookie_name) { let _ = auth_db::delete_session(&db_pool, session_id).await; } let mut response_headers = HeaderMap::new(); if let Ok(cookie) = clear_session_cookie( &config.oauth_cookie_name, config.slipnote_env == "prod", ) { response_headers.insert(axum::http::header::SET_COOKIE, cookie); } (response_headers, StatusCode::NO_CONTENT) } fn error_response( status: StatusCode, message: &str, ) -> (StatusCode, Json) { ( status, Json(ErrorResponse { error: ErrorDetail { message: message.to_string(), }, }), ) } fn build_session_cookie( name: &str, session_id: Uuid, max_age_seconds: i64, secure: bool, ) -> Result { let mut cookie = format!( "{}={}; Max-Age={}; Path=/; HttpOnly; SameSite=Lax", name, session_id, max_age_seconds ); if secure { cookie.push_str("; Secure"); } HeaderValue::from_str(&cookie).map_err(|err| { OAuthCallbackError::OperationFailed { error: err.into(), } }) } fn clear_session_cookie(name: &str, secure: bool) -> Result { let mut cookie = format!( "{}=; Max-Age=0; Path=/; HttpOnly; SameSite=Lax", name ); if secure { cookie.push_str("; Secure"); } HeaderValue::from_str(&cookie).map_err(|err| { OAuthCallbackError::OperationFailed { error: err.into(), } }) } fn map_oauth_policy_error(err: OAuthPolicyError) -> OAuthCallbackError { match err { OAuthPolicyError::MissingAtprotoScope => OAuthCallbackError::OperationFailed { error: anyhow!("OAuth scopes must include atproto.").into(), }, OAuthPolicyError::ScopeNotSubset => OAuthCallbackError::OperationFailed { error: anyhow!("OAuth scopes must be a subset of requested scopes.").into(), }, OAuthPolicyError::DidPdsMismatch => OAuthCallbackError::OperationFailed { error: anyhow!("Token DID does not match PDS used for authorization.").into(), }, OAuthPolicyError::IssuerMismatch { expected, actual } => { OAuthCallbackError::InvalidIssuer { expected, actual } } } } #[cfg(test)] mod tests { use super::*; use proptest::prelude::*; proptest! { #[test] fn build_session_cookie_sets_expected_flags( name in "[a-zA-Z0-9_]{1,24}", max_age in 1i64..86400i64, uuid in any::<[u8; 16]>() ) { let session_id = Uuid::from_bytes(uuid); let header_value = build_session_cookie(&name, session_id, max_age, true).unwrap(); let header_str = header_value.to_str().unwrap(); prop_assert!(header_str.contains("HttpOnly")); prop_assert!(header_str.contains("SameSite=Lax")); prop_assert!(header_str.contains("Secure")); prop_assert!( header_str.contains(&format!("{name}={session_id}")), "session cookie missing name/value" ); } } }