Built for people who think better out loud.
1use atproto_identity::key::{KeyData, KeyType, generate_key, identify_key, to_public};
2use atproto_identity::traits::{DidDocumentStorage, IdentityResolver, KeyResolver};
3use atproto_oauth::pkce;
4use atproto_oauth::resources::pds_resources;
5use atproto_oauth::storage::OAuthRequestStorage;
6use atproto_oauth::workflow::{OAuthClient, OAuthRequest, OAuthRequestState, oauth_complete, oauth_init};
7use atproto_oauth_axum::errors::OAuthCallbackError;
8use atproto_oauth_axum::handle_complete::OAuthCallbackForm;
9use atproto_oauth_axum::handle_jwks::handle_oauth_jwks;
10use atproto_oauth_axum::handler_metadata::handle_oauth_metadata;
11use atproto_oauth_axum::state::OAuthClientConfig;
12use axum::extract::{Query, State};
13use axum::http::{HeaderMap, HeaderValue, StatusCode};
14use axum::response::{IntoResponse, Redirect};
15use axum::routing::{get, post};
16use axum::{Json, Router};
17use anyhow::anyhow;
18use chrono::{Duration, Utc};
19use rand::distr::{Alphanumeric, SampleString};
20use serde::{Deserialize, Serialize};
21use sqlx::PgPool;
22use std::sync::Arc;
23use uuid::Uuid;
24
25use crate::application::oauth::{
26 OAuthPolicyError,
27 ensure_did_matches_authorization_server,
28 ensure_granted_scopes_valid,
29 scope_contains_atproto,
30};
31use crate::config::Config;
32use crate::state::AppState;
33use crate::http::cookies::read_session_cookie;
34use crate::http::AuthenticatedUser;
35use crate::infrastructure::db::auth as auth_db;
36
37#[derive(Serialize)]
38struct ErrorResponse {
39 error: ErrorDetail,
40}
41
42#[derive(Serialize)]
43struct ErrorDetail {
44 message: String,
45}
46
47#[derive(Deserialize)]
48struct StartParams {
49 subject: Option<String>,
50}
51
52#[derive(Serialize)]
53struct SessionUserResponse {
54 did: String,
55 handle: Option<String>,
56}
57
58pub fn router() -> Router<AppState> {
59 Router::<AppState>::new()
60 .route("/oauth/client-metadata.json", get(handle_oauth_metadata))
61 .route("/.well-known/jwks.json", get(handle_oauth_jwks))
62 .route("/oauth/callback", get(oauth_callback))
63 .route("/api/auth/atproto/start", get(start_oauth))
64 .route("/api/auth/me", get(current_user))
65 .route("/api/auth/logout", post(logout))
66}
67
68async fn start_oauth(
69 State(oauth_request_storage): State<Arc<dyn OAuthRequestStorage>>,
70 State(did_document_storage): State<Arc<dyn DidDocumentStorage>>,
71 State(identity_resolver): State<Arc<dyn IdentityResolver>>,
72 State(oauth_client_config): State<OAuthClientConfig>,
73 State(http_client): State<reqwest::Client>,
74 State(oauth_signing_key): State<KeyData>,
75 Query(params): Query<StartParams>,
76) -> Result<Redirect, (StatusCode, Json<ErrorResponse>)> {
77 if !scope_contains_atproto(oauth_client_config.scope()) {
78 return Err(error_response(
79 StatusCode::INTERNAL_SERVER_ERROR,
80 "OAuth scopes must include atproto.",
81 ));
82 }
83
84 let subject = params
85 .subject
86 .as_deref()
87 .filter(|value| !value.trim().is_empty())
88 .ok_or_else(|| error_response(StatusCode::BAD_REQUEST, "Missing subject."))?;
89
90 let document = identity_resolver
91 .resolve(subject)
92 .await
93 .map_err(|err| {
94 error_response(
95 StatusCode::BAD_REQUEST,
96 &format!("Failed to resolve subject. {err}"),
97 )
98 })?;
99
100 did_document_storage
101 .store_document(document.clone())
102 .await
103 .map_err(|err| {
104 error_response(
105 StatusCode::INTERNAL_SERVER_ERROR,
106 &format!("Failed to store DID document. {err}"),
107 )
108 })?;
109
110 let pds_endpoint = document.pds_endpoints().first().cloned().ok_or_else(|| {
111 error_response(
112 StatusCode::BAD_REQUEST,
113 "No PDS endpoint found for subject.",
114 )
115 })?;
116
117 let (_, authorization_server) = pds_resources(&http_client, &pds_endpoint)
118 .await
119 .map_err(|err| {
120 error_response(
121 StatusCode::BAD_GATEWAY,
122 &format!("Failed to discover authorization server. {err}"),
123 )
124 })?;
125
126 let (pkce_verifier, code_challenge) = pkce::generate();
127 let (oauth_state, nonce) = {
128 let mut rng = rand::rng();
129 (
130 Alphanumeric.sample_string(&mut rng, 32),
131 Alphanumeric.sample_string(&mut rng, 32),
132 )
133 };
134 let dpop_key = generate_key(KeyType::P256Private).map_err(|err| {
135 error_response(
136 StatusCode::INTERNAL_SERVER_ERROR,
137 &format!("Failed to generate DPoP key. {err}"),
138 )
139 })?;
140
141 let oauth_client = OAuthClient {
142 redirect_uri: oauth_client_config.redirect_uris.clone(),
143 client_id: oauth_client_config.client_id.clone(),
144 private_signing_key_data: oauth_signing_key.clone(),
145 };
146
147 let oauth_request_state = OAuthRequestState {
148 state: oauth_state.clone(),
149 nonce: nonce.clone(),
150 code_challenge,
151 scope: oauth_client_config.scope().to_string(),
152 };
153
154 let par_response = oauth_init(
155 &http_client,
156 &oauth_client,
157 &dpop_key,
158 Some(subject),
159 &authorization_server,
160 &oauth_request_state,
161 )
162 .await
163 .map_err(|err| {
164 error_response(
165 StatusCode::BAD_GATEWAY,
166 &format!("Failed to initialize OAuth flow. {err}"),
167 )
168 })?;
169
170 let public_signing_key = to_public(&oauth_signing_key).map_err(|err| {
171 error_response(
172 StatusCode::INTERNAL_SERVER_ERROR,
173 &format!("Failed to derive public signing key. {err}"),
174 )
175 })?;
176
177 let now = Utc::now();
178 let oauth_request = OAuthRequest {
179 oauth_state: oauth_state.clone(),
180 issuer: authorization_server.issuer.clone(),
181 authorization_server: pds_endpoint.to_string(),
182 nonce,
183 pkce_verifier,
184 signing_public_key: public_signing_key.to_string(),
185 dpop_private_key: dpop_key.to_string(),
186 created_at: now,
187 expires_at: now + Duration::hours(1),
188 };
189
190 oauth_request_storage
191 .insert_oauth_request(oauth_request)
192 .await
193 .map_err(|err| {
194 error_response(
195 StatusCode::INTERNAL_SERVER_ERROR,
196 &format!("Failed to store OAuth request. {err}"),
197 )
198 })?;
199
200 let auth_url = format!(
201 "{}?client_id={}&request_uri={}",
202 authorization_server.authorization_endpoint,
203 oauth_client.client_id,
204 par_response.request_uri
205 );
206
207 Ok(Redirect::to(&auth_url))
208}
209
210async fn oauth_callback(
211 State(oauth_request_storage): State<Arc<dyn OAuthRequestStorage>>,
212 State(did_document_storage): State<Arc<dyn DidDocumentStorage>>,
213 State(identity_resolver): State<Arc<dyn IdentityResolver>>,
214 State(oauth_client_config): State<OAuthClientConfig>,
215 State(key_resolver): State<Arc<dyn KeyResolver>>,
216 State(http_client): State<reqwest::Client>,
217 State(db_pool): State<PgPool>,
218 State(config): State<Config>,
219 Query(callback_form): Query<OAuthCallbackForm>,
220) -> Result<impl IntoResponse, OAuthCallbackError> {
221 let oauth_request = oauth_request_storage
222 .get_oauth_request_by_state(&callback_form.state)
223 .await?
224 .ok_or(OAuthCallbackError::NoOAuthRequestFound)?;
225
226 if oauth_request.issuer != callback_form.iss {
227 return Err(OAuthCallbackError::InvalidIssuer {
228 expected: oauth_request.issuer.clone(),
229 actual: callback_form.iss.clone(),
230 });
231 }
232
233 let private_signing_key_data = key_resolver
234 .resolve(&oauth_request.signing_public_key)
235 .await
236 .map_err(|_| OAuthCallbackError::NoSigningKeyFound)?;
237
238 let private_dpop_key_data = identify_key(&oauth_request.dpop_private_key)?;
239
240 let oauth_client = OAuthClient {
241 redirect_uri: oauth_client_config.redirect_uris.clone(),
242 client_id: oauth_client_config.client_id.clone(),
243 private_signing_key_data,
244 };
245
246 let (_, authorization_server) =
247 pds_resources(&http_client, &oauth_request.authorization_server).await?;
248
249 let token_response = oauth_complete(
250 &http_client,
251 &oauth_client,
252 &private_dpop_key_data,
253 &callback_form.code,
254 &oauth_request,
255 &authorization_server,
256 )
257 .await?;
258
259 ensure_granted_scopes_valid(
260 oauth_client_config.scope(),
261 &token_response.scope,
262 )
263 .map_err(map_oauth_policy_error)?;
264
265 let did = token_response
266 .sub
267 .clone()
268 .ok_or(OAuthCallbackError::NoDIDDocumentFound)?;
269
270 let document = identity_resolver.resolve(&did).await?;
271 did_document_storage.store_document(document.clone()).await?;
272
273 let did_pds_endpoint = document.pds_endpoints().first().cloned().ok_or(
274 OAuthCallbackError::OperationFailed {
275 error: anyhow!("No PDS endpoint found for token DID.").into(),
276 },
277 )?;
278 let (_, did_auth_server) =
279 pds_resources(&http_client, &did_pds_endpoint).await?;
280 ensure_did_matches_authorization_server(
281 &oauth_request.issuer,
282 &oauth_request.authorization_server,
283 &did_pds_endpoint,
284 &did_auth_server.issuer,
285 )
286 .map_err(map_oauth_policy_error)?;
287
288 let handle = document
289 .also_known_as
290 .iter()
291 .find_map(|item| item.strip_prefix("at://"))
292 .map(|value| value.to_string());
293
294 auth_db::upsert_user(&db_pool, &did, handle.as_deref())
295 .await
296 .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?;
297 auth_db::store_tokens(
298 &db_pool,
299 &did,
300 &token_response.access_token,
301 token_response.refresh_token.as_deref(),
302 &token_response.token_type,
303 &token_response.scope,
304 token_response.expires_in,
305 &oauth_request.issuer,
306 &oauth_request.dpop_private_key,
307 )
308 .await
309 .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?;
310
311 oauth_request_storage
312 .delete_oauth_request_by_state(&oauth_request.oauth_state)
313 .await?;
314
315 let session_id = Uuid::new_v4();
316 let session_expires_at =
317 Utc::now() + Duration::seconds(config.oauth_session_ttl_seconds);
318 auth_db::create_session(&db_pool, session_id, &did, session_expires_at)
319 .await
320 .map_err(|err| OAuthCallbackError::OperationFailed { error: err.into() })?;
321
322 let mut headers = HeaderMap::new();
323 headers.insert(
324 axum::http::header::SET_COOKIE,
325 build_session_cookie(
326 &config.oauth_cookie_name,
327 session_id,
328 config.oauth_session_ttl_seconds,
329 config.slipnote_env == "prod",
330 )?,
331 );
332
333 Ok((headers, Redirect::to(&config.oauth_post_auth_redirect)))
334}
335
336async fn current_user(
337 AuthenticatedUser(user): AuthenticatedUser,
338) -> Result<impl IntoResponse, (StatusCode, Json<ErrorResponse>)> {
339 Ok(Json(SessionUserResponse {
340 did: user.did,
341 handle: user.handle,
342 }))
343}
344
345async fn logout(
346 State(db_pool): State<PgPool>,
347 State(config): State<Config>,
348 headers: HeaderMap,
349) -> impl IntoResponse {
350 if let Some(session_id) = read_session_cookie(&headers, &config.oauth_cookie_name) {
351 let _ = auth_db::delete_session(&db_pool, session_id).await;
352 }
353
354 let mut response_headers = HeaderMap::new();
355 if let Ok(cookie) = clear_session_cookie(
356 &config.oauth_cookie_name,
357 config.slipnote_env == "prod",
358 ) {
359 response_headers.insert(axum::http::header::SET_COOKIE, cookie);
360 }
361
362 (response_headers, StatusCode::NO_CONTENT)
363}
364
365fn error_response(
366 status: StatusCode,
367 message: &str,
368) -> (StatusCode, Json<ErrorResponse>) {
369 (
370 status,
371 Json(ErrorResponse {
372 error: ErrorDetail {
373 message: message.to_string(),
374 },
375 }),
376 )
377}
378
379fn build_session_cookie(
380 name: &str,
381 session_id: Uuid,
382 max_age_seconds: i64,
383 secure: bool,
384) -> Result<HeaderValue, OAuthCallbackError> {
385 let mut cookie = format!(
386 "{}={}; Max-Age={}; Path=/; HttpOnly; SameSite=Lax",
387 name, session_id, max_age_seconds
388 );
389 if secure {
390 cookie.push_str("; Secure");
391 }
392 HeaderValue::from_str(&cookie).map_err(|err| {
393 OAuthCallbackError::OperationFailed {
394 error: err.into(),
395 }
396 })
397}
398
399fn clear_session_cookie(name: &str, secure: bool) -> Result<HeaderValue, OAuthCallbackError> {
400 let mut cookie = format!(
401 "{}=; Max-Age=0; Path=/; HttpOnly; SameSite=Lax",
402 name
403 );
404 if secure {
405 cookie.push_str("; Secure");
406 }
407 HeaderValue::from_str(&cookie).map_err(|err| {
408 OAuthCallbackError::OperationFailed {
409 error: err.into(),
410 }
411 })
412}
413
414
415fn map_oauth_policy_error(err: OAuthPolicyError) -> OAuthCallbackError {
416 match err {
417 OAuthPolicyError::MissingAtprotoScope => OAuthCallbackError::OperationFailed {
418 error: anyhow!("OAuth scopes must include atproto.").into(),
419 },
420 OAuthPolicyError::ScopeNotSubset => OAuthCallbackError::OperationFailed {
421 error: anyhow!("OAuth scopes must be a subset of requested scopes.").into(),
422 },
423 OAuthPolicyError::DidPdsMismatch => OAuthCallbackError::OperationFailed {
424 error: anyhow!("Token DID does not match PDS used for authorization.").into(),
425 },
426 OAuthPolicyError::IssuerMismatch { expected, actual } => {
427 OAuthCallbackError::InvalidIssuer { expected, actual }
428 }
429 }
430}
431
432#[cfg(test)]
433mod tests {
434 use super::*;
435 use proptest::prelude::*;
436
437 proptest! {
438 #[test]
439 fn build_session_cookie_sets_expected_flags(
440 name in "[a-zA-Z0-9_]{1,24}",
441 max_age in 1i64..86400i64,
442 uuid in any::<[u8; 16]>()
443 ) {
444 let session_id = Uuid::from_bytes(uuid);
445 let header_value = build_session_cookie(&name, session_id, max_age, true).unwrap();
446 let header_str = header_value.to_str().unwrap();
447 prop_assert!(header_str.contains("HttpOnly"));
448 prop_assert!(header_str.contains("SameSite=Lax"));
449 prop_assert!(header_str.contains("Secure"));
450 prop_assert!(
451 header_str.contains(&format!("{name}={session_id}")),
452 "session cookie missing name/value"
453 );
454 }
455 }
456
457}