Microservice to bring 2FA to self hosted PDSes

2FA gatekeeping

authored by baileytownsend.dev and committed by Tangled 0ca2fe62 f2ff2884

+10
Cargo.lock
··· 287 287 dependencies = [ 288 288 "android-tzdata", 289 289 "iana-time-zone", 290 + "js-sys", 290 291 "num-traits", 292 + "wasm-bindgen", 291 293 "windows-link", 292 294 ] 293 295 ··· 1652 1654 name = "pds_gatekeeper" 1653 1655 version = "0.1.0" 1654 1656 dependencies = [ 1657 + "anyhow", 1655 1658 "axum", 1656 1659 "axum-template", 1660 + "chrono", 1657 1661 "dotenvy", 1658 1662 "handlebars", 1659 1663 "hex", 1660 1664 "hyper-util", 1661 1665 "jwt-compact", 1662 1666 "lettre", 1667 + "rand 0.9.2", 1663 1668 "rust-embed", 1664 1669 "scrypt", 1665 1670 "serde", 1666 1671 "serde_json", 1672 + "sha2", 1667 1673 "sqlx", 1668 1674 "tokio", 1669 1675 "tower-http", ··· 2393 2399 dependencies = [ 2394 2400 "base64", 2395 2401 "bytes", 2402 + "chrono", 2396 2403 "crc", 2397 2404 "crossbeam-queue", 2398 2405 "either", ··· 2470 2477 "bitflags", 2471 2478 "byteorder", 2472 2479 "bytes", 2480 + "chrono", 2473 2481 "crc", 2474 2482 "digest", 2475 2483 "dotenvy", ··· 2511 2519 "base64", 2512 2520 "bitflags", 2513 2521 "byteorder", 2522 + "chrono", 2514 2523 "crc", 2515 2524 "dotenvy", 2516 2525 "etcetera", ··· 2545 2554 checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" 2546 2555 dependencies = [ 2547 2556 "atoi", 2557 + "chrono", 2548 2558 "flume", 2549 2559 "futures-channel", 2550 2560 "futures-core",
+5 -1
Cargo.toml
··· 6 6 [dependencies] 7 7 axum = { version = "0.8.4", features = ["macros", "json"] } 8 8 tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] } 9 - sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] } 9 + sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate", "chrono"] } 10 10 dotenvy = "0.15.7" 11 11 serde = { version = "1.0", features = ["derive"] } 12 12 serde_json = "1.0" ··· 22 22 handlebars = { version = "6.3.2", features = ["rust-embed"] } 23 23 rust-embed = "8.7.2" 24 24 axum-template = { version = "3.0.0", features = ["handlebars"] } 25 + rand = "0.9.2" 26 + anyhow = "1.0.99" 27 + chrono = "0.4.41" 28 + sha2 = "0.10"
+5 -6
README.md
··· 12 12 13 13 ## 2FA 14 14 15 - - [x] Ability to turn on/off 2FA 16 - - [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on 17 - - [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email. 18 - - [ ] generate a 2FA code 19 - - [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet) 20 - - [ ] oauth endpoint gatekeeping 15 + - Overrides The login endpoint to add 2FA for both Bluesky client logged in and OAuth logins 16 + - Overrides the settings endpoints as well. As long as you have a confirmed email you can turn on 2FA 21 17 22 18 ## Captcha on Create Account 23 19 24 20 Future feature? 25 21 26 22 # Setup 23 + 24 + We are getting close! Testing now 27 25 28 26 Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up. 29 27 But I want to run it locally on my own PDS first to test run it a bit. ··· 37 35 path /xrpc/com.atproto.server.getSession 38 36 path /xrpc/com.atproto.server.updateEmail 39 37 path /xrpc/com.atproto.server.createSession 38 + path /@atproto/oauth-provider/~api/sign-in 40 39 } 41 40 42 41 handle @gatekeeper {
-3
migrations_bells_and_whistles/.keep
··· 1 - # This directory holds SQLx migrations for the bells_and_whistles.sqlite database. 2 - # It is intentionally empty for now; running `sqlx::migrate!` will still ensure the 3 - # migrations table exists and succeed with zero migrations.
+524
src/helpers.rs
··· 1 + use crate::AppState; 2 + use crate::helpers::TokenCheckError::InvalidToken; 3 + use anyhow::anyhow; 4 + use axum::body::{Body, to_bytes}; 5 + use axum::extract::Request; 6 + use axum::http::header::CONTENT_TYPE; 7 + use axum::http::{HeaderMap, StatusCode, Uri}; 8 + use axum::response::{IntoResponse, Response}; 9 + use axum_template::TemplateEngine; 10 + use chrono::Utc; 11 + use lettre::message::{MultiPart, SinglePart, header}; 12 + use lettre::{AsyncTransport, Message}; 13 + use rand::Rng; 14 + use serde::de::DeserializeOwned; 15 + use serde_json::{Map, Value}; 16 + use sha2::{Digest, Sha256}; 17 + use sqlx::SqlitePool; 18 + use tracing::{error, log}; 19 + 20 + ///Used to generate the email 2fa code 21 + const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; 22 + 23 + /// The result of a proxied call that attempts to parse JSON. 24 + pub enum ProxiedResult<T> { 25 + /// Successfully parsed JSON body along with original response headers. 26 + Parsed { value: T, _headers: HeaderMap }, 27 + /// Could not or should not parse: return the original (or rebuilt) response as-is. 28 + Passthrough(Response<Body>), 29 + } 30 + 31 + /// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse 32 + /// the successful response body as JSON into `T`. 33 + /// 34 + pub async fn proxy_get_json<T>( 35 + state: &AppState, 36 + mut req: Request, 37 + path: &str, 38 + ) -> Result<ProxiedResult<T>, StatusCode> 39 + where 40 + T: DeserializeOwned, 41 + { 42 + let uri = format!("{}{}", state.pds_base_url, path); 43 + *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 44 + 45 + let result = state 46 + .reverse_proxy_client 47 + .request(req) 48 + .await 49 + .map_err(|_| StatusCode::BAD_REQUEST)? 50 + .into_response(); 51 + 52 + if result.status() != StatusCode::OK { 53 + return Ok(ProxiedResult::Passthrough(result)); 54 + } 55 + 56 + let response_headers = result.headers().clone(); 57 + let body = result.into_body(); 58 + let body_bytes = to_bytes(body, usize::MAX) 59 + .await 60 + .map_err(|_| StatusCode::BAD_REQUEST)?; 61 + 62 + match serde_json::from_slice::<T>(&body_bytes) { 63 + Ok(value) => Ok(ProxiedResult::Parsed { 64 + value, 65 + _headers: response_headers, 66 + }), 67 + Err(err) => { 68 + error!(%err, "failed to parse proxied JSON response; returning original body"); 69 + let mut builder = Response::builder().status(StatusCode::OK); 70 + if let Some(headers) = builder.headers_mut() { 71 + *headers = response_headers; 72 + } 73 + let resp = builder 74 + .body(Body::from(body_bytes)) 75 + .map_err(|_| StatusCode::BAD_REQUEST)?; 76 + Ok(ProxiedResult::Passthrough(resp)) 77 + } 78 + } 79 + } 80 + 81 + /// Build a JSON error response with the required Content-Type header 82 + /// Content-Type: application/json;charset=utf-8 83 + /// Body shape: { "error": string, "message": string } 84 + pub fn json_error_response( 85 + status: StatusCode, 86 + error: impl Into<String>, 87 + message: impl Into<String>, 88 + ) -> Result<Response<Body>, StatusCode> { 89 + let body_str = match serde_json::to_string(&serde_json::json!({ 90 + "error": error.into(), 91 + "message": message.into(), 92 + })) { 93 + Ok(s) => s, 94 + Err(_) => return Err(StatusCode::BAD_REQUEST), 95 + }; 96 + 97 + Response::builder() 98 + .status(status) 99 + .header(CONTENT_TYPE, "application/json;charset=utf-8") 100 + .body(Body::from(body_str)) 101 + .map_err(|_| StatusCode::BAD_REQUEST) 102 + } 103 + 104 + /// Build a JSON error response with the required Content-Type header 105 + /// Content-Type: application/json (oauth endpoint does not like utf ending) 106 + /// Body shape: { "error": string, "error_description": string } 107 + pub fn oauth_json_error_response( 108 + status: StatusCode, 109 + error: impl Into<String>, 110 + message: impl Into<String>, 111 + ) -> Result<Response<Body>, StatusCode> { 112 + let body_str = match serde_json::to_string(&serde_json::json!({ 113 + "error": error.into(), 114 + "error_description": message.into(), 115 + })) { 116 + Ok(s) => s, 117 + Err(_) => return Err(StatusCode::BAD_REQUEST), 118 + }; 119 + 120 + Response::builder() 121 + .status(status) 122 + .header(CONTENT_TYPE, "application/json") 123 + .body(Body::from(body_str)) 124 + .map_err(|_| StatusCode::BAD_REQUEST) 125 + } 126 + 127 + /// Creates a random token of 10 characters for email 2FA 128 + pub fn get_random_token() -> String { 129 + let mut rng = rand::rng(); 130 + 131 + let mut full_code = String::with_capacity(10); 132 + for _ in 0..10 { 133 + let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len()); 134 + full_code.push(UPPERCASE_BASE32_CHARS[idx] as char); 135 + } 136 + 137 + //The PDS implementation creates in lowercase, then converts to uppercase. 138 + //Just going a head and doing uppercase here. 139 + let slice_one = &full_code[0..5].to_ascii_uppercase(); 140 + let slice_two = &full_code[5..10].to_ascii_uppercase(); 141 + format!("{slice_one}-{slice_two}") 142 + } 143 + 144 + pub enum TokenCheckError { 145 + InvalidToken, 146 + ExpiredToken, 147 + } 148 + 149 + pub enum AuthResult { 150 + WrongIdentityOrPassword, 151 + /// The string here is the email address to create a hint for oauth 152 + TwoFactorRequired(String), 153 + /// User does not have 2FA enabled, or using an app password, or passes it 154 + ProxyThrough, 155 + TokenCheckFailed(TokenCheckError), 156 + } 157 + 158 + pub enum IdentifierType { 159 + Email, 160 + Did, 161 + Handle, 162 + } 163 + 164 + impl IdentifierType { 165 + fn what_is_it(identifier: String) -> Self { 166 + if identifier.contains("@") { 167 + IdentifierType::Email 168 + } else if identifier.contains("did:") { 169 + IdentifierType::Did 170 + } else { 171 + IdentifierType::Handle 172 + } 173 + } 174 + } 175 + 176 + /// Creates a hex string from the password and salt to find app passwords 177 + fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> { 178 + let params = scrypt::Params::new(14, 8, 1, 64)?; 179 + let mut derived = [0u8; 64]; 180 + scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived)?; 181 + Ok(hex::encode(derived)) 182 + } 183 + 184 + /// Hashes the app password. did is used as the salt. 185 + pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> { 186 + let mut hasher = Sha256::new(); 187 + hasher.update(did.as_bytes()); 188 + let sha = hasher.finalize(); 189 + let salt = hex::encode(&sha[..16]); 190 + let hash_hex = scrypt_hex(password, &salt)?; 191 + Ok(format!("{salt}:{hash_hex}")) 192 + } 193 + 194 + async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> { 195 + // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) 196 + let mut parts = password_scrypt.splitn(2, ':'); 197 + let salt = match parts.next() { 198 + Some(s) if !s.is_empty() => s, 199 + _ => return Ok(false), 200 + }; 201 + let stored_hash_hex = match parts.next() { 202 + Some(h) if !h.is_empty() => h, 203 + _ => return Ok(false), 204 + }; 205 + 206 + // Derive using the shared helper and compare 207 + let derived_hex = match scrypt_hex(password, salt) { 208 + Ok(h) => h, 209 + Err(_) => return Ok(false), 210 + }; 211 + 212 + Ok(derived_hex.as_str() == stored_hash_hex) 213 + } 214 + 215 + /// Handles the auth checks along with sending a 2fa email 216 + pub async fn preauth_check( 217 + state: &AppState, 218 + identifier: &str, 219 + password: &str, 220 + two_factor_code: Option<String>, 221 + oauth: bool, 222 + ) -> anyhow::Result<AuthResult> { 223 + // Determine identifier type 224 + let id_type = IdentifierType::what_is_it(identifier.to_string()); 225 + 226 + // Query account DB for did and passwordScrypt based on identifier type 227 + let account_row: Option<(String, String, String, String)> = match id_type { 228 + IdentifierType::Email => { 229 + sqlx::query_as::<_, (String, String, String, String)>( 230 + "SELECT account.did, account.passwordScrypt, account.email, actor.handle 231 + FROM actor 232 + LEFT JOIN account ON actor.did = account.did 233 + where account.email = ? LIMIT 1", 234 + ) 235 + .bind(identifier) 236 + .fetch_optional(&state.account_pool) 237 + .await? 238 + } 239 + IdentifierType::Handle => { 240 + sqlx::query_as::<_, (String, String, String, String)>( 241 + "SELECT account.did, account.passwordScrypt, account.email, actor.handle 242 + FROM actor 243 + LEFT JOIN account ON actor.did = account.did 244 + where actor.handle = ? LIMIT 1", 245 + ) 246 + .bind(identifier) 247 + .fetch_optional(&state.account_pool) 248 + .await? 249 + } 250 + IdentifierType::Did => { 251 + sqlx::query_as::<_, (String, String, String, String)>( 252 + "SELECT account.did, account.passwordScrypt, account.email, actor.handle 253 + FROM actor 254 + LEFT JOIN account ON actor.did = account.did 255 + where account.did = ? LIMIT 1", 256 + ) 257 + .bind(identifier) 258 + .fetch_optional(&state.account_pool) 259 + .await? 260 + } 261 + }; 262 + 263 + if let Some((did, password_scrypt, email, handle)) = account_row { 264 + // Verify password before proceeding to 2FA email step 265 + let verified = verify_password(password, &password_scrypt).await?; 266 + if !verified { 267 + if oauth { 268 + //OAuth does not allow app password logins so just go ahead and send it along it's way 269 + return Ok(AuthResult::WrongIdentityOrPassword); 270 + } 271 + //Theres a chance it could be an app password so check that as well 272 + return match verify_app_password(&state.account_pool, &did, password).await { 273 + Ok(valid) => { 274 + if valid { 275 + //Was a valid app password up to the PDS now 276 + Ok(AuthResult::ProxyThrough) 277 + } else { 278 + Ok(AuthResult::WrongIdentityOrPassword) 279 + } 280 + } 281 + Err(err) => { 282 + log::error!("Error checking the app password: {err}"); 283 + Err(err) 284 + } 285 + }; 286 + } 287 + 288 + // Check two-factor requirement for this DID in the gatekeeper DB 289 + let required_opt = sqlx::query_as::<_, (u8,)>( 290 + "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 291 + ) 292 + .bind(did.clone()) 293 + .fetch_optional(&state.pds_gatekeeper_pool) 294 + .await?; 295 + 296 + let two_factor_required = match required_opt { 297 + Some(row) => row.0 != 0, 298 + None => false, 299 + }; 300 + 301 + if two_factor_required { 302 + //Two factor is required and a taken was provided 303 + if let Some(two_factor_code) = two_factor_code { 304 + //if the two_factor_code is set need to see if we have a valid token 305 + if !two_factor_code.is_empty() { 306 + return match assert_valid_token( 307 + &state.account_pool, 308 + did.clone(), 309 + two_factor_code, 310 + ) 311 + .await 312 + { 313 + Ok(_) => { 314 + let result_of_cleanup = 315 + delete_all_email_tokens(&state.account_pool, did.clone()).await; 316 + if result_of_cleanup.is_err() { 317 + log::error!( 318 + "There was an error deleting the email tokens after login: {:?}", 319 + result_of_cleanup.err() 320 + ) 321 + } 322 + Ok(AuthResult::ProxyThrough) 323 + } 324 + Err(err) => Ok(AuthResult::TokenCheckFailed(err)), 325 + }; 326 + } 327 + } 328 + 329 + return match create_two_factor_token(&state.account_pool, did).await { 330 + Ok(code) => { 331 + let mut email_data = Map::new(); 332 + email_data.insert("token".to_string(), Value::from(code.clone())); 333 + email_data.insert("handle".to_string(), Value::from(handle.clone())); 334 + let email_body = state 335 + .template_engine 336 + .render("two_factor_code.hbs", email_data)?; 337 + 338 + let email_message = Message::builder() 339 + //TODO prob get the proper type in the state 340 + .from(state.mailer_from.parse()?) 341 + .to(email.parse()?) 342 + .subject("Sign in to Bluesky") 343 + .multipart( 344 + MultiPart::alternative() // This is composed of two parts. 345 + .singlepart( 346 + SinglePart::builder() 347 + .header(header::ContentType::TEXT_PLAIN) 348 + .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback. 349 + ) 350 + .singlepart( 351 + SinglePart::builder() 352 + .header(header::ContentType::TEXT_HTML) 353 + .body(email_body), 354 + ), 355 + )?; 356 + match state.mailer.send(email_message).await { 357 + Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))), 358 + Err(err) => { 359 + log::error!("Error sending the 2FA email: {err}"); 360 + Err(anyhow!(err)) 361 + } 362 + } 363 + } 364 + Err(err) => { 365 + log::error!("error on creating a 2fa token: {err}"); 366 + Err(anyhow!(err)) 367 + } 368 + }; 369 + } 370 + } 371 + 372 + // No local 2FA requirement (or account not found) 373 + Ok(AuthResult::ProxyThrough) 374 + } 375 + 376 + pub async fn create_two_factor_token( 377 + account_db: &SqlitePool, 378 + did: String, 379 + ) -> anyhow::Result<String> { 380 + let purpose = "2fa_code"; 381 + 382 + let token = get_random_token(); 383 + let right_now = Utc::now(); 384 + 385 + let res = sqlx::query( 386 + "INSERT INTO email_token (purpose, did, token, requestedAt) 387 + VALUES (?, ?, ?, ?) 388 + ON CONFLICT(purpose, did) DO UPDATE SET 389 + token=excluded.token, 390 + requestedAt=excluded.requestedAt 391 + WHERE did=excluded.did", 392 + ) 393 + .bind(purpose) 394 + .bind(&did) 395 + .bind(&token) 396 + .bind(right_now) 397 + .execute(account_db) 398 + .await; 399 + 400 + match res { 401 + Ok(_) => Ok(token), 402 + Err(err) => { 403 + log::error!("Error creating a two factor token: {err}"); 404 + Err(anyhow::anyhow!(err)) 405 + } 406 + } 407 + } 408 + 409 + pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> { 410 + sqlx::query("DELETE FROM email_token WHERE did = ?") 411 + .bind(did) 412 + .execute(account_db) 413 + .await?; 414 + Ok(()) 415 + } 416 + 417 + pub async fn assert_valid_token( 418 + account_db: &SqlitePool, 419 + did: String, 420 + token: String, 421 + ) -> Result<(), TokenCheckError> { 422 + let token_upper = token.to_ascii_uppercase(); 423 + let purpose = "2fa_code"; 424 + 425 + let row: Option<(String,)> = sqlx::query_as( 426 + "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1", 427 + ) 428 + .bind(purpose) 429 + .bind(did) 430 + .bind(token_upper) 431 + .fetch_optional(account_db) 432 + .await 433 + .map_err(|err| { 434 + log::error!("Error getting the 2fa token: {err}"); 435 + InvalidToken 436 + })?; 437 + 438 + match row { 439 + None => Err(InvalidToken), 440 + Some(row) => { 441 + // Token lives for 15 minutes 442 + let expiration_ms = 15 * 60_000; 443 + 444 + let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) { 445 + Ok(dt) => dt.with_timezone(&Utc), 446 + Err(_) => { 447 + return Err(TokenCheckError::InvalidToken); 448 + } 449 + }; 450 + 451 + let now = Utc::now(); 452 + let age_ms = (now - requested_at_utc).num_milliseconds(); 453 + let expired = age_ms > expiration_ms; 454 + if expired { 455 + return Err(TokenCheckError::ExpiredToken); 456 + } 457 + 458 + Ok(()) 459 + } 460 + } 461 + } 462 + 463 + /// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions 464 + pub async fn verify_app_password( 465 + account_db: &SqlitePool, 466 + did: &str, 467 + password: &str, 468 + ) -> anyhow::Result<bool> { 469 + let password_scrypt = hash_app_password(did, password)?; 470 + 471 + let row: Option<(i64,)> = sqlx::query_as( 472 + "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1", 473 + ) 474 + .bind(did) 475 + .bind(password_scrypt) 476 + .fetch_optional(account_db) 477 + .await?; 478 + 479 + Ok(match row { 480 + None => false, 481 + Some((count,)) => count > 0, 482 + }) 483 + } 484 + 485 + /// Mask an email address into a hint like "2***0@p***m". 486 + pub fn mask_email(email: String) -> String { 487 + // Basic split on first '@' 488 + let mut parts = email.splitn(2, '@'); 489 + let local = match parts.next() { 490 + Some(l) => l, 491 + None => return email.to_string(), 492 + }; 493 + let domain_rest = match parts.next() { 494 + Some(d) if !d.is_empty() => d, 495 + _ => return email.to_string(), 496 + }; 497 + 498 + // Helper to mask a single label (keep first and last, middle becomes ***). 499 + fn mask_label(s: &str) -> String { 500 + let chars: Vec<char> = s.chars().collect(); 501 + match chars.len() { 502 + 0 => String::new(), 503 + 1 => format!("{}***", chars[0]), 504 + 2 => format!("{}***{}", chars[0], chars[1]), 505 + _ => format!("{}***{}", chars[0], chars[chars.len() - 1]), 506 + } 507 + } 508 + 509 + // Mask local 510 + let masked_local = mask_label(local); 511 + 512 + // Mask first domain label only, keep the rest of the domain intact 513 + let mut dom_parts = domain_rest.splitn(2, '.'); 514 + let first_label = dom_parts.next().unwrap_or(""); 515 + let rest = dom_parts.next(); 516 + let masked_first = mask_label(first_label); 517 + let masked_domain = if let Some(rest) = rest { 518 + format!("{}.{rest}", masked_first) 519 + } else { 520 + masked_first 521 + }; 522 + 523 + format!("{masked_local}@{masked_domain}") 524 + }
+53 -26
src/main.rs
··· 1 + #![warn(clippy::unwrap_used)] 2 + use crate::oauth_provider::sign_in; 1 3 use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 2 - use axum::middleware as ax_middleware; 3 - mod middleware; 4 4 use axum::body::Body; 5 5 use axum::handler::Handler; 6 6 use axum::http::{Method, header}; 7 + use axum::middleware as ax_middleware; 7 8 use axum::routing::post; 8 9 use axum::{Router, routing::get}; 9 10 use axum_template::engine::Engine; ··· 21 22 use tower_governor::governor::GovernorConfigBuilder; 22 23 use tower_http::compression::CompressionLayer; 23 24 use tower_http::cors::{Any, CorsLayer}; 24 - use tracing::{error, log}; 25 + use tracing::log; 25 26 use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 26 27 28 + pub mod helpers; 29 + mod middleware; 30 + mod oauth_provider; 27 31 mod xrpc; 28 32 29 33 type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; ··· 34 38 struct EmailTemplates; 35 39 36 40 #[derive(Clone)] 37 - struct AppState { 41 + pub struct AppState { 38 42 account_pool: SqlitePool, 39 43 pds_gatekeeper_pool: SqlitePool, 40 44 reverse_proxy_client: HyperUtilClient, ··· 73 77 74 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 75 79 76 - let banner = format!(" {}\n{}", body, intro); 80 + let banner = format!(" {body}\n{intro}"); 77 81 78 82 ( 79 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], ··· 84 88 #[tokio::main] 85 89 async fn main() -> Result<(), Box<dyn std::error::Error>> { 86 90 setup_tracing(); 87 - //TODO prod 91 + //TODO may need to change where this reads from? Like an env variable for it's location? Or arg? 88 92 dotenvy::from_path(Path::new("./pds.env"))?; 89 93 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 90 - // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data"; 91 - let account_db_url = format!("{}/account.sqlite", pds_root); 92 - log::info!("accounts_db_url: {}", account_db_url); 94 + let account_db_url = format!("{pds_root}/account.sqlite"); 93 95 94 96 let account_options = SqliteConnectOptions::new() 95 - .journal_mode(SqliteJournalMode::Wal) 96 - .filename(account_db_url); 97 + .filename(account_db_url) 98 + .busy_timeout(Duration::from_secs(5)); 97 99 98 100 let account_pool = SqlitePoolOptions::new() 99 101 .max_connections(5) 100 102 .connect_with(account_options) 101 103 .await?; 102 104 103 - let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root); 105 + let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 104 106 let options = SqliteConnectOptions::new() 105 107 .journal_mode(SqliteJournalMode::Wal) 106 108 .filename(bells_db_url) 107 - .create_if_missing(true); 109 + .create_if_missing(true) 110 + .busy_timeout(Duration::from_secs(5)); 108 111 let pds_gatekeeper_pool = SqlitePoolOptions::new() 109 112 .max_connections(5) 110 113 .connect_with(options) 111 114 .await?; 112 115 113 - // Run migrations for the bells_and_whistles database 116 + // Run migrations for the extra database 114 117 // Note: the migrations are embedded at compile time from the given directory 115 118 // sqlx 116 119 sqlx::migrate!("./migrations") ··· 130 133 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 131 134 //Email templates setup 132 135 let mut hbs = Handlebars::new(); 133 - let _ = hbs.register_embed_templates::<EmailTemplates>(); 136 + 137 + let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 138 + if let Ok(users_email_directory) = users_email_directory { 139 + hbs.register_template_file( 140 + "two_factor_code.hbs", 141 + format!("{users_email_directory}/two_factor_code.hbs"), 142 + )?; 143 + } else { 144 + let _ = hbs.register_embed_templates::<EmailTemplates>(); 145 + } 146 + 147 + let pds_base_url = 148 + env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 134 149 135 150 let state = AppState { 136 151 account_pool, 137 152 pds_gatekeeper_pool, 138 153 reverse_proxy_client: client, 139 - //TODO should be env prob 140 - pds_base_url: "http://localhost:3000".to_string(), 154 + pds_base_url, 141 155 mailer, 142 156 mailer_from: sent_from, 143 157 template_engine: Engine::from(hbs), ··· 145 159 146 160 // Rate limiting 147 161 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 148 - let governor_conf = GovernorConfigBuilder::default() 162 + let create_session_governor_conf = GovernorConfigBuilder::default() 149 163 .per_second(60) 150 164 .burst_size(5) 151 165 .finish() 152 - .unwrap(); 153 - let governor_limiter = governor_conf.limiter().clone(); 166 + .expect("failed to create governor config. this should not happen and is a bug"); 167 + 168 + // Create a second config with the same settings for the other endpoint 169 + let sign_in_governor_conf = GovernorConfigBuilder::default() 170 + .per_second(60) 171 + .burst_size(5) 172 + .finish() 173 + .expect("failed to create governor config. this should not happen and is a bug"); 174 + 175 + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 176 + let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 154 177 let interval = Duration::from_secs(60); 155 178 // a separate background task to clean up 156 179 std::thread::spawn(move || { 157 180 loop { 158 181 std::thread::sleep(interval); 159 - tracing::info!("rate limiting storage size: {}", governor_limiter.len()); 160 - governor_limiter.retain_recent(); 182 + create_session_governor_limiter.retain_recent(); 183 + sign_in_governor_limiter.retain_recent(); 161 184 } 162 185 }); 163 186 ··· 177 200 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 178 201 ) 179 202 .route( 203 + "/@atproto/oauth-provider/~api/sign-in", 204 + post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 205 + ) 206 + .route( 180 207 "/xrpc/com.atproto.server.createSession", 181 - post(create_session.layer(GovernorLayer::new(governor_conf))), 208 + post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 182 209 ) 183 210 .layer(CompressionLayer::new()) 184 211 .layer(cors) 185 212 .with_state(state); 186 213 187 - let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 188 - let port: u16 = env::var("PORT") 214 + let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 215 + let port: u16 = env::var("GATEKEEPER_PORT") 189 216 .ok() 190 217 .and_then(|s| s.parse().ok()) 191 218 .unwrap_or(8080); ··· 202 229 .with_graceful_shutdown(shutdown_signal()); 203 230 204 231 if let Err(err) = server.await { 205 - error!(error = %err, "server error"); 232 + log::error!("server error:{err}"); 206 233 } 207 234 208 235 Ok(())
+19 -34
src/middleware.rs
··· 1 - use crate::xrpc::helpers::json_error_response; 1 + use crate::helpers::json_error_response; 2 2 use axum::extract::Request; 3 3 use axum::http::{HeaderMap, StatusCode}; 4 4 use axum::middleware::Next; ··· 7 7 use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; 8 8 use serde::{Deserialize, Serialize}; 9 9 use std::env; 10 + use tracing::log; 10 11 11 12 #[derive(Clone, Debug)] 12 13 pub struct Did(pub Option<String>); ··· 22 23 match token { 23 24 Ok(token) => { 24 25 match token { 25 - None => { 26 - return json_error_response( 27 - StatusCode::BAD_REQUEST, 28 - "TokenRequired", 29 - "", 30 - ).unwrap(); 31 - } 26 + None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 27 + .expect("Error creating an error response"), 32 28 Some(token) => { 33 29 let token = UntrustedToken::new(&token); 34 - //Doing weird unwraps cause I can't do Result for middleware? 35 30 if token.is_err() { 36 - return json_error_response( 37 - StatusCode::BAD_REQUEST, 38 - "TokenRequired", 39 - "", 40 - ).unwrap(); 31 + return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 32 + .expect("Error creating an error response"); 41 33 } 42 - let parsed_token = token.unwrap(); 34 + let parsed_token = token.expect("Already checked for error"); 43 35 let claims: Result<Claims<TokenClaims>, ValidationError> = 44 36 parsed_token.deserialize_claims_unchecked(); 45 37 if claims.is_err() { 46 - return json_error_response( 47 - StatusCode::BAD_REQUEST, 48 - "TokenRequired", 49 - "", 50 - ).unwrap(); 38 + return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 39 + .expect("Error creating an error response"); 51 40 } 52 41 53 - let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap()); 42 + let key = Hs256Key::new( 43 + env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), 44 + ); 54 45 let token: Result<Token<TokenClaims>, ValidationError> = 55 46 Hs256.validator(&key).validate(&parsed_token); 56 47 if token.is_err() { 57 - return json_error_response( 58 - StatusCode::BAD_REQUEST, 59 - "InvalidToken", 60 - "", 61 - ).unwrap(); 48 + return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 49 + .expect("Error creating an error response"); 62 50 } 63 - let token = token.unwrap(); 51 + let token = token.expect("Already checked for error,"); 64 52 //Not going to worry about expiration since it still goes to the PDS 65 - 66 53 req.extensions_mut() 67 54 .insert(Did(Some(token.claims().custom.sub.clone()))); 68 55 next.run(req).await 69 56 } 70 57 } 71 58 } 72 - Err(_) => { 73 - return json_error_response( 74 - StatusCode::BAD_REQUEST, 75 - "InvalidToken", 76 - "", 77 - ).unwrap(); 59 + Err(err) => { 60 + log::error!("Error extracting token: {err}"); 61 + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 62 + .expect("Error creating an error response") 78 63 } 79 64 } 80 65 }
+141
src/oauth_provider.rs
··· 1 + use crate::AppState; 2 + use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check}; 3 + use axum::body::Body; 4 + use axum::extract::State; 5 + use axum::http::header::CONTENT_TYPE; 6 + use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; 7 + use axum::response::{IntoResponse, Response}; 8 + use axum::{Json, extract}; 9 + use serde::{Deserialize, Serialize}; 10 + use tracing::log; 11 + 12 + #[derive(Serialize, Deserialize, Clone)] 13 + pub struct SignInRequest { 14 + pub username: String, 15 + pub password: String, 16 + pub remember: bool, 17 + pub locale: String, 18 + #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 19 + pub email_otp: Option<String>, 20 + } 21 + 22 + pub async fn sign_in( 23 + State(state): State<AppState>, 24 + headers: HeaderMap, 25 + Json(mut payload): extract::Json<SignInRequest>, 26 + ) -> Result<Response<Body>, StatusCode> { 27 + let identifier = payload.username.clone(); 28 + let password = payload.password.clone(); 29 + let auth_factor_token = payload.email_otp.clone(); 30 + 31 + match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { 32 + Ok(result) => match result { 33 + AuthResult::WrongIdentityOrPassword => oauth_json_error_response( 34 + StatusCode::BAD_REQUEST, 35 + "invalid_request", 36 + "Invalid identifier or password", 37 + ), 38 + AuthResult::TwoFactorRequired(masked_email) => { 39 + // Email sending step can be handled here if needed in the future. 40 + 41 + // {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"} 42 + let body_str = match serde_json::to_string(&serde_json::json!({ 43 + "error": "second_authentication_factor_required", 44 + "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email), 45 + "type": "emailOtp", 46 + "hint": masked_email, 47 + })) { 48 + Ok(s) => s, 49 + Err(_) => return Err(StatusCode::BAD_REQUEST), 50 + }; 51 + 52 + Response::builder() 53 + .status(StatusCode::BAD_REQUEST) 54 + .header(CONTENT_TYPE, "application/json") 55 + .body(Body::from(body_str)) 56 + .map_err(|_| StatusCode::BAD_REQUEST) 57 + } 58 + AuthResult::ProxyThrough => { 59 + //No 2FA or already passed 60 + let uri = format!( 61 + "{}{}", 62 + state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in" 63 + ); 64 + 65 + let mut req = axum::http::Request::post(uri); 66 + if let Some(req_headers) = req.headers_mut() { 67 + // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers 68 + copy_filtered_headers(&headers, req_headers); 69 + //Setting the content type to application/json manually 70 + req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 71 + } 72 + 73 + //Clears the email_otp because the pds will reject a request with it. 74 + payload.email_otp = None; 75 + let payload_bytes = 76 + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 77 + 78 + let req = req 79 + .body(Body::from(payload_bytes)) 80 + .map_err(|_| StatusCode::BAD_REQUEST)?; 81 + 82 + let proxied = state 83 + .reverse_proxy_client 84 + .request(req) 85 + .await 86 + .map_err(|_| StatusCode::BAD_REQUEST)? 87 + .into_response(); 88 + 89 + Ok(proxied) 90 + } 91 + //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same. 92 + AuthResult::TokenCheckFailed(_) => oauth_json_error_response( 93 + StatusCode::BAD_REQUEST, 94 + "invalid_request", 95 + "Unable to sign-in due to an unexpected server error", 96 + ), 97 + }, 98 + Err(err) => { 99 + log::error!( 100 + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" 101 + ); 102 + oauth_json_error_response( 103 + StatusCode::BAD_REQUEST, 104 + "pds_gatekeeper_error", 105 + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 106 + ) 107 + } 108 + } 109 + } 110 + 111 + fn is_disallowed_header(name: &HeaderName) -> bool { 112 + // possible problematic headers with proxying 113 + matches!( 114 + name.as_str(), 115 + "connection" 116 + | "keep-alive" 117 + | "proxy-authenticate" 118 + | "proxy-authorization" 119 + | "te" 120 + | "trailer" 121 + | "transfer-encoding" 122 + | "upgrade" 123 + | "host" 124 + | "content-length" 125 + | "content-encoding" 126 + | "expect" 127 + | "accept-encoding" 128 + ) 129 + } 130 + 131 + fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { 132 + for (name, value) in src.iter() { 133 + if is_disallowed_header(name) { 134 + continue; 135 + } 136 + // Only copy valid headers 137 + if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { 138 + dst.insert(name.clone(), hv); 139 + } 140 + } 141 + }
+66 -211
src/xrpc/com_atproto_server.rs
··· 1 1 use crate::AppState; 2 + use crate::helpers::{ 3 + AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json, 4 + }; 2 5 use crate::middleware::Did; 3 - use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json}; 4 6 use axum::body::Body; 5 7 use axum::extract::State; 6 8 use axum::http::{HeaderMap, StatusCode}; 7 9 use axum::response::{IntoResponse, Response}; 8 10 use axum::{Extension, Json, debug_handler, extract, extract::Request}; 9 - use axum_template::TemplateEngine; 10 - use lettre::message::{MultiPart, SinglePart, header}; 11 - use lettre::{AsyncTransport, Message}; 12 11 use serde::{Deserialize, Serialize}; 13 12 use serde_json; 14 - use serde_json::Value; 15 - use serde_json::value::Map; 16 13 use tracing::log; 17 14 18 15 #[derive(Serialize, Deserialize, Debug, Clone)] ··· 58 55 pub struct CreateSessionRequest { 59 56 identifier: String, 60 57 password: String, 61 - auth_factor_token: String, 62 - allow_takendown: bool, 63 - } 64 - 65 - pub enum AuthResult { 66 - WrongIdentityOrPassword, 67 - TwoFactorRequired, 68 - TwoFactorFailed, 69 - /// User does not have 2FA enabled, or passes it 70 - ProxyThrough, 71 - } 72 - 73 - pub enum IdentifierType { 74 - Email, 75 - DID, 76 - Handle, 77 - } 78 - 79 - impl IdentifierType { 80 - fn what_is_it(identifier: String) -> Self { 81 - if identifier.contains("@") { 82 - IdentifierType::Email 83 - } else if identifier.contains("did:") { 84 - IdentifierType::DID 85 - } else { 86 - IdentifierType::Handle 87 - } 88 - } 89 - } 90 - 91 - async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> { 92 - // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) 93 - let mut parts = password_scrypt.splitn(2, ':'); 94 - let salt = match parts.next() { 95 - Some(s) if !s.is_empty() => s, 96 - _ => return Ok(false), 97 - }; 98 - let stored_hash_hex = match parts.next() { 99 - Some(h) if !h.is_empty() => h, 100 - _ => return Ok(false), 101 - }; 102 - 103 - //Sets up scrypt to mimic node's scrypt 104 - let params = match scrypt::Params::new(14, 8, 1, 64) { 105 - Ok(p) => p, 106 - Err(_) => return Ok(false), 107 - }; 108 - let mut derived = [0u8; 64]; 109 - if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived).is_err() { 110 - return Ok(false); 111 - } 112 - 113 - let stored_bytes = match hex::decode(stored_hash_hex) { 114 - Ok(b) => b, 115 - Err(e) => { 116 - log::error!("Error decoding stored hash: {}", e); 117 - return Ok(false); 118 - } 119 - }; 120 - 121 - Ok(derived.as_slice() == stored_bytes.as_slice()) 122 - } 123 - 124 - async fn preauth_check( 125 - state: &AppState, 126 - identifier: &str, 127 - password: &str, 128 - ) -> Result<AuthResult, StatusCode> { 129 - // Determine identifier type 130 - let id_type = IdentifierType::what_is_it(identifier.to_string()); 131 - 132 - // Query account DB for did and passwordScrypt based on identifier type 133 - let account_row: Option<(String, String, String)> = match id_type { 134 - IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>( 135 - "SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1", 136 - ) 137 - .bind(identifier) 138 - .fetch_optional(&state.account_pool) 139 - .await 140 - .map_err(|_| StatusCode::BAD_REQUEST)?, 141 - IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>( 142 - "SELECT account.did, account.passwordScrypt, account.email 143 - FROM actor 144 - LEFT JOIN account ON actor.did = account.did 145 - where actor.handle =? LIMIT 1", 146 - ) 147 - .bind(identifier) 148 - .fetch_optional(&state.account_pool) 149 - .await 150 - .map_err(|_| StatusCode::BAD_REQUEST)?, 151 - IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>( 152 - "SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1", 153 - ) 154 - .bind(identifier) 155 - .fetch_optional(&state.account_pool) 156 - .await 157 - .map_err(|_| StatusCode::BAD_REQUEST)?, 158 - }; 159 - 160 - if let Some((did, password_scrypt, email)) = account_row { 161 - // Check two-factor requirement for this DID in the gatekeeper DB 162 - let required_opt = sqlx::query_as::<_, (u8,)>( 163 - "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 164 - ) 165 - .bind(&did) 166 - .fetch_optional(&state.pds_gatekeeper_pool) 167 - .await 168 - .map_err(|_| StatusCode::BAD_REQUEST)?; 169 - 170 - let two_factor_required = match required_opt { 171 - Some(row) => row.0 != 0, 172 - None => false, 173 - }; 174 - 175 - if two_factor_required { 176 - // Verify password before proceeding to 2FA email step 177 - let verified = verify_password(password, &password_scrypt).await?; 178 - if !verified { 179 - return Ok(AuthResult::WrongIdentityOrPassword); 180 - } 181 - let mut email_data = Map::new(); 182 - //TODO these need real values 183 - let token = "test".to_string(); 184 - let handle = "baileytownsend.dev".to_string(); 185 - email_data.insert("token".to_string(), Value::from(token.clone())); 186 - email_data.insert("handle".to_string(), Value::from(handle.clone())); 187 - //TODO bad unwrap 188 - let email_body = state 189 - .template_engine 190 - .render("two_factor_code.hbs", email_data) 191 - .unwrap(); 192 - 193 - let email = Message::builder() 194 - //TODO prob get the proper type in the state 195 - .from(state.mailer_from.parse().unwrap()) 196 - .to(email.parse().unwrap()) 197 - .subject("Sign in to Bluesky") 198 - .multipart( 199 - MultiPart::alternative() // This is composed of two parts. 200 - .singlepart( 201 - SinglePart::builder() 202 - .header(header::ContentType::TEXT_PLAIN) 203 - .body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback. 204 - ) 205 - .singlepart( 206 - SinglePart::builder() 207 - .header(header::ContentType::TEXT_HTML) 208 - .body(email_body), 209 - ), 210 - ) 211 - //TODO bad 212 - .unwrap(); 213 - return match state.mailer.send(email).await { 214 - Ok(_) => Ok(AuthResult::TwoFactorRequired), 215 - Err(err) => { 216 - log::error!("Error sending the 2FA email: {}", err); 217 - Err(StatusCode::BAD_REQUEST) 218 - } 219 - }; 220 - } 221 - } 222 - 223 - // No local 2FA requirement (or account not found) 224 - Ok(AuthResult::ProxyThrough) 58 + #[serde(skip_serializing_if = "Option::is_none")] 59 + auth_factor_token: Option<String>, 60 + #[serde(skip_serializing_if = "Option::is_none")] 61 + allow_takendown: Option<bool>, 225 62 } 226 63 227 64 pub async fn create_session( ··· 231 68 ) -> Result<Response<Body>, StatusCode> { 232 69 let identifier = payload.identifier.clone(); 233 70 let password = payload.password.clone(); 71 + let auth_factor_token = payload.auth_factor_token.clone(); 234 72 235 73 // Run the shared pre-auth logic to validate and check 2FA requirement 236 - match preauth_check(&state, &identifier, &password).await? { 237 - AuthResult::WrongIdentityOrPassword => json_error_response( 238 - StatusCode::UNAUTHORIZED, 239 - "AuthenticationRequired", 240 - "Invalid identifier or password", 241 - ), 242 - AuthResult::TwoFactorRequired => { 243 - // Email sending step can be handled here if needed in the future. 244 - json_error_response( 74 + match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { 75 + Ok(result) => match result { 76 + AuthResult::WrongIdentityOrPassword => json_error_response( 245 77 StatusCode::UNAUTHORIZED, 246 - "AuthFactorTokenRequired", 247 - "A sign in code has been sent to your email address", 248 - ) 249 - } 250 - AuthResult::TwoFactorFailed => { 251 - //Not sure what the errors are for this response is yet 252 - json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER") 253 - } 254 - AuthResult::ProxyThrough => { 255 - //No 2FA or already passed 256 - let uri = format!( 257 - "{}{}", 258 - state.pds_base_url, "/xrpc/com.atproto.server.createSession" 259 - ); 260 - 261 - let mut req = axum::http::Request::post(uri); 262 - if let Some(req_headers) = req.headers_mut() { 263 - req_headers.extend(headers.clone()); 78 + "AuthenticationRequired", 79 + "Invalid identifier or password", 80 + ), 81 + AuthResult::TwoFactorRequired(_) => { 82 + // Email sending step can be handled here if needed in the future. 83 + json_error_response( 84 + StatusCode::UNAUTHORIZED, 85 + "AuthFactorTokenRequired", 86 + "A sign in code has been sent to your email address", 87 + ) 264 88 } 89 + AuthResult::ProxyThrough => { 90 + log::info!("Proxying through"); 91 + //No 2FA or already passed 92 + let uri = format!( 93 + "{}{}", 94 + state.pds_base_url, "/xrpc/com.atproto.server.createSession" 95 + ); 265 96 266 - let payload_bytes = 267 - serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 268 - let req = req 269 - .body(Body::from(payload_bytes)) 270 - .map_err(|_| StatusCode::BAD_REQUEST)?; 97 + let mut req = axum::http::Request::post(uri); 98 + if let Some(req_headers) = req.headers_mut() { 99 + req_headers.extend(headers.clone()); 100 + } 101 + 102 + let payload_bytes = 103 + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 104 + let req = req 105 + .body(Body::from(payload_bytes)) 106 + .map_err(|_| StatusCode::BAD_REQUEST)?; 271 107 272 - let proxied = state 273 - .reverse_proxy_client 274 - .request(req) 275 - .await 276 - .map_err(|_| StatusCode::BAD_REQUEST)? 277 - .into_response(); 108 + let proxied = state 109 + .reverse_proxy_client 110 + .request(req) 111 + .await 112 + .map_err(|_| StatusCode::BAD_REQUEST)? 113 + .into_response(); 278 114 279 - Ok(proxied) 115 + Ok(proxied) 116 + } 117 + AuthResult::TokenCheckFailed(err) => match err { 118 + TokenCheckError::InvalidToken => { 119 + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid") 120 + } 121 + TokenCheckError::ExpiredToken => { 122 + json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired") 123 + } 124 + }, 125 + }, 126 + Err(err) => { 127 + log::error!( 128 + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" 129 + ); 130 + json_error_response( 131 + StatusCode::INTERNAL_SERVER_ERROR, 132 + "InternalServerError", 133 + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 134 + ) 280 135 } 281 136 } 282 137 } ··· 290 145 ) -> Result<Response<Body>, StatusCode> { 291 146 //If email auth is not set at all it is a update email address 292 147 let email_auth_not_set = payload.email_auth_factor.is_none(); 293 - //If email aurth is set it is to either turn on or off 2fa 148 + //If email auth is set it is to either turn on or off 2fa 294 149 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 295 150 296 151 // Email update asked for ··· 350 205 } 351 206 } 352 207 353 - // Updating the acutal email address 208 + // Updating the actual email address by sending it on to the PDS 354 209 let uri = format!( 355 210 "{}{}", 356 211 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
-150
src/xrpc/helpers.rs
··· 1 - use axum::body::{Body, to_bytes}; 2 - use axum::extract::Request; 3 - use axum::http::{HeaderMap, Method, StatusCode, Uri}; 4 - use axum::http::header::CONTENT_TYPE; 5 - use axum::response::{IntoResponse, Response}; 6 - use serde::de::DeserializeOwned; 7 - use tracing::error; 8 - 9 - use crate::AppState; 10 - 11 - /// The result of a proxied call that attempts to parse JSON. 12 - pub enum ProxiedResult<T> { 13 - /// Successfully parsed JSON body along with original response headers. 14 - Parsed { value: T, _headers: HeaderMap }, 15 - /// Could not or should not parse: return the original (or rebuilt) response as-is. 16 - Passthrough(Response<Body>), 17 - } 18 - 19 - /// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse 20 - /// the successful response body as JSON into `T`. 21 - /// 22 - /// Behavior: 23 - /// - If the proxied response is non-200, returns Passthrough with the original response. 24 - /// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers. 25 - /// - If parsing succeeds, returns Parsed { value, headers }. 26 - pub async fn proxy_get_json<T>( 27 - state: &AppState, 28 - mut req: Request, 29 - path: &str, 30 - ) -> Result<ProxiedResult<T>, StatusCode> 31 - where 32 - T: DeserializeOwned, 33 - { 34 - let uri = format!("{}{}", state.pds_base_url, path); 35 - *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 36 - 37 - let result = state 38 - .reverse_proxy_client 39 - .request(req) 40 - .await 41 - .map_err(|_| StatusCode::BAD_REQUEST)? 42 - .into_response(); 43 - 44 - if result.status() != StatusCode::OK { 45 - return Ok(ProxiedResult::Passthrough(result)); 46 - } 47 - 48 - let response_headers = result.headers().clone(); 49 - let body = result.into_body(); 50 - let body_bytes = to_bytes(body, usize::MAX) 51 - .await 52 - .map_err(|_| StatusCode::BAD_REQUEST)?; 53 - 54 - match serde_json::from_slice::<T>(&body_bytes) { 55 - Ok(value) => Ok(ProxiedResult::Parsed { 56 - value, 57 - _headers: response_headers, 58 - }), 59 - Err(err) => { 60 - error!(%err, "failed to parse proxied JSON response; returning original body"); 61 - let mut builder = Response::builder().status(StatusCode::OK); 62 - if let Some(headers) = builder.headers_mut() { 63 - *headers = response_headers; 64 - } 65 - let resp = builder 66 - .body(Body::from(body_bytes)) 67 - .map_err(|_| StatusCode::BAD_REQUEST)?; 68 - Ok(ProxiedResult::Passthrough(resp)) 69 - } 70 - } 71 - } 72 - 73 - /// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse 74 - /// the successful response body as JSON into `T`. 75 - /// 76 - /// Behavior mirrors `proxy_get_json`: 77 - /// - If the proxied response is non-200, returns Passthrough with the original response. 78 - /// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers. 79 - /// - If parsing succeeds, returns Parsed { value, headers }. 80 - pub async fn _proxy_post_json<T>( 81 - state: &AppState, 82 - mut req: Request, 83 - path: &str, 84 - ) -> Result<ProxiedResult<T>, StatusCode> 85 - where 86 - T: DeserializeOwned, 87 - { 88 - let uri = format!("{}{}", state.pds_base_url, path); 89 - *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 90 - *req.method_mut() = Method::POST; 91 - 92 - let result = state 93 - .reverse_proxy_client 94 - .request(req) 95 - .await 96 - .map_err(|_| StatusCode::BAD_REQUEST)? 97 - .into_response(); 98 - 99 - if result.status() != StatusCode::OK { 100 - return Ok(ProxiedResult::Passthrough(result)); 101 - } 102 - 103 - let response_headers = result.headers().clone(); 104 - let body = result.into_body(); 105 - let body_bytes = to_bytes(body, usize::MAX) 106 - .await 107 - .map_err(|_| StatusCode::BAD_REQUEST)?; 108 - 109 - match serde_json::from_slice::<T>(&body_bytes) { 110 - Ok(value) => Ok(ProxiedResult::Parsed { 111 - value, 112 - _headers: response_headers, 113 - }), 114 - Err(err) => { 115 - error!(%err, "failed to parse proxied JSON response (POST); returning original body"); 116 - let mut builder = Response::builder().status(StatusCode::OK); 117 - if let Some(headers) = builder.headers_mut() { 118 - *headers = response_headers; 119 - } 120 - let resp = builder 121 - .body(Body::from(body_bytes)) 122 - .map_err(|_| StatusCode::BAD_REQUEST)?; 123 - Ok(ProxiedResult::Passthrough(resp)) 124 - } 125 - } 126 - } 127 - 128 - 129 - /// Build a JSON error response with the required Content-Type header 130 - /// Content-Type: application/json;charset=utf-8 131 - /// Body shape: { "error": string, "message": string } 132 - pub fn json_error_response( 133 - status: StatusCode, 134 - error: impl Into<String>, 135 - message: impl Into<String>, 136 - ) -> Result<Response<Body>, StatusCode> { 137 - let body_str = match serde_json::to_string(&serde_json::json!({ 138 - "error": error.into(), 139 - "message": message.into(), 140 - })) { 141 - Ok(s) => s, 142 - Err(_) => return Err(StatusCode::BAD_REQUEST), 143 - }; 144 - 145 - Response::builder() 146 - .status(status) 147 - .header(CONTENT_TYPE, "application/json;charset=utf-8") 148 - .body(Body::from(body_str)) 149 - .map_err(|_| StatusCode::BAD_REQUEST) 150 - }
-1
src/xrpc/mod.rs
··· 1 1 pub mod com_atproto_server; 2 - pub mod helpers;