Microservice to bring 2FA to self hosted PDSes
at main 24 kB view raw
1use crate::AppState; 2use crate::helpers::TokenCheckError::InvalidToken; 3use anyhow::anyhow; 4use axum::{ 5 body::{Body, to_bytes}, 6 extract::Request, 7 http::header::CONTENT_TYPE, 8 http::{HeaderMap, StatusCode, Uri}, 9 response::{IntoResponse, Response}, 10}; 11use axum_template::TemplateEngine; 12use chrono::Utc; 13use jacquard_common::{ 14 service_auth, service_auth::PublicKey, types::did::Did, types::did_doc::VerificationMethod, 15 types::nsid::Nsid, 16}; 17use jacquard_identity::{PublicResolver, resolver::IdentityResolver}; 18use josekit::jwe::alg::direct::DirectJweAlgorithm; 19use lettre::{ 20 AsyncTransport, Message, 21 message::{MultiPart, SinglePart, header}, 22}; 23use rand::Rng; 24use serde::de::DeserializeOwned; 25use serde_json::{Map, Value}; 26use sha2::{Digest, Sha256}; 27use sqlx::SqlitePool; 28use std::sync::Arc; 29use tracing::{error, log}; 30 31///Used to generate the email 2fa code 32const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; 33 34/// The result of a proxied call that attempts to parse JSON. 35pub enum ProxiedResult<T> { 36 /// Successfully parsed JSON body along with original response headers. 37 Parsed { value: T, _headers: HeaderMap }, 38 /// Could not or should not parse: return the original (or rebuilt) response as-is. 39 Passthrough(Response<Body>), 40} 41 42/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse 43/// the successful response body as JSON into `T`. 44/// 45pub async fn proxy_get_json<T>( 46 state: &AppState, 47 mut req: Request, 48 path: &str, 49) -> Result<ProxiedResult<T>, StatusCode> 50where 51 T: DeserializeOwned, 52{ 53 let uri = format!("{}{}", state.app_config.pds_base_url, path); 54 *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 55 56 let result = state 57 .reverse_proxy_client 58 .request(req) 59 .await 60 .map_err(|_| StatusCode::BAD_REQUEST)? 61 .into_response(); 62 63 if result.status() != StatusCode::OK { 64 return Ok(ProxiedResult::Passthrough(result)); 65 } 66 67 let response_headers = result.headers().clone(); 68 let body = result.into_body(); 69 let body_bytes = to_bytes(body, usize::MAX) 70 .await 71 .map_err(|_| StatusCode::BAD_REQUEST)?; 72 73 match serde_json::from_slice::<T>(&body_bytes) { 74 Ok(value) => Ok(ProxiedResult::Parsed { 75 value, 76 _headers: response_headers, 77 }), 78 Err(err) => { 79 error!(%err, "failed to parse proxied JSON response; returning original body"); 80 let mut builder = Response::builder().status(StatusCode::OK); 81 if let Some(headers) = builder.headers_mut() { 82 *headers = response_headers; 83 } 84 let resp = builder 85 .body(Body::from(body_bytes)) 86 .map_err(|_| StatusCode::BAD_REQUEST)?; 87 Ok(ProxiedResult::Passthrough(resp)) 88 } 89 } 90} 91 92/// Build a JSON error response with the required Content-Type header 93/// Content-Type: application/json;charset=utf-8 94/// Body shape: { "error": string, "message": string } 95pub fn json_error_response( 96 status: StatusCode, 97 error: impl Into<String>, 98 message: impl Into<String>, 99) -> Result<Response<Body>, StatusCode> { 100 let body_str = match serde_json::to_string(&serde_json::json!({ 101 "error": error.into(), 102 "message": message.into(), 103 })) { 104 Ok(s) => s, 105 Err(_) => return Err(StatusCode::BAD_REQUEST), 106 }; 107 108 Response::builder() 109 .status(status) 110 .header(CONTENT_TYPE, "application/json;charset=utf-8") 111 .body(Body::from(body_str)) 112 .map_err(|_| StatusCode::BAD_REQUEST) 113} 114 115/// Build a JSON error response with the required Content-Type header 116/// Content-Type: application/json (oauth endpoint does not like utf ending) 117/// Body shape: { "error": string, "error_description": string } 118pub fn oauth_json_error_response( 119 status: StatusCode, 120 error: impl Into<String>, 121 message: impl Into<String>, 122) -> Result<Response<Body>, StatusCode> { 123 let body_str = match serde_json::to_string(&serde_json::json!({ 124 "error": error.into(), 125 "error_description": message.into(), 126 })) { 127 Ok(s) => s, 128 Err(_) => return Err(StatusCode::BAD_REQUEST), 129 }; 130 131 Response::builder() 132 .status(status) 133 .header(CONTENT_TYPE, "application/json") 134 .body(Body::from(body_str)) 135 .map_err(|_| StatusCode::BAD_REQUEST) 136} 137 138/// Creates a random token of 10 characters for email 2FA 139pub fn get_random_token() -> String { 140 let mut rng = rand::rng(); 141 142 let mut full_code = String::with_capacity(10); 143 for _ in 0..10 { 144 let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len()); 145 full_code.push(UPPERCASE_BASE32_CHARS[idx] as char); 146 } 147 148 let slice_one = &full_code[0..5]; 149 let slice_two = &full_code[5..10]; 150 format!("{slice_one}-{slice_two}") 151} 152 153pub enum TokenCheckError { 154 InvalidToken, 155 ExpiredToken, 156} 157 158pub enum AuthResult { 159 WrongIdentityOrPassword, 160 /// The string here is the email address to create a hint for oauth 161 TwoFactorRequired(String), 162 /// User does not have 2FA enabled, or using an app password, or passes it 163 ProxyThrough, 164 TokenCheckFailed(TokenCheckError), 165} 166 167pub enum IdentifierType { 168 Email, 169 Did, 170 Handle, 171} 172 173impl IdentifierType { 174 fn what_is_it(identifier: String) -> Self { 175 if identifier.contains("@") { 176 IdentifierType::Email 177 } else if identifier.contains("did:") { 178 IdentifierType::Did 179 } else { 180 IdentifierType::Handle 181 } 182 } 183} 184 185/// Creates a hex string from the password and salt to find app passwords 186fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> { 187 let params = scrypt::Params::new(14, 8, 1, 64)?; 188 let mut derived = [0u8; 64]; 189 scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived)?; 190 Ok(hex::encode(derived)) 191} 192 193/// Hashes the app password. did is used as the salt. 194pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> { 195 let mut hasher = Sha256::new(); 196 hasher.update(did.as_bytes()); 197 let sha = hasher.finalize(); 198 let salt = hex::encode(&sha[..16]); 199 let hash_hex = scrypt_hex(password, &salt)?; 200 Ok(format!("{salt}:{hash_hex}")) 201} 202 203async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> { 204 // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) 205 let mut parts = password_scrypt.splitn(2, ':'); 206 let salt = match parts.next() { 207 Some(s) if !s.is_empty() => s, 208 _ => return Ok(false), 209 }; 210 let stored_hash_hex = match parts.next() { 211 Some(h) if !h.is_empty() => h, 212 _ => return Ok(false), 213 }; 214 215 // Derive using the shared helper and compare 216 let derived_hex = match scrypt_hex(password, salt) { 217 Ok(h) => h, 218 Err(_) => return Ok(false), 219 }; 220 221 Ok(derived_hex.as_str() == stored_hash_hex) 222} 223 224/// Handles the auth checks along with sending a 2fa email 225pub async fn preauth_check( 226 state: &AppState, 227 identifier: &str, 228 password: &str, 229 two_factor_code: Option<String>, 230 oauth: bool, 231) -> anyhow::Result<AuthResult> { 232 // Determine identifier type 233 let id_type = IdentifierType::what_is_it(identifier.to_string()); 234 235 // Query account DB for did and passwordScrypt based on identifier type 236 let account_row: Option<(String, String, String, String)> = match id_type { 237 IdentifierType::Email => { 238 sqlx::query_as::<_, (String, String, String, String)>( 239 "SELECT account.did, account.passwordScrypt, account.email, actor.handle 240 FROM actor 241 LEFT JOIN account ON actor.did = account.did 242 where account.email = ? LIMIT 1", 243 ) 244 .bind(identifier) 245 .fetch_optional(&state.account_pool) 246 .await? 247 } 248 IdentifierType::Handle => { 249 sqlx::query_as::<_, (String, String, String, String)>( 250 "SELECT account.did, account.passwordScrypt, account.email, actor.handle 251 FROM actor 252 LEFT JOIN account ON actor.did = account.did 253 where actor.handle = ? LIMIT 1", 254 ) 255 .bind(identifier) 256 .fetch_optional(&state.account_pool) 257 .await? 258 } 259 IdentifierType::Did => { 260 sqlx::query_as::<_, (String, String, String, String)>( 261 "SELECT account.did, account.passwordScrypt, account.email, actor.handle 262 FROM actor 263 LEFT JOIN account ON actor.did = account.did 264 where account.did = ? LIMIT 1", 265 ) 266 .bind(identifier) 267 .fetch_optional(&state.account_pool) 268 .await? 269 } 270 }; 271 272 if let Some((did, password_scrypt, email, handle)) = account_row { 273 // Verify password before proceeding to 2FA email step 274 let verified = verify_password(password, &password_scrypt).await?; 275 if !verified { 276 if oauth { 277 //OAuth does not allow app password logins so just go ahead and send it along it's way 278 return Ok(AuthResult::WrongIdentityOrPassword); 279 } 280 //Theres a chance it could be an app password so check that as well 281 return match verify_app_password(&state.account_pool, &did, password).await { 282 Ok(valid) => { 283 if valid { 284 //Was a valid app password up to the PDS now 285 Ok(AuthResult::ProxyThrough) 286 } else { 287 Ok(AuthResult::WrongIdentityOrPassword) 288 } 289 } 290 Err(err) => { 291 log::error!("Error checking the app password: {err}"); 292 Err(err) 293 } 294 }; 295 } 296 297 // Check two-factor requirement for this DID in the gatekeeper DB 298 let required_opt = sqlx::query_as::<_, (u8,)>( 299 "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 300 ) 301 .bind(did.clone()) 302 .fetch_optional(&state.pds_gatekeeper_pool) 303 .await?; 304 305 let two_factor_required = match required_opt { 306 Some(row) => row.0 != 0, 307 None => false, 308 }; 309 310 if two_factor_required { 311 //Two factor is required and a taken was provided 312 if let Some(two_factor_code) = two_factor_code { 313 //if the two_factor_code is set need to see if we have a valid token 314 if !two_factor_code.is_empty() { 315 return match assert_valid_token( 316 &state.account_pool, 317 did.clone(), 318 two_factor_code, 319 ) 320 .await 321 { 322 Ok(_) => { 323 let result_of_cleanup = 324 delete_all_email_tokens(&state.account_pool, did.clone()).await; 325 if result_of_cleanup.is_err() { 326 log::error!( 327 "There was an error deleting the email tokens after login: {:?}", 328 result_of_cleanup.err() 329 ) 330 } 331 Ok(AuthResult::ProxyThrough) 332 } 333 Err(err) => Ok(AuthResult::TokenCheckFailed(err)), 334 }; 335 } 336 } 337 338 return match create_two_factor_token(&state.account_pool, did).await { 339 Ok(code) => { 340 let mut email_data = Map::new(); 341 email_data.insert("token".to_string(), Value::from(code.clone())); 342 email_data.insert("handle".to_string(), Value::from(handle.clone())); 343 let email_body = state 344 .template_engine 345 .render("two_factor_code.hbs", email_data)?; 346 347 let email_message = Message::builder() 348 //TODO prob get the proper type in the state 349 .from(state.app_config.mailer_from.parse()?) 350 .to(email.parse()?) 351 .subject(&state.app_config.email_subject) 352 .multipart( 353 MultiPart::alternative() // This is composed of two parts. 354 .singlepart( 355 SinglePart::builder() 356 .header(header::ContentType::TEXT_PLAIN) 357 .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback. 358 ) 359 .singlepart( 360 SinglePart::builder() 361 .header(header::ContentType::TEXT_HTML) 362 .body(email_body), 363 ), 364 )?; 365 match state.mailer.send(email_message).await { 366 Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))), 367 Err(err) => { 368 log::error!("Error sending the 2FA email: {err}"); 369 Err(anyhow!(err)) 370 } 371 } 372 } 373 Err(err) => { 374 log::error!("error on creating a 2fa token: {err}"); 375 Err(anyhow!(err)) 376 } 377 }; 378 } 379 } 380 381 // No local 2FA requirement (or account not found) 382 Ok(AuthResult::ProxyThrough) 383} 384 385pub async fn create_two_factor_token( 386 account_db: &SqlitePool, 387 did: String, 388) -> anyhow::Result<String> { 389 let purpose = "2fa_code"; 390 391 let token = get_random_token(); 392 let right_now = Utc::now(); 393 394 let res = sqlx::query( 395 "INSERT INTO email_token (purpose, did, token, requestedAt) 396 VALUES (?, ?, ?, ?) 397 ON CONFLICT(purpose, did) DO UPDATE SET 398 token=excluded.token, 399 requestedAt=excluded.requestedAt 400 WHERE did=excluded.did", 401 ) 402 .bind(purpose) 403 .bind(&did) 404 .bind(&token) 405 .bind(right_now) 406 .execute(account_db) 407 .await; 408 409 match res { 410 Ok(_) => Ok(token), 411 Err(err) => { 412 log::error!("Error creating a two factor token: {err}"); 413 Err(anyhow::anyhow!(err)) 414 } 415 } 416} 417 418pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> { 419 sqlx::query("DELETE FROM email_token WHERE did = ?") 420 .bind(did) 421 .execute(account_db) 422 .await?; 423 Ok(()) 424} 425 426pub async fn assert_valid_token( 427 account_db: &SqlitePool, 428 did: String, 429 token: String, 430) -> Result<(), TokenCheckError> { 431 let token_upper = token.to_ascii_uppercase(); 432 let purpose = "2fa_code"; 433 434 let row: Option<(String,)> = sqlx::query_as( 435 "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1", 436 ) 437 .bind(purpose) 438 .bind(did) 439 .bind(token_upper) 440 .fetch_optional(account_db) 441 .await 442 .map_err(|err| { 443 log::error!("Error getting the 2fa token: {err}"); 444 InvalidToken 445 })?; 446 447 match row { 448 None => Err(InvalidToken), 449 Some(row) => { 450 // Token lives for 15 minutes 451 let expiration_ms = 15 * 60_000; 452 453 let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) { 454 Ok(dt) => dt.with_timezone(&Utc), 455 Err(_) => { 456 return Err(TokenCheckError::InvalidToken); 457 } 458 }; 459 460 let now = Utc::now(); 461 let age_ms = (now - requested_at_utc).num_milliseconds(); 462 let expired = age_ms > expiration_ms; 463 if expired { 464 return Err(TokenCheckError::ExpiredToken); 465 } 466 467 Ok(()) 468 } 469 } 470} 471 472/// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions 473pub async fn verify_app_password( 474 account_db: &SqlitePool, 475 did: &str, 476 password: &str, 477) -> anyhow::Result<bool> { 478 let password_scrypt = hash_app_password(did, password)?; 479 480 let row: Option<(i64,)> = sqlx::query_as( 481 "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1", 482 ) 483 .bind(did) 484 .bind(password_scrypt) 485 .fetch_optional(account_db) 486 .await?; 487 488 Ok(match row { 489 None => false, 490 Some((count,)) => count > 0, 491 }) 492} 493 494/// Mask an email address into a hint like "2***0@p***m". 495pub fn mask_email(email: String) -> String { 496 // Basic split on first '@' 497 let mut parts = email.splitn(2, '@'); 498 let local = match parts.next() { 499 Some(l) => l, 500 None => return email.to_string(), 501 }; 502 let domain_rest = match parts.next() { 503 Some(d) if !d.is_empty() => d, 504 _ => return email.to_string(), 505 }; 506 507 // Helper to mask a single label (keep first and last, middle becomes ***). 508 fn mask_label(s: &str) -> String { 509 let chars: Vec<char> = s.chars().collect(); 510 match chars.len() { 511 0 => String::new(), 512 1 => format!("{}***", chars[0]), 513 2 => format!("{}***{}", chars[0], chars[1]), 514 _ => format!("{}***{}", chars[0], chars[chars.len() - 1]), 515 } 516 } 517 518 // Mask local 519 let masked_local = mask_label(local); 520 521 // Mask first domain label only, keep the rest of the domain intact 522 let mut dom_parts = domain_rest.splitn(2, '.'); 523 let first_label = dom_parts.next().unwrap_or(""); 524 let rest = dom_parts.next(); 525 let masked_first = mask_label(first_label); 526 let masked_domain = if let Some(rest) = rest { 527 format!("{}.{rest}", masked_first) 528 } else { 529 masked_first 530 }; 531 532 format!("{masked_local}@{masked_domain}") 533} 534 535pub enum VerifyServiceAuthError { 536 AuthFailed, 537 Error(anyhow::Error), 538} 539 540/// Verifies the service auth token that is appended to an XRPC proxy request 541pub async fn verify_service_auth( 542 jwt: &str, 543 lxm: &Nsid<'static>, 544 public_resolver: Arc<PublicResolver>, 545 service_did: &Did<'static>, 546 //The did of the user wanting to create an account 547 requested_did: &Did<'static>, 548) -> Result<(), VerifyServiceAuthError> { 549 let parsed = 550 service_auth::parse_jwt(jwt).map_err(|e| VerifyServiceAuthError::Error(e.into()))?; 551 552 let claims = parsed.claims(); 553 554 let did_doc = public_resolver 555 .resolve_did_doc(&requested_did) 556 .await 557 .map_err(|err| { 558 log::error!("Error resolving the service auth for: {}", claims.iss); 559 return VerifyServiceAuthError::Error(err.into()); 560 })?; 561 562 // Parse the DID document response to get verification methods 563 let doc = did_doc.parse().map_err(|err| { 564 log::error!("Error parsing the service auth did doc: {}", claims.iss); 565 VerifyServiceAuthError::Error(anyhow::anyhow!(err)) 566 })?; 567 568 let verification_methods = doc.verification_method.as_deref().ok_or_else(|| { 569 VerifyServiceAuthError::Error(anyhow::anyhow!( 570 "No verification methods in did doc: {}", 571 &claims.iss 572 )) 573 })?; 574 575 let signing_key = extract_signing_key(verification_methods).ok_or_else(|| { 576 VerifyServiceAuthError::Error(anyhow::anyhow!( 577 "No signing key found in did doc: {}", 578 &claims.iss 579 )) 580 })?; 581 582 service_auth::verify_signature(&parsed, &signing_key).map_err(|err| { 583 log::error!("Error verifying service auth signature: {}", err); 584 VerifyServiceAuthError::AuthFailed 585 })?; 586 587 // Now validate claims (audience, expiration, etc.) 588 claims.validate(service_did).map_err(|e| { 589 log::error!("Error validating service auth claims: {}", e); 590 VerifyServiceAuthError::AuthFailed 591 })?; 592 593 if claims.aud != *service_did { 594 log::error!("Invalid audience (did:web): {}", claims.aud); 595 return Err(VerifyServiceAuthError::AuthFailed); 596 } 597 598 let lxm_from_claims = claims.lxm.as_ref().ok_or_else(|| { 599 VerifyServiceAuthError::Error(anyhow::anyhow!("No lxm claim in service auth JWT")) 600 })?; 601 602 if lxm_from_claims != lxm { 603 return Err(VerifyServiceAuthError::Error(anyhow::anyhow!( 604 "Invalid XRPC endpoint requested" 605 ))); 606 } 607 Ok(()) 608} 609 610/// Ripped from Jacquard 611/// 612/// Extract the signing key from a DID document's verification methods. 613/// 614/// This looks for a key with type "atproto" or the first available key 615/// if no atproto-specific key is found. 616fn extract_signing_key(methods: &[VerificationMethod]) -> Option<PublicKey> { 617 // First try to find an atproto-specific key 618 let atproto_method = methods 619 .iter() 620 .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto"); 621 622 let method = atproto_method.or_else(|| methods.first())?; 623 624 // Parse the multikey 625 let public_key_multibase = method.public_key_multibase.as_ref()?; 626 627 // Decode multibase 628 let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?; 629 630 // First two bytes are the multicodec prefix 631 if key_bytes.len() < 2 { 632 return None; 633 } 634 635 let codec = &key_bytes[..2]; 636 let key_material = &key_bytes[2..]; 637 638 match codec { 639 // p256-pub (0x1200) 640 [0x80, 0x24] => PublicKey::from_p256_bytes(key_material).ok(), 641 // secp256k1-pub (0xe7) 642 [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material).ok(), 643 _ => None, 644 } 645} 646 647/// Payload for gate JWE tokens 648#[derive(serde::Serialize, serde::Deserialize, Debug)] 649pub struct GateTokenPayload { 650 pub handle: String, 651 pub created_at: String, 652} 653 654/// Generate a secure JWE token for gate verification 655pub fn generate_gate_token(handle: &str, encryption_key: &[u8]) -> Result<String, anyhow::Error> { 656 use josekit::jwe::{JweHeader, alg::direct::DirectJweAlgorithm}; 657 658 let payload = GateTokenPayload { 659 handle: handle.to_string(), 660 created_at: Utc::now().to_rfc3339(), 661 }; 662 663 let payload_json = serde_json::to_string(&payload)?; 664 665 let mut header = JweHeader::new(); 666 header.set_token_type("JWT"); 667 header.set_content_encryption("A128CBC-HS256"); 668 669 let encrypter = DirectJweAlgorithm::Dir.encrypter_from_bytes(encryption_key)?; 670 671 // Encrypt 672 let jwe = josekit::jwe::serialize_compact(payload_json.as_bytes(), &header, &encrypter)?; 673 674 Ok(jwe) 675} 676 677/// Verify and decrypt a gate JWE token, returning the payload if valid 678pub fn verify_gate_token( 679 token: &str, 680 encryption_key: &[u8], 681) -> Result<GateTokenPayload, anyhow::Error> { 682 let decrypter = DirectJweAlgorithm::Dir.decrypter_from_bytes(encryption_key)?; 683 let (payload_bytes, _header) = josekit::jwe::deserialize_compact(token, &decrypter)?; 684 let payload: GateTokenPayload = serde_json::from_slice(&payload_bytes)?; 685 686 Ok(payload) 687}