use crate::AppState; use crate::helpers::TokenCheckError::InvalidToken; use anyhow::anyhow; use axum::{ body::{Body, to_bytes}, extract::Request, http::header::CONTENT_TYPE, http::{HeaderMap, StatusCode, Uri}, response::{IntoResponse, Response}, }; use axum_template::TemplateEngine; use chrono::Utc; use jacquard_common::{ service_auth, service_auth::PublicKey, types::did::Did, types::did_doc::VerificationMethod, types::nsid::Nsid, }; use jacquard_identity::{PublicResolver, resolver::IdentityResolver}; use josekit::jwe::alg::direct::DirectJweAlgorithm; use lettre::{ AsyncTransport, Message, message::{MultiPart, SinglePart, header}, }; use rand::Rng; use serde::de::DeserializeOwned; use serde_json::{Map, Value}; use sha2::{Digest, Sha256}; use sqlx::SqlitePool; use std::sync::Arc; use tracing::{error, log}; ///Used to generate the email 2fa code const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; /// The result of a proxied call that attempts to parse JSON. pub enum ProxiedResult { /// Successfully parsed JSON body along with original response headers. Parsed { value: T, _headers: HeaderMap }, /// Could not or should not parse: return the original (or rebuilt) response as-is. Passthrough(Response), } /// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse /// the successful response body as JSON into `T`. /// pub async fn proxy_get_json( state: &AppState, mut req: Request, path: &str, ) -> Result, StatusCode> where T: DeserializeOwned, { let uri = format!("{}{}", state.app_config.pds_base_url, path); *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; let result = state .reverse_proxy_client .request(req) .await .map_err(|_| StatusCode::BAD_REQUEST)? .into_response(); if result.status() != StatusCode::OK { return Ok(ProxiedResult::Passthrough(result)); } let response_headers = result.headers().clone(); let body = result.into_body(); let body_bytes = to_bytes(body, usize::MAX) .await .map_err(|_| StatusCode::BAD_REQUEST)?; match serde_json::from_slice::(&body_bytes) { Ok(value) => Ok(ProxiedResult::Parsed { value, _headers: response_headers, }), Err(err) => { error!(%err, "failed to parse proxied JSON response; returning original body"); let mut builder = Response::builder().status(StatusCode::OK); if let Some(headers) = builder.headers_mut() { *headers = response_headers; } let resp = builder .body(Body::from(body_bytes)) .map_err(|_| StatusCode::BAD_REQUEST)?; Ok(ProxiedResult::Passthrough(resp)) } } } /// Build a JSON error response with the required Content-Type header /// Content-Type: application/json;charset=utf-8 /// Body shape: { "error": string, "message": string } pub fn json_error_response( status: StatusCode, error: impl Into, message: impl Into, ) -> Result, StatusCode> { let body_str = match serde_json::to_string(&serde_json::json!({ "error": error.into(), "message": message.into(), })) { Ok(s) => s, Err(_) => return Err(StatusCode::BAD_REQUEST), }; Response::builder() .status(status) .header(CONTENT_TYPE, "application/json;charset=utf-8") .body(Body::from(body_str)) .map_err(|_| StatusCode::BAD_REQUEST) } /// Build a JSON error response with the required Content-Type header /// Content-Type: application/json (oauth endpoint does not like utf ending) /// Body shape: { "error": string, "error_description": string } pub fn oauth_json_error_response( status: StatusCode, error: impl Into, message: impl Into, ) -> Result, StatusCode> { let body_str = match serde_json::to_string(&serde_json::json!({ "error": error.into(), "error_description": message.into(), })) { Ok(s) => s, Err(_) => return Err(StatusCode::BAD_REQUEST), }; Response::builder() .status(status) .header(CONTENT_TYPE, "application/json") .body(Body::from(body_str)) .map_err(|_| StatusCode::BAD_REQUEST) } /// Creates a random token of 10 characters for email 2FA pub fn get_random_token() -> String { let mut rng = rand::rng(); let mut full_code = String::with_capacity(10); for _ in 0..10 { let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len()); full_code.push(UPPERCASE_BASE32_CHARS[idx] as char); } let slice_one = &full_code[0..5]; let slice_two = &full_code[5..10]; format!("{slice_one}-{slice_two}") } pub enum TokenCheckError { InvalidToken, ExpiredToken, } pub enum AuthResult { WrongIdentityOrPassword, /// The string here is the email address to create a hint for oauth TwoFactorRequired(String), /// User does not have 2FA enabled, or using an app password, or passes it ProxyThrough, TokenCheckFailed(TokenCheckError), } pub enum IdentifierType { Email, Did, Handle, } impl IdentifierType { fn what_is_it(identifier: String) -> Self { if identifier.contains("@") { IdentifierType::Email } else if identifier.contains("did:") { IdentifierType::Did } else { IdentifierType::Handle } } } /// Creates a hex string from the password and salt to find app passwords fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result { let params = scrypt::Params::new(14, 8, 1, 64)?; let mut derived = [0u8; 64]; scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived)?; Ok(hex::encode(derived)) } /// Hashes the app password. did is used as the salt. pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result { let mut hasher = Sha256::new(); hasher.update(did.as_bytes()); let sha = hasher.finalize(); let salt = hex::encode(&sha[..16]); let hash_hex = scrypt_hex(password, &salt)?; Ok(format!("{salt}:{hash_hex}")) } async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result { // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) let mut parts = password_scrypt.splitn(2, ':'); let salt = match parts.next() { Some(s) if !s.is_empty() => s, _ => return Ok(false), }; let stored_hash_hex = match parts.next() { Some(h) if !h.is_empty() => h, _ => return Ok(false), }; // Derive using the shared helper and compare let derived_hex = match scrypt_hex(password, salt) { Ok(h) => h, Err(_) => return Ok(false), }; Ok(derived_hex.as_str() == stored_hash_hex) } /// Handles the auth checks along with sending a 2fa email pub async fn preauth_check( state: &AppState, identifier: &str, password: &str, two_factor_code: Option, oauth: bool, ) -> anyhow::Result { // Determine identifier type let id_type = IdentifierType::what_is_it(identifier.to_string()); // Query account DB for did and passwordScrypt based on identifier type let account_row: Option<(String, String, String, String)> = match id_type { IdentifierType::Email => { sqlx::query_as::<_, (String, String, String, String)>( "SELECT account.did, account.passwordScrypt, account.email, actor.handle FROM actor LEFT JOIN account ON actor.did = account.did where account.email = ? LIMIT 1", ) .bind(identifier) .fetch_optional(&state.account_pool) .await? } IdentifierType::Handle => { sqlx::query_as::<_, (String, String, String, String)>( "SELECT account.did, account.passwordScrypt, account.email, actor.handle FROM actor LEFT JOIN account ON actor.did = account.did where actor.handle = ? LIMIT 1", ) .bind(identifier) .fetch_optional(&state.account_pool) .await? } IdentifierType::Did => { sqlx::query_as::<_, (String, String, String, String)>( "SELECT account.did, account.passwordScrypt, account.email, actor.handle FROM actor LEFT JOIN account ON actor.did = account.did where account.did = ? LIMIT 1", ) .bind(identifier) .fetch_optional(&state.account_pool) .await? } }; if let Some((did, password_scrypt, email, handle)) = account_row { // Verify password before proceeding to 2FA email step let verified = verify_password(password, &password_scrypt).await?; if !verified { if oauth { //OAuth does not allow app password logins so just go ahead and send it along it's way return Ok(AuthResult::WrongIdentityOrPassword); } //Theres a chance it could be an app password so check that as well return match verify_app_password(&state.account_pool, &did, password).await { Ok(valid) => { if valid { //Was a valid app password up to the PDS now Ok(AuthResult::ProxyThrough) } else { Ok(AuthResult::WrongIdentityOrPassword) } } Err(err) => { log::error!("Error checking the app password: {err}"); Err(err) } }; } // Check two-factor requirement for this DID in the gatekeeper DB let required_opt = sqlx::query_as::<_, (u8,)>( "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", ) .bind(did.clone()) .fetch_optional(&state.pds_gatekeeper_pool) .await?; let two_factor_required = match required_opt { Some(row) => row.0 != 0, None => false, }; if two_factor_required { //Two factor is required and a taken was provided if let Some(two_factor_code) = two_factor_code { //if the two_factor_code is set need to see if we have a valid token if !two_factor_code.is_empty() { return match assert_valid_token( &state.account_pool, did.clone(), two_factor_code, ) .await { Ok(_) => { let result_of_cleanup = delete_all_email_tokens(&state.account_pool, did.clone()).await; if result_of_cleanup.is_err() { log::error!( "There was an error deleting the email tokens after login: {:?}", result_of_cleanup.err() ) } Ok(AuthResult::ProxyThrough) } Err(err) => Ok(AuthResult::TokenCheckFailed(err)), }; } } return match create_two_factor_token(&state.account_pool, did).await { Ok(code) => { let mut email_data = Map::new(); email_data.insert("token".to_string(), Value::from(code.clone())); email_data.insert("handle".to_string(), Value::from(handle.clone())); let email_body = state .template_engine .render("two_factor_code.hbs", email_data)?; let email_message = Message::builder() //TODO prob get the proper type in the state .from(state.app_config.mailer_from.parse()?) .to(email.parse()?) .subject(&state.app_config.email_subject) .multipart( MultiPart::alternative() // This is composed of two parts. .singlepart( SinglePart::builder() .header(header::ContentType::TEXT_PLAIN) .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback. ) .singlepart( SinglePart::builder() .header(header::ContentType::TEXT_HTML) .body(email_body), ), )?; match state.mailer.send(email_message).await { Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))), Err(err) => { log::error!("Error sending the 2FA email: {err}"); Err(anyhow!(err)) } } } Err(err) => { log::error!("error on creating a 2fa token: {err}"); Err(anyhow!(err)) } }; } } // No local 2FA requirement (or account not found) Ok(AuthResult::ProxyThrough) } pub async fn create_two_factor_token( account_db: &SqlitePool, did: String, ) -> anyhow::Result { let purpose = "2fa_code"; let token = get_random_token(); let right_now = Utc::now(); let res = sqlx::query( "INSERT INTO email_token (purpose, did, token, requestedAt) VALUES (?, ?, ?, ?) ON CONFLICT(purpose, did) DO UPDATE SET token=excluded.token, requestedAt=excluded.requestedAt WHERE did=excluded.did", ) .bind(purpose) .bind(&did) .bind(&token) .bind(right_now) .execute(account_db) .await; match res { Ok(_) => Ok(token), Err(err) => { log::error!("Error creating a two factor token: {err}"); Err(anyhow::anyhow!(err)) } } } pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> { sqlx::query("DELETE FROM email_token WHERE did = ?") .bind(did) .execute(account_db) .await?; Ok(()) } pub async fn assert_valid_token( account_db: &SqlitePool, did: String, token: String, ) -> Result<(), TokenCheckError> { let token_upper = token.to_ascii_uppercase(); let purpose = "2fa_code"; let row: Option<(String,)> = sqlx::query_as( "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1", ) .bind(purpose) .bind(did) .bind(token_upper) .fetch_optional(account_db) .await .map_err(|err| { log::error!("Error getting the 2fa token: {err}"); InvalidToken })?; match row { None => Err(InvalidToken), Some(row) => { // Token lives for 15 minutes let expiration_ms = 15 * 60_000; let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) { Ok(dt) => dt.with_timezone(&Utc), Err(_) => { return Err(TokenCheckError::InvalidToken); } }; let now = Utc::now(); let age_ms = (now - requested_at_utc).num_milliseconds(); let expired = age_ms > expiration_ms; if expired { return Err(TokenCheckError::ExpiredToken); } Ok(()) } } } /// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions pub async fn verify_app_password( account_db: &SqlitePool, did: &str, password: &str, ) -> anyhow::Result { let password_scrypt = hash_app_password(did, password)?; let row: Option<(i64,)> = sqlx::query_as( "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1", ) .bind(did) .bind(password_scrypt) .fetch_optional(account_db) .await?; Ok(match row { None => false, Some((count,)) => count > 0, }) } /// Mask an email address into a hint like "2***0@p***m". pub fn mask_email(email: String) -> String { // Basic split on first '@' let mut parts = email.splitn(2, '@'); let local = match parts.next() { Some(l) => l, None => return email.to_string(), }; let domain_rest = match parts.next() { Some(d) if !d.is_empty() => d, _ => return email.to_string(), }; // Helper to mask a single label (keep first and last, middle becomes ***). fn mask_label(s: &str) -> String { let chars: Vec = s.chars().collect(); match chars.len() { 0 => String::new(), 1 => format!("{}***", chars[0]), 2 => format!("{}***{}", chars[0], chars[1]), _ => format!("{}***{}", chars[0], chars[chars.len() - 1]), } } // Mask local let masked_local = mask_label(local); // Mask first domain label only, keep the rest of the domain intact let mut dom_parts = domain_rest.splitn(2, '.'); let first_label = dom_parts.next().unwrap_or(""); let rest = dom_parts.next(); let masked_first = mask_label(first_label); let masked_domain = if let Some(rest) = rest { format!("{}.{rest}", masked_first) } else { masked_first }; format!("{masked_local}@{masked_domain}") } pub enum VerifyServiceAuthError { AuthFailed, Error(anyhow::Error), } /// Verifies the service auth token that is appended to an XRPC proxy request pub async fn verify_service_auth( jwt: &str, lxm: &Nsid<'static>, public_resolver: Arc, service_did: &Did<'static>, //The did of the user wanting to create an account requested_did: &Did<'static>, ) -> Result<(), VerifyServiceAuthError> { let parsed = service_auth::parse_jwt(jwt).map_err(|e| VerifyServiceAuthError::Error(e.into()))?; let claims = parsed.claims(); let did_doc = public_resolver .resolve_did_doc(&requested_did) .await .map_err(|err| { log::error!("Error resolving the service auth for: {}", claims.iss); return VerifyServiceAuthError::Error(err.into()); })?; // Parse the DID document response to get verification methods let doc = did_doc.parse().map_err(|err| { log::error!("Error parsing the service auth did doc: {}", claims.iss); VerifyServiceAuthError::Error(anyhow::anyhow!(err)) })?; let verification_methods = doc.verification_method.as_deref().ok_or_else(|| { VerifyServiceAuthError::Error(anyhow::anyhow!( "No verification methods in did doc: {}", &claims.iss )) })?; let signing_key = extract_signing_key(verification_methods).ok_or_else(|| { VerifyServiceAuthError::Error(anyhow::anyhow!( "No signing key found in did doc: {}", &claims.iss )) })?; service_auth::verify_signature(&parsed, &signing_key).map_err(|err| { log::error!("Error verifying service auth signature: {}", err); VerifyServiceAuthError::AuthFailed })?; // Now validate claims (audience, expiration, etc.) claims.validate(service_did).map_err(|e| { log::error!("Error validating service auth claims: {}", e); VerifyServiceAuthError::AuthFailed })?; if claims.aud != *service_did { log::error!("Invalid audience (did:web): {}", claims.aud); return Err(VerifyServiceAuthError::AuthFailed); } let lxm_from_claims = claims.lxm.as_ref().ok_or_else(|| { VerifyServiceAuthError::Error(anyhow::anyhow!("No lxm claim in service auth JWT")) })?; if lxm_from_claims != lxm { return Err(VerifyServiceAuthError::Error(anyhow::anyhow!( "Invalid XRPC endpoint requested" ))); } Ok(()) } /// Ripped from Jacquard /// /// Extract the signing key from a DID document's verification methods. /// /// This looks for a key with type "atproto" or the first available key /// if no atproto-specific key is found. fn extract_signing_key(methods: &[VerificationMethod]) -> Option { // First try to find an atproto-specific key let atproto_method = methods .iter() .find(|m| m.r#type.as_ref() == "Multikey" || m.r#type.as_ref() == "atproto"); let method = atproto_method.or_else(|| methods.first())?; // Parse the multikey let public_key_multibase = method.public_key_multibase.as_ref()?; // Decode multibase let (_, key_bytes) = multibase::decode(public_key_multibase.as_ref()).ok()?; // First two bytes are the multicodec prefix if key_bytes.len() < 2 { return None; } let codec = &key_bytes[..2]; let key_material = &key_bytes[2..]; match codec { // p256-pub (0x1200) [0x80, 0x24] => PublicKey::from_p256_bytes(key_material).ok(), // secp256k1-pub (0xe7) [0xe7, 0x01] => PublicKey::from_k256_bytes(key_material).ok(), _ => None, } } /// Payload for gate JWE tokens #[derive(serde::Serialize, serde::Deserialize, Debug)] pub struct GateTokenPayload { pub handle: String, pub created_at: String, } /// Generate a secure JWE token for gate verification pub fn generate_gate_token(handle: &str, encryption_key: &[u8]) -> Result { use josekit::jwe::{JweHeader, alg::direct::DirectJweAlgorithm}; let payload = GateTokenPayload { handle: handle.to_string(), created_at: Utc::now().to_rfc3339(), }; let payload_json = serde_json::to_string(&payload)?; let mut header = JweHeader::new(); header.set_token_type("JWT"); header.set_content_encryption("A128CBC-HS256"); let encrypter = DirectJweAlgorithm::Dir.encrypter_from_bytes(encryption_key)?; // Encrypt let jwe = josekit::jwe::serialize_compact(payload_json.as_bytes(), &header, &encrypter)?; Ok(jwe) } /// Verify and decrypt a gate JWE token, returning the payload if valid pub fn verify_gate_token( token: &str, encryption_key: &[u8], ) -> Result { let decrypter = DirectJweAlgorithm::Dir.decrypter_from_bytes(encryption_key)?; let (payload_bytes, _header) = josekit::jwe::deserialize_compact(token, &decrypter)?; let payload: GateTokenPayload = serde_json::from_slice(&payload_bytes)?; Ok(payload) }