forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::TokenCheckError::InvalidToken;
3use anyhow::anyhow;
4use axum::{
5 body::{Body, to_bytes},
6 extract::Request,
7 http::header::CONTENT_TYPE,
8 http::{HeaderMap, StatusCode, Uri},
9 response::{IntoResponse, Response},
10};
11use axum_template::TemplateEngine;
12use chrono::Utc;
13use jacquard_common::{
14 service_auth, service_auth::PublicKey, types::did::Did, types::did_doc::VerificationMethod,
15 types::nsid::Nsid,
16};
17use jacquard_identity::{PublicResolver, resolver::IdentityResolver};
18use josekit::jwe::alg::direct::DirectJweAlgorithm;
19use lettre::{
20 AsyncTransport, Message,
21 message::{MultiPart, SinglePart, header},
22};
23use rand::Rng;
24use serde::de::DeserializeOwned;
25use serde_json::{Map, Value};
26use sha2::{Digest, Sha256};
27use sqlx::SqlitePool;
28use std::sync::Arc;
29use tracing::{error, log};
30
31///Used to generate the email 2fa code
32const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
33
34/// The result of a proxied call that attempts to parse JSON.
35pub enum ProxiedResult<T> {
36 /// Successfully parsed JSON body along with original response headers.
37 Parsed { value: T, _headers: HeaderMap },
38 /// Could not or should not parse: return the original (or rebuilt) response as-is.
39 Passthrough(Response<Body>),
40}
41
42/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
43/// the successful response body as JSON into `T`.
44///
45pub async fn proxy_get_json<T>(
46 state: &AppState,
47 mut req: Request,
48 path: &str,
49) -> Result<ProxiedResult<T>, StatusCode>
50where
51 T: DeserializeOwned,
52{
53 let uri = format!("{}{}", state.app_config.pds_base_url, path);
54 *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
55
56 let result = state
57 .reverse_proxy_client
58 .request(req)
59 .await
60 .map_err(|_| StatusCode::BAD_REQUEST)?
61 .into_response();
62
63 if result.status() != StatusCode::OK {
64 return Ok(ProxiedResult::Passthrough(result));
65 }
66
67 let response_headers = result.headers().clone();
68 let body = result.into_body();
69 let body_bytes = to_bytes(body, usize::MAX)
70 .await
71 .map_err(|_| StatusCode::BAD_REQUEST)?;
72
73 match serde_json::from_slice::<T>(&body_bytes) {
74 Ok(value) => Ok(ProxiedResult::Parsed {
75 value,
76 _headers: response_headers,
77 }),
78 Err(err) => {
79 error!(%err, "failed to parse proxied JSON response; returning original body");
80 let mut builder = Response::builder().status(StatusCode::OK);
81 if let Some(headers) = builder.headers_mut() {
82 *headers = response_headers;
83 }
84 let resp = builder
85 .body(Body::from(body_bytes))
86 .map_err(|_| StatusCode::BAD_REQUEST)?;
87 Ok(ProxiedResult::Passthrough(resp))
88 }
89 }
90}
91
92/// Build a JSON error response with the required Content-Type header
93/// Content-Type: application/json;charset=utf-8
94/// Body shape: { "error": string, "message": string }
95pub fn json_error_response(
96 status: StatusCode,
97 error: impl Into<String>,
98 message: impl Into<String>,
99) -> Result<Response<Body>, StatusCode> {
100 let body_str = match serde_json::to_string(&serde_json::json!({
101 "error": error.into(),
102 "message": message.into(),
103 })) {
104 Ok(s) => s,
105 Err(_) => return Err(StatusCode::BAD_REQUEST),
106 };
107
108 Response::builder()
109 .status(status)
110 .header(CONTENT_TYPE, "application/json;charset=utf-8")
111 .body(Body::from(body_str))
112 .map_err(|_| StatusCode::BAD_REQUEST)
113}
114
115/// Build a JSON error response with the required Content-Type header
116/// Content-Type: application/json (oauth endpoint does not like utf ending)
117/// Body shape: { "error": string, "error_description": string }
118pub fn oauth_json_error_response(
119 status: StatusCode,
120 error: impl Into<String>,
121 message: impl Into<String>,
122) -> Result<Response<Body>, StatusCode> {
123 let body_str = match serde_json::to_string(&serde_json::json!({
124 "error": error.into(),
125 "error_description": message.into(),
126 })) {
127 Ok(s) => s,
128 Err(_) => return Err(StatusCode::BAD_REQUEST),
129 };
130
131 Response::builder()
132 .status(status)
133 .header(CONTENT_TYPE, "application/json")
134 .body(Body::from(body_str))
135 .map_err(|_| StatusCode::BAD_REQUEST)
136}
137
138/// Creates a random token of 10 characters for email 2FA
139pub fn get_random_token() -> String {
140 let mut rng = rand::rng();
141
142 let mut full_code = String::with_capacity(10);
143 for _ in 0..10 {
144 let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len());
145 full_code.push(UPPERCASE_BASE32_CHARS[idx] as char);
146 }
147
148 let slice_one = &full_code[0..5];
149 let slice_two = &full_code[5..10];
150 format!("{slice_one}-{slice_two}")
151}
152
153pub enum TokenCheckError {
154 InvalidToken,
155 ExpiredToken,
156}
157
158pub enum AuthResult {
159 WrongIdentityOrPassword,
160 /// The string here is the email address to create a hint for oauth
161 TwoFactorRequired(String),
162 /// User does not have 2FA enabled, or using an app password, or passes it
163 ProxyThrough,
164 TokenCheckFailed(TokenCheckError),
165}
166
167pub enum IdentifierType {
168 Email,
169 Did,
170 Handle,
171}
172
173impl IdentifierType {
174 fn what_is_it(identifier: String) -> Self {
175 if identifier.contains("@") {
176 IdentifierType::Email
177 } else if identifier.contains("did:") {
178 IdentifierType::Did
179 } else {
180 IdentifierType::Handle
181 }
182 }
183}
184
185/// Creates a hex string from the password and salt to find app passwords
186fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> {
187 let params = scrypt::Params::new(14, 8, 1, 64)?;
188 let mut derived = [0u8; 64];
189 scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived)?;
190 Ok(hex::encode(derived))
191}
192
193/// Hashes the app password. did is used as the salt.
194pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> {
195 let mut hasher = Sha256::new();
196 hasher.update(did.as_bytes());
197 let sha = hasher.finalize();
198 let salt = hex::encode(&sha[..16]);
199 let hash_hex = scrypt_hex(password, &salt)?;
200 Ok(format!("{salt}:{hash_hex}"))
201}
202
203async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> {
204 // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
205 let mut parts = password_scrypt.splitn(2, ':');
206 let salt = match parts.next() {
207 Some(s) if !s.is_empty() => s,
208 _ => return Ok(false),
209 };
210 let stored_hash_hex = match parts.next() {
211 Some(h) if !h.is_empty() => h,
212 _ => return Ok(false),
213 };
214
215 // Derive using the shared helper and compare
216 let derived_hex = match scrypt_hex(password, salt) {
217 Ok(h) => h,
218 Err(_) => return Ok(false),
219 };
220
221 Ok(derived_hex.as_str() == stored_hash_hex)
222}
223
224/// Handles the auth checks along with sending a 2fa email
225pub async fn preauth_check(
226 state: &AppState,
227 identifier: &str,
228 password: &str,
229 two_factor_code: Option<String>,
230 oauth: bool,
231) -> anyhow::Result<AuthResult> {
232 // Determine identifier type
233 let id_type = IdentifierType::what_is_it(identifier.to_string());
234
235 // Query account DB for did and passwordScrypt based on identifier type
236 let account_row: Option<(String, String, String, String)> = match id_type {
237 IdentifierType::Email => {
238 sqlx::query_as::<_, (String, String, String, String)>(
239 "SELECT account.did, account.passwordScrypt, account.email, actor.handle
240 FROM actor
241 LEFT JOIN account ON actor.did = account.did
242 where account.email = ? LIMIT 1",
243 )
244 .bind(identifier)
245 .fetch_optional(&state.account_pool)
246 .await?
247 }
248 IdentifierType::Handle => {
249 sqlx::query_as::<_, (String, String, String, String)>(
250 "SELECT account.did, account.passwordScrypt, account.email, actor.handle
251 FROM actor
252 LEFT JOIN account ON actor.did = account.did
253 where actor.handle = ? LIMIT 1",
254 )
255 .bind(identifier)
256 .fetch_optional(&state.account_pool)
257 .await?
258 }
259 IdentifierType::Did => {
260 sqlx::query_as::<_, (String, String, String, String)>(
261 "SELECT account.did, account.passwordScrypt, account.email, actor.handle
262 FROM actor
263 LEFT JOIN account ON actor.did = account.did
264 where account.did = ? LIMIT 1",
265 )
266 .bind(identifier)
267 .fetch_optional(&state.account_pool)
268 .await?
269 }
270 };
271
272 if let Some((did, password_scrypt, email, handle)) = account_row {
273 // Verify password before proceeding to 2FA email step
274 let verified = verify_password(password, &password_scrypt).await?;
275 if !verified {
276 if oauth {
277 //OAuth does not allow app password logins so just go ahead and send it along it's way
278 return Ok(AuthResult::WrongIdentityOrPassword);
279 }
280 //Theres a chance it could be an app password so check that as well
281 return match verify_app_password(&state.account_pool, &did, password).await {
282 Ok(valid) => {
283 if valid {
284 //Was a valid app password up to the PDS now
285 Ok(AuthResult::ProxyThrough)
286 } else {
287 Ok(AuthResult::WrongIdentityOrPassword)
288 }
289 }
290 Err(err) => {
291 log::error!("Error checking the app password: {err}");
292 Err(err)
293 }
294 };
295 }
296
297 // Check two-factor requirement for this DID in the gatekeeper DB
298 let required_opt = sqlx::query_as::<_, (u8,)>(
299 "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
300 )
301 .bind(did.clone())
302 .fetch_optional(&state.pds_gatekeeper_pool)
303 .await?;
304
305 let two_factor_required = match required_opt {
306 Some(row) => row.0 != 0,
307 None => false,
308 };
309
310 if two_factor_required {
311 //Two factor is required and a taken was provided
312 if let Some(two_factor_code) = two_factor_code {
313 //if the two_factor_code is set need to see if we have a valid token
314 if !two_factor_code.is_empty() {
315 return match assert_valid_token(
316 &state.account_pool,
317 did.clone(),
318 two_factor_code,
319 )
320 .await
321 {
322 Ok(_) => {
323 let result_of_cleanup =
324 delete_all_email_tokens(&state.account_pool, did.clone()).await;
325 if result_of_cleanup.is_err() {
326 log::error!(
327 "There was an error deleting the email tokens after login: {:?}",
328 result_of_cleanup.err()
329 )
330 }
331 Ok(AuthResult::ProxyThrough)
332 }
333 Err(err) => Ok(AuthResult::TokenCheckFailed(err)),
334 };
335 }
336 }
337
338 return match create_two_factor_token(&state.account_pool, did).await {
339 Ok(code) => {
340 let mut email_data = Map::new();
341 email_data.insert("token".to_string(), Value::from(code.clone()));
342 email_data.insert("handle".to_string(), Value::from(handle.clone()));
343 let email_body = state
344 .template_engine
345 .render("two_factor_code.hbs", email_data)?;
346
347 let email_message = Message::builder()
348 //TODO prob get the proper type in the state
349 .from(state.app_config.mailer_from.parse()?)
350 .to(email.parse()?)
351 .subject(&state.app_config.email_subject)
352 .multipart(
353 MultiPart::alternative() // This is composed of two parts.
354 .singlepart(
355 SinglePart::builder()
356 .header(header::ContentType::TEXT_PLAIN)
357 .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback.
358 )
359 .singlepart(
360 SinglePart::builder()
361 .header(header::ContentType::TEXT_HTML)
362 .body(email_body),
363 ),
364 )?;
365 match state.mailer.send(email_message).await {
366 Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))),
367 Err(err) => {
368 log::error!("Error sending the 2FA email: {err}");
369 Err(anyhow!(err))
370 }
371 }
372 }
373 Err(err) => {
374 log::error!("error on creating a 2fa token: {err}");
375 Err(anyhow!(err))
376 }
377 };
378 }
379 }
380
381 // No local 2FA requirement (or account not found)
382 Ok(AuthResult::ProxyThrough)
383}
384
385pub async fn create_two_factor_token(
386 account_db: &SqlitePool,
387 did: String,
388) -> anyhow::Result<String> {
389 let purpose = "2fa_code";
390
391 let token = get_random_token();
392 let right_now = Utc::now();
393
394 let res = sqlx::query(
395 "INSERT INTO email_token (purpose, did, token, requestedAt)
396 VALUES (?, ?, ?, ?)
397 ON CONFLICT(purpose, did) DO UPDATE SET
398 token=excluded.token,
399 requestedAt=excluded.requestedAt
400 WHERE did=excluded.did",
401 )
402 .bind(purpose)
403 .bind(&did)
404 .bind(&token)
405 .bind(right_now)
406 .execute(account_db)
407 .await;
408
409 match res {
410 Ok(_) => Ok(token),
411 Err(err) => {
412 log::error!("Error creating a two factor token: {err}");
413 Err(anyhow::anyhow!(err))
414 }
415 }
416}
417
418pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> {
419 sqlx::query("DELETE FROM email_token WHERE did = ?")
420 .bind(did)
421 .execute(account_db)
422 .await?;
423 Ok(())
424}
425
426pub async fn assert_valid_token(
427 account_db: &SqlitePool,
428 did: String,
429 token: String,
430) -> Result<(), TokenCheckError> {
431 let token_upper = token.to_ascii_uppercase();
432 let purpose = "2fa_code";
433
434 let row: Option<(String,)> = sqlx::query_as(
435 "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1",
436 )
437 .bind(purpose)
438 .bind(did)
439 .bind(token_upper)
440 .fetch_optional(account_db)
441 .await
442 .map_err(|err| {
443 log::error!("Error getting the 2fa token: {err}");
444 InvalidToken
445 })?;
446
447 match row {
448 None => Err(InvalidToken),
449 Some(row) => {
450 // Token lives for 15 minutes
451 let expiration_ms = 15 * 60_000;
452
453 let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) {
454 Ok(dt) => dt.with_timezone(&Utc),
455 Err(_) => {
456 return Err(TokenCheckError::InvalidToken);
457 }
458 };
459
460 let now = Utc::now();
461 let age_ms = (now - requested_at_utc).num_milliseconds();
462 let expired = age_ms > expiration_ms;
463 if expired {
464 return Err(TokenCheckError::ExpiredToken);
465 }
466
467 Ok(())
468 }
469 }
470}
471
472/// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions
473pub async fn verify_app_password(
474 account_db: &SqlitePool,
475 did: &str,
476 password: &str,
477) -> anyhow::Result<bool> {
478 let password_scrypt = hash_app_password(did, password)?;
479
480 let row: Option<(i64,)> = sqlx::query_as(
481 "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1",
482 )
483 .bind(did)
484 .bind(password_scrypt)
485 .fetch_optional(account_db)
486 .await?;
487
488 Ok(match row {
489 None => false,
490 Some((count,)) => count > 0,
491 })
492}
493
494/// Mask an email address into a hint like "2***0@p***m".
495pub fn mask_email(email: String) -> String {
496 // Basic split on first '@'
497 let mut parts = email.splitn(2, '@');
498 let local = match parts.next() {
499 Some(l) => l,
500 None => return email.to_string(),
501 };
502 let domain_rest = match parts.next() {
503 Some(d) if !d.is_empty() => d,
504 _ => return email.to_string(),
505 };
506
507 // Helper to mask a single label (keep first and last, middle becomes ***).
508 fn mask_label(s: &str) -> String {
509 let chars: Vec<char> = s.chars().collect();
510 match chars.len() {
511 0 => String::new(),
512 1 => format!("{}***", chars[0]),
513 2 => format!("{}***{}", chars[0], chars[1]),
514 _ => format!("{}***{}", chars[0], chars[chars.len() - 1]),
515 }
516 }
517
518 // Mask local
519 let masked_local = mask_label(local);
520
521 // Mask first domain label only, keep the rest of the domain intact
522 let mut dom_parts = domain_rest.splitn(2, '.');
523 let first_label = dom_parts.next().unwrap_or("");
524 let rest = dom_parts.next();
525 let masked_first = mask_label(first_label);
526 let masked_domain = if let Some(rest) = rest {
527 format!("{}.{rest}", masked_first)
528 } else {
529 masked_first
530 };
531
532 format!("{masked_local}@{masked_domain}")
533}
534
535pub enum VerifyServiceAuthError {
536 AuthFailed,
537 Error(anyhow::Error),
538}
539
540/// Verifies the service auth token that is appended to an XRPC proxy request
541pub async fn verify_service_auth(
542 jwt: &str,
543 lxm: &Nsid<'static>,
544 public_resolver: Arc<PublicResolver>,
545 service_did: &Did<'static>,
546 //The did of the user wanting to create an account
547 requested_did: &Did<'static>,
548) -> Result<(), VerifyServiceAuthError> {
549 let parsed =
550 service_auth::parse_jwt(jwt).map_err(|e| VerifyServiceAuthError::Error(e.into()))?;
551
552 let claims = parsed.claims();
553
554 let did_doc = public_resolver
555 .resolve_did_doc(&requested_did)
556 .await
557 .map_err(|err| {
558 log::error!("Error resolving the service auth for: {}", claims.iss);
559 return VerifyServiceAuthError::Error(err.into());
560 })?;
561
562 // Parse the DID document response to get verification methods
563 let doc = did_doc.parse().map_err(|err| {
564 log::error!("Error parsing the service auth did doc: {}", claims.iss);
565 VerifyServiceAuthError::Error(anyhow::anyhow!(err))
566 })?;
567
568 let verification_methods = doc.verification_method.as_deref().ok_or_else(|| {
569 VerifyServiceAuthError::Error(anyhow::anyhow!(
570 "No verification methods in did doc: {}",
571 &claims.iss
572 ))
573 })?;
574
575 let signing_key = extract_signing_key(verification_methods).ok_or_else(|| {
576 VerifyServiceAuthError::Error(anyhow::anyhow!(
577 "No signing key found in did doc: {}",
578 &claims.iss
579 ))
580 })?;
581
582 service_auth::verify_signature(&parsed, &signing_key).map_err(|err| {
583 log::error!("Error verifying service auth signature: {}", err);
584 VerifyServiceAuthError::AuthFailed
585 })?;
586
587 // Now validate claims (audience, expiration, etc.)
588 claims.validate(service_did).map_err(|e| {
589 log::error!("Error validating service auth claims: {}", e);
590 VerifyServiceAuthError::AuthFailed
591 })?;
592
593 if claims.aud != *service_did {
594 log::error!("Invalid audience (did:web): {}", claims.aud);
595 return Err(VerifyServiceAuthError::AuthFailed);
596 }
597
598 let lxm_from_claims = claims.lxm.as_ref().ok_or_else(|| {
599 VerifyServiceAuthError::Error(anyhow::anyhow!("No lxm claim in service auth JWT"))
600 })?;
601
602 if lxm_from_claims != lxm {
603 return Err(VerifyServiceAuthError::Error(anyhow::anyhow!(
604 "Invalid XRPC endpoint requested"
605 )));
606 }
607 Ok(())
608}
609
610/// Ripped from Jacquard
611///
612/// Extract the signing key from a DID document's verification methods.
613///
614/// This looks for a key with type "atproto" or the first available key
615/// if no atproto-specific key is found.
616fn extract_signing_key(methods: &[VerificationMethod]) -> Option<PublicKey> {
617 // First try to find an atproto-specific key
618 let atproto_method = methods
619 .iter()
620 .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto");
621
622 let method = atproto_method.or_else(|| methods.first())?;
623
624 // Parse the multikey
625 let public_key_multibase = method.public_key_multibase.as_ref()?;
626
627 // Decode multibase
628 let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?;
629
630 // First two bytes are the multicodec prefix
631 if key_bytes.len() < 2 {
632 return None;
633 }
634
635 let codec = &key_bytes[..2];
636 let key_material = &key_bytes[2..];
637
638 match codec {
639 // p256-pub (0x1200)
640 [0x80, 0x24] => PublicKey::from_p256_bytes(key_material).ok(),
641 // secp256k1-pub (0xe7)
642 [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material).ok(),
643 _ => None,
644 }
645}
646
647/// Payload for gate JWE tokens
648#[derive(serde::Serialize, serde::Deserialize, Debug)]
649pub struct GateTokenPayload {
650 pub handle: String,
651 pub created_at: String,
652}
653
654/// Generate a secure JWE token for gate verification
655pub fn generate_gate_token(handle: &str, encryption_key: &[u8]) -> Result<String, anyhow::Error> {
656 use josekit::jwe::{JweHeader, alg::direct::DirectJweAlgorithm};
657
658 let payload = GateTokenPayload {
659 handle: handle.to_string(),
660 created_at: Utc::now().to_rfc3339(),
661 };
662
663 let payload_json = serde_json::to_string(&payload)?;
664
665 let mut header = JweHeader::new();
666 header.set_token_type("JWT");
667 header.set_content_encryption("A128CBC-HS256");
668
669 let encrypter = DirectJweAlgorithm::Dir.encrypter_from_bytes(encryption_key)?;
670
671 // Encrypt
672 let jwe = josekit::jwe::serialize_compact(payload_json.as_bytes(), &header, &encrypter)?;
673
674 Ok(jwe)
675}
676
677/// Verify and decrypt a gate JWE token, returning the payload if valid
678pub fn verify_gate_token(
679 token: &str,
680 encryption_key: &[u8],
681) -> Result<GateTokenPayload, anyhow::Error> {
682 let decrypter = DirectJweAlgorithm::Dir.decrypter_from_bytes(encryption_key)?;
683 let (payload_bytes, _header) = josekit::jwe::deserialize_compact(token, &decrypter)?;
684 let payload: GateTokenPayload = serde_json::from_slice(&payload_bytes)?;
685
686 Ok(payload)
687}