Microservice to bring 2FA to self hosted PDSes

wip

+59 -9
src/helpers.rs
··· 10 10 use chrono::Utc; 11 11 use lettre::message::{MultiPart, SinglePart, header}; 12 12 use lettre::{AsyncTransport, Message}; 13 - use rand::distr::{Alphanumeric, SampleString}; 13 + use rand::distr::{Alphabetic, Alphanumeric, SampleString}; 14 14 use serde::de::DeserializeOwned; 15 15 use serde_json::{Map, Value}; 16 16 use sqlx::SqlitePool; ··· 118 118 } 119 119 120 120 /// Creates a random token of 10 characters for email 2FA 121 - pub fn get_random_token() -> String { 122 - let full_code = Alphanumeric.sample_string(&mut rand::rng(), 10); 121 + pub fn get_random_token(oauth: bool) -> String { 122 + let full_code = match oauth { 123 + true => Alphabetic.sample_string(&mut rand::rng(), 10), 124 + false => Alphanumeric.sample_string(&mut rand::rng(), 10), 125 + }; 123 126 //The PDS implementation creates in lowercase, then converts to uppercase. 124 127 //Just going a head and doing uppercase here. 125 128 let slice_one = &full_code[0..5].to_ascii_uppercase(); ··· 203 206 identifier: &str, 204 207 password: &str, 205 208 two_factor_code: Option<String>, 209 + oauth: bool, 206 210 ) -> anyhow::Result<AuthResult> { 207 211 // Determine identifier type 208 212 let id_type = IdentifierType::what_is_it(identifier.to_string()); ··· 248 252 // Verify password before proceeding to 2FA email step 249 253 let verified = verify_password(password, &password_scrypt).await?; 250 254 if !verified { 255 + if oauth { 256 + return Ok(AuthResult::WrongIdentityOrPassword); 257 + } 251 258 //Theres a chance it could be an app password so check that as well 252 259 return match verify_app_password(&state.account_pool, &did, password).await { 253 260 Ok(valid) => { ··· 306 313 } 307 314 } 308 315 309 - return match create_two_factor_token(&state.account_pool, did).await { 316 + return match create_two_factor_token(&state.account_pool, did, oauth).await { 310 317 Ok(code) => { 311 318 let mut email_data = Map::new(); 312 319 email_data.insert("token".to_string(), Value::from(code.clone())); ··· 315 322 .template_engine 316 323 .render("two_factor_code.hbs", email_data)?; 317 324 318 - let email = Message::builder() 325 + let email_message = Message::builder() 319 326 //TODO prob get the proper type in the state 320 327 .from(state.mailer_from.parse()?) 321 328 .to(email.parse()?) ··· 333 340 .body(email_body), 334 341 ), 335 342 )?; 336 - match state.mailer.send(email).await { 337 - Ok(_) => Ok(AuthResult::TwoFactorRequired), 343 + match state.mailer.send(email_message).await { 344 + Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))), 338 345 Err(err) => { 339 346 log::error!("Error sending the 2FA email: {err}"); 340 347 Err(anyhow!(err)) ··· 356 363 pub async fn create_two_factor_token( 357 364 account_db: &SqlitePool, 358 365 did: String, 366 + oauth: bool, 359 367 ) -> anyhow::Result<String> { 360 368 let purpose = "2fa_code"; 361 369 362 - let token = get_random_token(); 370 + let token = get_random_token(oauth); 363 371 let right_now = Utc::now(); 364 372 365 373 let res = sqlx::query( ··· 367 375 VALUES (?, ?, ?, ?) 368 376 ON CONFLICT(purpose, did) DO UPDATE SET 369 377 token=excluded.token, 370 - requestedAt=excluded.requestedAt", 378 + requestedAt=excluded.requestedAt 379 + WHERE did=excluded.did", 371 380 ) 372 381 .bind(purpose) 373 382 .bind(&did) ··· 460 469 Some((count,)) => count > 0, 461 470 }) 462 471 } 472 + 473 + /// Mask an email address into a hint like "2***0@p***m". 474 + pub fn mask_email(email: String) -> String { 475 + // Basic split on first '@' 476 + let mut parts = email.splitn(2, '@'); 477 + let local = match parts.next() { 478 + Some(l) => l, 479 + None => return email.to_string(), 480 + }; 481 + let domain_rest = match parts.next() { 482 + Some(d) if !d.is_empty() => d, 483 + _ => return email.to_string(), 484 + }; 485 + 486 + // Helper to mask a single label (keep first and last, middle becomes ***). 487 + fn mask_label(s: &str) -> String { 488 + let chars: Vec<char> = s.chars().collect(); 489 + match chars.len() { 490 + 0 => String::new(), 491 + 1 => format!("{}***", chars[0]), 492 + 2 => format!("{}***{}", chars[0], chars[1]), 493 + _ => format!("{}***{}", chars[0], chars[chars.len() - 1]), 494 + } 495 + } 496 + 497 + // Mask local 498 + let masked_local = mask_label(local); 499 + 500 + // Mask first domain label only, keep the rest of the domain intact 501 + let mut dom_parts = domain_rest.splitn(2, '.'); 502 + let first_label = dom_parts.next().unwrap_or(""); 503 + let rest = dom_parts.next(); 504 + let masked_first = mask_label(first_label); 505 + let masked_domain = if let Some(rest) = rest { 506 + format!("{}.{rest}", masked_first) 507 + } else { 508 + masked_first 509 + }; 510 + 511 + format!("{masked_local}@{masked_domain}") 512 + }
+24 -17
src/oauth_provider.rs
··· 4 4 }; 5 5 use axum::body::Body; 6 6 use axum::extract::State; 7 + use axum::http::header::CONTENT_TYPE; 7 8 use axum::http::{HeaderMap, StatusCode}; 8 9 use axum::response::{IntoResponse, Response}; 9 10 use axum::{Json, extract}; 10 11 use serde::{Deserialize, Serialize}; 11 12 use tracing::log; 12 13 13 - #[derive(Serialize, Deserialize)] 14 - struct Root { 15 - #[serde(rename = "CamelCaseJson")] 16 - pub camel_case_json: i64, 17 - #[serde(rename = "woahThisIsNeat")] 18 - pub woah_this_is_neat: String, 19 - } 20 - 21 - #[derive(Serialize, Deserialize)] 14 + #[derive(Serialize, Deserialize, Clone)] 22 15 pub struct SignInRequest { 23 16 pub username: String, 24 17 pub password: String, ··· 31 24 pub async fn sign_in( 32 25 State(state): State<AppState>, 33 26 headers: HeaderMap, 34 - Json(payload): extract::Json<SignInRequest>, 27 + Json(mut payload): extract::Json<SignInRequest>, 35 28 ) -> Result<Response<Body>, StatusCode> { 36 29 let identifier = payload.username.clone(); 37 30 let password = payload.password.clone(); ··· 39 32 40 33 //TODO need to pass in a flag to ignore app passwords for Oauth 41 34 // Run the shared pre-auth logic to validate and check 2FA requirement 42 - match preauth_check(&state, &identifier, &password, auth_factor_token).await { 35 + match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { 43 36 Ok(result) => match result { 44 37 AuthResult::WrongIdentityOrPassword => oauth_json_error_response( 45 38 StatusCode::BAD_REQUEST, 46 39 "invalid_request", 47 40 "Invalid identifier or password", 48 41 ), 49 - AuthResult::TwoFactorRequired => { 42 + AuthResult::TwoFactorRequired(masked_email) => { 50 43 // Email sending step can be handled here if needed in the future. 51 44 52 45 // {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"} 53 - oauth_json_error_response( 54 - StatusCode::UNAUTHORIZED, 55 - "AuthFactorTokenRequired", 56 - "A sign in code has been sent to your email address", 57 - ) 46 + let body_str = match serde_json::to_string(&serde_json::json!({ 47 + "error": "second_authentication_factor_required", 48 + "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email), 49 + "type": "emailOtp", 50 + "hint": masked_email, 51 + })) { 52 + Ok(s) => s, 53 + Err(_) => return Err(StatusCode::BAD_REQUEST), 54 + }; 55 + 56 + Response::builder() 57 + .status(StatusCode::BAD_REQUEST) 58 + .header(CONTENT_TYPE, "application/json") 59 + .body(Body::from(body_str)) 60 + .map_err(|_| StatusCode::BAD_REQUEST) 58 61 } 59 62 AuthResult::ProxyThrough => { 60 63 //No 2FA or already passed 64 + //I don't think it likes localhost. Maybe do 61 65 let uri = format!( 62 66 "{}{}", 63 67 state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in" ··· 68 72 req_headers.extend(headers.clone()); 69 73 } 70 74 75 + payload.email_otp = None; 71 76 let payload_bytes = 72 77 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 78 + let body = serde_json::to_string(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 79 + 73 80 let req = req 74 81 .body(Body::from(payload_bytes)) 75 82 .map_err(|_| StatusCode::BAD_REQUEST)?;
+7 -3
src/xrpc/com_atproto_server.rs
··· 11 11 use serde::{Deserialize, Serialize}; 12 12 use serde_json; 13 13 use tracing::log; 14 + use tracing::log::log; 14 15 15 16 #[derive(Serialize, Deserialize, Debug, Clone)] 16 17 #[serde(rename_all = "camelCase")] ··· 55 56 pub struct CreateSessionRequest { 56 57 identifier: String, 57 58 password: String, 59 + #[serde(skip_serializing_if = "Option::is_none")] 58 60 auth_factor_token: Option<String>, 61 + #[serde(skip_serializing_if = "Option::is_none")] 59 62 allow_takendown: Option<bool>, 60 63 } 61 64 62 65 pub async fn create_session( 63 66 State(state): State<AppState>, 64 67 headers: HeaderMap, 65 - Json(payload): extract::Json<CreateSessionRequest>, 68 + Json(mut payload): extract::Json<CreateSessionRequest>, 66 69 ) -> Result<Response<Body>, StatusCode> { 67 70 let identifier = payload.identifier.clone(); 68 71 let password = payload.password.clone(); 69 72 let auth_factor_token = payload.auth_factor_token.clone(); 70 73 71 74 // Run the shared pre-auth logic to validate and check 2FA requirement 72 - match preauth_check(&state, &identifier, &password, auth_factor_token).await { 75 + match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { 73 76 Ok(result) => match result { 74 77 AuthResult::WrongIdentityOrPassword => json_error_response( 75 78 StatusCode::UNAUTHORIZED, 76 79 "AuthenticationRequired", 77 80 "Invalid identifier or password", 78 81 ), 79 - AuthResult::TwoFactorRequired => { 82 + AuthResult::TwoFactorRequired(_) => { 80 83 // Email sending step can be handled here if needed in the future. 81 84 json_error_response( 82 85 StatusCode::UNAUTHORIZED, ··· 85 88 ) 86 89 } 87 90 AuthResult::ProxyThrough => { 91 + log::info!("Proxying through"); 88 92 //No 2FA or already passed 89 93 let uri = format!( 90 94 "{}{}",