forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::{
3 AuthResult, ProxiedResult, TokenCheckError, VerifyServiceAuthError, json_error_response,
4 preauth_check, proxy_get_json, verify_gate_token, verify_service_auth,
5};
6use crate::middleware::Did;
7use axum::body::{Body, to_bytes};
8use axum::extract::State;
9use axum::http::{HeaderMap, StatusCode, header};
10use axum::response::{IntoResponse, Response};
11use axum::{Extension, Json, debug_handler, extract, extract::Request};
12use chrono::{Duration, Utc};
13use jacquard_common::types::did::Did as JacquardDid;
14use serde::{Deserialize, Serialize};
15use serde_json;
16use tracing::log;
17
18#[derive(Serialize, Deserialize, Debug, Clone)]
19#[serde(rename_all = "camelCase")]
20enum AccountStatus {
21 Takendown,
22 Suspended,
23 Deactivated,
24}
25
26#[derive(Serialize, Deserialize, Debug, Clone)]
27#[serde(rename_all = "camelCase")]
28struct GetSessionResponse {
29 handle: String,
30 did: String,
31 #[serde(skip_serializing_if = "Option::is_none")]
32 email: Option<String>,
33 #[serde(skip_serializing_if = "Option::is_none")]
34 email_confirmed: Option<bool>,
35 #[serde(skip_serializing_if = "Option::is_none")]
36 email_auth_factor: Option<bool>,
37 #[serde(skip_serializing_if = "Option::is_none")]
38 did_doc: Option<String>,
39 #[serde(skip_serializing_if = "Option::is_none")]
40 active: Option<bool>,
41 #[serde(skip_serializing_if = "Option::is_none")]
42 status: Option<AccountStatus>,
43}
44
45#[derive(Serialize, Deserialize, Debug, Clone)]
46#[serde(rename_all = "camelCase")]
47pub struct UpdateEmailResponse {
48 email: String,
49 #[serde(skip_serializing_if = "Option::is_none")]
50 email_auth_factor: Option<bool>,
51 #[serde(skip_serializing_if = "Option::is_none")]
52 token: Option<String>,
53}
54
55#[allow(dead_code)]
56#[derive(Deserialize, Serialize)]
57#[serde(rename_all = "camelCase")]
58pub struct CreateSessionRequest {
59 identifier: String,
60 password: String,
61 #[serde(skip_serializing_if = "Option::is_none")]
62 auth_factor_token: Option<String>,
63 #[serde(skip_serializing_if = "Option::is_none")]
64 allow_takendown: Option<bool>,
65}
66
67#[derive(Deserialize, Serialize, Debug)]
68#[serde(rename_all = "camelCase")]
69pub struct CreateAccountRequest {
70 handle: String,
71 #[serde(skip_serializing_if = "Option::is_none")]
72 email: Option<String>,
73 #[serde(skip_serializing_if = "Option::is_none")]
74 password: Option<String>,
75 #[serde(skip_serializing_if = "Option::is_none")]
76 did: Option<String>,
77 #[serde(skip_serializing_if = "Option::is_none")]
78 invite_code: Option<String>,
79 #[serde(skip_serializing_if = "Option::is_none")]
80 verification_code: Option<String>,
81 #[serde(skip_serializing_if = "Option::is_none")]
82 plc_op: Option<serde_json::Value>,
83}
84
85#[derive(Deserialize, Serialize, Debug, Clone)]
86#[serde(rename_all = "camelCase")]
87pub struct DescribeServerContact {
88 #[serde(skip_serializing_if = "Option::is_none")]
89 email: Option<String>,
90}
91
92#[derive(Deserialize, Serialize, Debug, Clone)]
93#[serde(rename_all = "camelCase")]
94pub struct DescribeServerLinks {
95 #[serde(skip_serializing_if = "Option::is_none")]
96 privacy_policy: Option<String>,
97 #[serde(skip_serializing_if = "Option::is_none")]
98 terms_of_service: Option<String>,
99}
100
101#[derive(Deserialize, Serialize, Debug, Clone)]
102#[serde(rename_all = "camelCase")]
103pub struct DescribeServerResponse {
104 #[serde(skip_serializing_if = "Option::is_none")]
105 invite_code_required: Option<bool>,
106 #[serde(skip_serializing_if = "Option::is_none")]
107 phone_verification_required: Option<bool>,
108 #[serde(skip_serializing_if = "Option::is_none")]
109 available_user_domains: Option<Vec<String>>,
110 #[serde(skip_serializing_if = "Option::is_none")]
111 links: Option<DescribeServerLinks>,
112 #[serde(skip_serializing_if = "Option::is_none")]
113 contact: Option<DescribeServerContact>,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 did: Option<String>,
116}
117
118pub async fn create_session(
119 State(state): State<AppState>,
120 headers: HeaderMap,
121 Json(payload): extract::Json<CreateSessionRequest>,
122) -> Result<Response<Body>, StatusCode> {
123 let identifier = payload.identifier.clone();
124 let password = payload.password.clone();
125 let auth_factor_token = payload.auth_factor_token.clone();
126
127 // Run the shared pre-auth logic to validate and check 2FA requirement
128 match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
129 Ok(result) => match result {
130 AuthResult::WrongIdentityOrPassword => json_error_response(
131 StatusCode::UNAUTHORIZED,
132 "AuthenticationRequired",
133 "Invalid identifier or password",
134 ),
135 AuthResult::TwoFactorRequired(_) => {
136 // Email sending step can be handled here if needed in the future.
137 json_error_response(
138 StatusCode::UNAUTHORIZED,
139 "AuthFactorTokenRequired",
140 "A sign in code has been sent to your email address",
141 )
142 }
143 AuthResult::ProxyThrough => {
144 //No 2FA or already passed
145 let uri = format!(
146 "{}{}",
147 state.app_config.pds_base_url, "/xrpc/com.atproto.server.createSession"
148 );
149
150 let mut req = axum::http::Request::post(uri);
151 if let Some(req_headers) = req.headers_mut() {
152 req_headers.extend(headers.clone());
153 }
154
155 let payload_bytes =
156 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
157 let req = req
158 .body(Body::from(payload_bytes))
159 .map_err(|_| StatusCode::BAD_REQUEST)?;
160
161 let proxied = state
162 .reverse_proxy_client
163 .request(req)
164 .await
165 .map_err(|_| StatusCode::BAD_REQUEST)?
166 .into_response();
167
168 Ok(proxied)
169 }
170 AuthResult::TokenCheckFailed(err) => match err {
171 TokenCheckError::InvalidToken => {
172 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
173 }
174 TokenCheckError::ExpiredToken => {
175 json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
176 }
177 },
178 },
179 Err(err) => {
180 log::error!(
181 "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
182 );
183 json_error_response(
184 StatusCode::INTERNAL_SERVER_ERROR,
185 "InternalServerError",
186 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
187 )
188 }
189 }
190}
191
192#[debug_handler]
193pub async fn update_email(
194 State(state): State<AppState>,
195 Extension(did): Extension<Did>,
196 headers: HeaderMap,
197 Json(payload): extract::Json<UpdateEmailResponse>,
198) -> Result<Response<Body>, StatusCode> {
199 //If email auth is not set at all it is a update email address
200 let email_auth_not_set = payload.email_auth_factor.is_none();
201 //If email auth is set it is to either turn on or off 2fa
202 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
203
204 //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS
205 //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented
206 let did_is_not_empty = did.0.is_some();
207
208 if did_is_not_empty {
209 // Email update asked for
210 if email_auth_update {
211 let email = payload.email.clone();
212 let email_confirmed = match sqlx::query_as::<_, (String,)>(
213 "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
214 )
215 .bind(&email)
216 .fetch_optional(&state.account_pool)
217 .await
218 {
219 Ok(row) => row,
220 Err(err) => {
221 log::error!("Error checking if email is confirmed: {err}");
222 return Err(StatusCode::BAD_REQUEST);
223 }
224 };
225
226 //Since the email is already confirmed we can enable 2fa
227 return match email_confirmed {
228 None => Err(StatusCode::BAD_REQUEST),
229 Some(did_row) => {
230 let _ = sqlx::query(
231 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
232 )
233 .bind(&did_row.0)
234 .execute(&state.pds_gatekeeper_pool)
235 .await
236 .map_err(|err| {
237 log::error!("Error enabling 2FA: {err}");
238 StatusCode::BAD_REQUEST
239 })?;
240
241 Ok(StatusCode::OK.into_response())
242 }
243 };
244 }
245
246 // User wants auth turned off
247 if !email_auth_update && !email_auth_not_set {
248 //User wants auth turned off and has a token
249 if let Some(token) = &payload.token {
250 let token_found = match sqlx::query_as::<_, (String,)>(
251 "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
252 )
253 .bind(token)
254 .bind(&did.0)
255 .fetch_optional(&state.account_pool)
256 .await{
257 Ok(token) => token,
258 Err(err) => {
259 log::error!("Error checking if token is valid: {err}");
260 return Err(StatusCode::BAD_REQUEST);
261 }
262 };
263
264 return if token_found.is_some() {
265 //TODO I think there may be a bug here and need to do some retry logic
266 // First try was erroring, seconds was allowing
267 match sqlx::query(
268 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
269 )
270 .bind(&did.0)
271 .execute(&state.pds_gatekeeper_pool)
272 .await {
273 Ok(_) => {}
274 Err(err) => {
275 log::error!("Error updating email auth: {err}");
276 return Err(StatusCode::BAD_REQUEST);
277 }
278 }
279
280 Ok(StatusCode::OK.into_response())
281 } else {
282 Err(StatusCode::BAD_REQUEST)
283 };
284 }
285 }
286 }
287 // Updating the actual email address by sending it on to the PDS
288 let uri = format!(
289 "{}{}",
290 state.app_config.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
291 );
292 let mut req = axum::http::Request::post(uri);
293 if let Some(req_headers) = req.headers_mut() {
294 req_headers.extend(headers.clone());
295 }
296
297 let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
298 let req = req
299 .body(Body::from(payload_bytes))
300 .map_err(|_| StatusCode::BAD_REQUEST)?;
301
302 let proxied = state
303 .reverse_proxy_client
304 .request(req)
305 .await
306 .map_err(|_| StatusCode::BAD_REQUEST)?
307 .into_response();
308
309 Ok(proxied)
310}
311
312pub async fn get_session(
313 State(state): State<AppState>,
314 req: Request,
315) -> Result<Response<Body>, StatusCode> {
316 match proxy_get_json::<GetSessionResponse>(&state, req, "/xrpc/com.atproto.server.getSession")
317 .await?
318 {
319 ProxiedResult::Parsed {
320 value: mut session, ..
321 } => {
322 let did = session.did.clone();
323 let required_opt = sqlx::query_as::<_, (u8,)>(
324 "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
325 )
326 .bind(&did)
327 .fetch_optional(&state.pds_gatekeeper_pool)
328 .await
329 .map_err(|_| StatusCode::BAD_REQUEST)?;
330
331 let email_auth_factor = match required_opt {
332 Some(row) => row.0 != 0,
333 None => false,
334 };
335
336 session.email_auth_factor = Some(email_auth_factor);
337 Ok(Json(session).into_response())
338 }
339 ProxiedResult::Passthrough(resp) => Ok(resp),
340 }
341}
342
343pub async fn describe_server(
344 State(state): State<AppState>,
345 req: Request,
346) -> Result<Response<Body>, StatusCode> {
347 match proxy_get_json::<DescribeServerResponse>(
348 &state,
349 req,
350 "/xrpc/com.atproto.server.describeServer",
351 )
352 .await?
353 {
354 ProxiedResult::Parsed {
355 value: mut server_info,
356 ..
357 } => {
358 //This signifies the server is configured for captcha verification
359 server_info.phone_verification_required = Some(state.app_config.use_captcha);
360 Ok(Json(server_info).into_response())
361 }
362 ProxiedResult::Passthrough(resp) => Ok(resp),
363 }
364}
365
366/// Verify a gate code matches the handle and is not expired
367async fn verify_gate_code(
368 state: &AppState,
369 code: &str,
370 handle: &str,
371) -> Result<bool, anyhow::Error> {
372 // First, decrypt and verify the JWE token
373 let payload = match verify_gate_token(code, &state.app_config.gate_jwe_key) {
374 Ok(p) => p,
375 Err(e) => {
376 log::warn!("Failed to decrypt gate token: {}", e);
377 return Ok(false);
378 }
379 };
380
381 // Verify the handle matches
382 if payload.handle != handle {
383 log::warn!(
384 "Gate code handle mismatch: expected {}, got {}",
385 handle,
386 payload.handle
387 );
388 return Ok(false);
389 }
390
391 let created_at = chrono::DateTime::parse_from_rfc3339(&payload.created_at)
392 .map_err(|e| anyhow::anyhow!("Failed to parse created_at from token: {}", e))?
393 .with_timezone(&Utc);
394
395 let now = Utc::now();
396 let age = now - created_at;
397
398 // Check if the token is expired (5 minutes)
399 if age > Duration::minutes(5) {
400 log::warn!("Gate code expired for handle {}", handle);
401 return Ok(false);
402 }
403
404 // Verify the token exists in the database (to prevent reuse)
405 let row: Option<(String,)> =
406 sqlx::query_as("SELECT code FROM gate_codes WHERE code = ? and handle = ? LIMIT 1")
407 .bind(code)
408 .bind(handle)
409 .fetch_optional(&state.pds_gatekeeper_pool)
410 .await?;
411
412 if row.is_none() {
413 log::warn!("Gate code not found in database or already used");
414 return Ok(false);
415 }
416
417 // Token is valid, delete it so it can't be reused
418 //TODO probably also delete expired codes? Will need to do that at some point probably altho the where is on code and handle
419
420 sqlx::query("DELETE FROM gate_codes WHERE code = ?")
421 .bind(code)
422 .execute(&state.pds_gatekeeper_pool)
423 .await?;
424
425 Ok(true)
426}
427
428pub async fn create_account(
429 State(state): State<AppState>,
430 req: Request,
431) -> Result<Response<Body>, StatusCode> {
432 let headers = req.headers().clone();
433 let body_bytes = to_bytes(req.into_body(), usize::MAX)
434 .await
435 .map_err(|_| StatusCode::BAD_REQUEST)?;
436
437 // Parse the body to check for verification code
438 let account_request: CreateAccountRequest =
439 serde_json::from_slice(&body_bytes).map_err(|e| {
440 log::error!("Failed to parse create account request: {}", e);
441 StatusCode::BAD_REQUEST
442 })?;
443
444 // Check for service auth (migrations) if configured
445 if state.app_config.allow_only_migrations {
446 // Expect Authorization: Bearer <jwt>
447 let auth_header = headers
448 .get(header::AUTHORIZATION)
449 .and_then(|v| v.to_str().ok())
450 .map(str::to_string);
451
452 let Some(value) = auth_header else {
453 log::error!("No Authorization header found in the request");
454 return json_error_response(
455 StatusCode::UNAUTHORIZED,
456 "InvalidAuth",
457 "This PDS is configured to only allow accounts created by migrations via this endpoint.",
458 );
459 };
460
461 // Ensure Bearer prefix
462 let token = value.strip_prefix("Bearer ").unwrap_or("").trim();
463 if token.is_empty() {
464 log::error!("No Service Auth token found in the Authorization header");
465 return json_error_response(
466 StatusCode::UNAUTHORIZED,
467 "InvalidAuth",
468 "This PDS is configured to only allow accounts created by migrations via this endpoint.",
469 );
470 }
471
472 // Ensure a non-empty DID was provided when migrations are enabled
473 let requested_did_str = match account_request.did.as_deref() {
474 Some(s) if !s.trim().is_empty() => s,
475 _ => {
476 return json_error_response(
477 StatusCode::BAD_REQUEST,
478 "InvalidRequest",
479 "The 'did' field is required when migrations are enforced.",
480 );
481 }
482 };
483
484 // Parse the DID into the expected type for verification
485 let requested_did: JacquardDid<'static> = match requested_did_str.parse() {
486 Ok(d) => d,
487 Err(e) => {
488 log::error!(
489 "Invalid DID format provided in createAccount: {} | error: {}",
490 requested_did_str,
491 e
492 );
493 return json_error_response(
494 StatusCode::BAD_REQUEST,
495 "InvalidRequest",
496 "The 'did' field is not a valid DID.",
497 );
498 }
499 };
500
501 let nsid = "com.atproto.server.createAccount".parse().unwrap();
502 match verify_service_auth(
503 token,
504 &nsid,
505 state.resolver.clone(),
506 &state.app_config.pds_service_did,
507 &requested_did,
508 )
509 .await
510 {
511 //Just do nothing if it passes so it continues.
512 Ok(_) => {}
513 Err(err) => match err {
514 VerifyServiceAuthError::AuthFailed => {
515 return json_error_response(
516 StatusCode::UNAUTHORIZED,
517 "InvalidAuth",
518 "This PDS is configured to only allow accounts created by migrations via this endpoint.",
519 );
520 }
521 VerifyServiceAuthError::Error(err) => {
522 log::error!("Error verifying service auth token: {err}");
523 return json_error_response(
524 StatusCode::BAD_REQUEST,
525 "InvalidRequest",
526 "There has been an error, please contact your PDS administrator for help and for them to review the server logs.",
527 );
528 }
529 },
530 }
531 }
532
533 // Check for captcha verification if configured
534 if state.app_config.use_captcha {
535 if let Some(ref verification_code) = account_request.verification_code {
536 match verify_gate_code(&state, verification_code, &account_request.handle).await {
537 //TODO has a few errors to support
538
539 //expired token
540 // {
541 // "error": "ExpiredToken",
542 // "message": "Token has expired"
543 // }
544
545 //TODO ALSO add rate limits on the /gate endpoints so they can't be abused
546 Ok(true) => {
547 log::info!("Gate code verified for handle: {}", account_request.handle);
548 }
549 Ok(false) => {
550 log::warn!(
551 "Invalid or expired gate code for handle: {}",
552 account_request.handle
553 );
554 return json_error_response(
555 StatusCode::BAD_REQUEST,
556 "InvalidToken",
557 "Token could not be verified",
558 );
559 }
560 Err(e) => {
561 log::error!("Error verifying gate code: {}", e);
562 return json_error_response(
563 StatusCode::INTERNAL_SERVER_ERROR,
564 "InvalidToken",
565 "Token could not be verified",
566 );
567 }
568 }
569 } else {
570 // No verification code provided but captcha is required
571 log::warn!(
572 "No verification code provided for account creation: {}",
573 account_request.handle
574 );
575 return json_error_response(
576 StatusCode::BAD_REQUEST,
577 "InvalidRequest",
578 "Verification is now required on this server.",
579 );
580 }
581 }
582
583 // Rebuild the request with the same body and headers
584 let uri = format!(
585 "{}{}",
586 state.app_config.pds_base_url, "/xrpc/com.atproto.server.createAccount"
587 );
588
589 let mut new_req = axum::http::Request::post(&uri);
590 if let Some(req_headers) = new_req.headers_mut() {
591 *req_headers = headers;
592 }
593
594 let new_req = new_req
595 .body(Body::from(body_bytes))
596 .map_err(|_| StatusCode::BAD_REQUEST)?;
597
598 let proxied = state
599 .reverse_proxy_client
600 .request(new_req)
601 .await
602 .map_err(|_| StatusCode::BAD_REQUEST)?
603 .into_response();
604
605 Ok(proxied)
606}