2FA logins gatekept #1

merged
opened by baileytownsend.dev targeting main from feature/2faCodeGeneration
+10
Cargo.lock
··· 287 287 dependencies = [ 288 288 "android-tzdata", 289 289 "iana-time-zone", 290 + "js-sys", 290 291 "num-traits", 292 + "wasm-bindgen", 291 293 "windows-link", 292 294 ] 293 295 ··· 1652 1654 name = "pds_gatekeeper" 1653 1655 version = "0.1.0" 1654 1656 dependencies = [ 1657 + "anyhow", 1655 1658 "axum", 1656 1659 "axum-template", 1660 + "chrono", 1657 1661 "dotenvy", 1658 1662 "handlebars", 1659 1663 "hex", 1660 1664 "hyper-util", 1661 1665 "jwt-compact", 1662 1666 "lettre", 1667 + "rand 0.9.2", 1663 1668 "rust-embed", 1664 1669 "scrypt", 1665 1670 "serde", 1666 1671 "serde_json", 1672 + "sha2", 1667 1673 "sqlx", 1668 1674 "tokio", 1669 1675 "tower-http", ··· 2393 2399 dependencies = [ 2394 2400 "base64", 2395 2401 "bytes", 2402 + "chrono", 2396 2403 "crc", 2397 2404 "crossbeam-queue", 2398 2405 "either", ··· 2470 2477 "bitflags", 2471 2478 "byteorder", 2472 2479 "bytes", 2480 + "chrono", 2473 2481 "crc", 2474 2482 "digest", 2475 2483 "dotenvy", ··· 2511 2519 "base64", 2512 2520 "bitflags", 2513 2521 "byteorder", 2522 + "chrono", 2514 2523 "crc", 2515 2524 "dotenvy", 2516 2525 "etcetera", ··· 2545 2554 checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea" 2546 2555 dependencies = [ 2547 2556 "atoi", 2557 + "chrono", 2548 2558 "flume", 2549 2559 "futures-channel", 2550 2560 "futures-core",
+5 -1
Cargo.toml
··· 6 6 [dependencies] 7 7 axum = { version = "0.8.4", features = ["macros", "json"] } 8 8 tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] } 9 - sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] } 9 + sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate", "chrono"] } 10 10 dotenvy = "0.15.7" 11 11 serde = { version = "1.0", features = ["derive"] } 12 12 serde_json = "1.0" ··· 22 22 handlebars = { version = "6.3.2", features = ["rust-embed"] } 23 23 rust-embed = "8.7.2" 24 24 axum-template = { version = "3.0.0", features = ["handlebars"] } 25 + rand = "0.9.2" 26 + anyhow = "1.0.99" 27 + chrono = "0.4.41" 28 + sha2 = "0.10"
+5 -6
README.md
··· 12 12 13 13 ## 2FA 14 14 15 - - [x] Ability to turn on/off 2FA 16 - - [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on 17 - - [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email. 18 - - [ ] generate a 2FA code 19 - - [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet) 20 - - [ ] oauth endpoint gatekeeping 15 + - Overrides The login endpoint to add 2FA for both Bluesky client logged in and OAuth logins 16 + - Overrides the settings endpoints as well. As long as you have a confirmed email you can turn on 2FA 21 17 22 18 ## Captcha on Create Account 23 19 ··· 25 21 26 22 # Setup 27 23 24 + We are getting close! Testing now 25 + 28 26 Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up. 29 27 But I want to run it locally on my own PDS first to test run it a bit. 30 28 ··· 37 35 path /xrpc/com.atproto.server.getSession 38 36 path /xrpc/com.atproto.server.updateEmail 39 37 path /xrpc/com.atproto.server.createSession 38 + path /@atproto/oauth-provider/~api/sign-in 40 39 } 41 40 42 41 handle @gatekeeper {
-3
migrations_bells_and_whistles/.keep
··· 1 - # This directory holds SQLx migrations for the bells_and_whistles.sqlite database. 2 - # It is intentionally empty for now; running `sqlx::migrate!` will still ensure the 3 - # migrations table exists and succeed with zero migrations.
+524
src/helpers.rs
··· 1 + use crate::AppState; 2 + use crate::helpers::TokenCheckError::InvalidToken; 3 + use anyhow::anyhow; 4 + use axum::body::{Body, to_bytes}; 5 + use axum::extract::Request; 6 + use axum::http::header::CONTENT_TYPE; 7 + use axum::http::{HeaderMap, StatusCode, Uri}; 8 + use axum::response::{IntoResponse, Response}; 9 + use axum_template::TemplateEngine; 10 + use chrono::Utc; 11 + use lettre::message::{MultiPart, SinglePart, header}; 12 + use lettre::{AsyncTransport, Message}; 13 + use rand::Rng; 14 + use serde::de::DeserializeOwned; 15 + use serde_json::{Map, Value}; 16 + use sha2::{Digest, Sha256}; 17 + use sqlx::SqlitePool; 18 + use tracing::{error, log}; 19 + 20 + ///Used to generate the email 2fa code 21 + const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567"; 22 + 23 + /// The result of a proxied call that attempts to parse JSON. 24 + pub enum ProxiedResult<T> { 25 + /// Successfully parsed JSON body along with original response headers. 26 + Parsed { value: T, _headers: HeaderMap }, 27 + /// Could not or should not parse: return the original (or rebuilt) response as-is. 28 + Passthrough(Response<Body>), 29 + } 30 + 31 + /// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse 32 + /// the successful response body as JSON into `T`. 33 + /// 34 + pub async fn proxy_get_json<T>( 35 + state: &AppState, 36 + mut req: Request, 37 + path: &str, 38 + ) -> Result<ProxiedResult<T>, StatusCode> 39 + where 40 + T: DeserializeOwned, 41 + { 42 + let uri = format!("{}{}", state.pds_base_url, path); 43 + *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 44 + 45 + let result = state 46 + .reverse_proxy_client 47 + .request(req) 48 + .await 49 + .map_err(|_| StatusCode::BAD_REQUEST)? 50 + .into_response(); 51 + 52 + if result.status() != StatusCode::OK { 53 + return Ok(ProxiedResult::Passthrough(result)); 54 + } 55 + 56 + let response_headers = result.headers().clone(); 57 + let body = result.into_body(); 58 + let body_bytes = to_bytes(body, usize::MAX) 59 + .await 60 + .map_err(|_| StatusCode::BAD_REQUEST)?; 61 + 62 + match serde_json::from_slice::<T>(&body_bytes) { 63 + Ok(value) => Ok(ProxiedResult::Parsed { 64 + value, 65 + _headers: response_headers, 66 + }), 67 + Err(err) => { 68 + error!(%err, "failed to parse proxied JSON response; returning original body"); 69 + let mut builder = Response::builder().status(StatusCode::OK); 70 + if let Some(headers) = builder.headers_mut() { 71 + *headers = response_headers; 72 + } 73 + let resp = builder 74 + .body(Body::from(body_bytes)) 75 + .map_err(|_| StatusCode::BAD_REQUEST)?; 76 + Ok(ProxiedResult::Passthrough(resp)) 77 + } 78 + } 79 + } 80 + 81 + /// Build a JSON error response with the required Content-Type header 82 + /// Content-Type: application/json;charset=utf-8 83 + /// Body shape: { "error": string, "message": string } 84 + pub fn json_error_response( 85 + status: StatusCode, 86 + error: impl Into<String>, 87 + message: impl Into<String>, 88 + ) -> Result<Response<Body>, StatusCode> { 89 + let body_str = match serde_json::to_string(&serde_json::json!({ 90 + "error": error.into(), 91 + "message": message.into(), 92 + })) { 93 + Ok(s) => s, 94 + Err(_) => return Err(StatusCode::BAD_REQUEST), 95 + }; 96 + 97 + Response::builder() 98 + .status(status) 99 + .header(CONTENT_TYPE, "application/json;charset=utf-8") 100 + .body(Body::from(body_str)) 101 + .map_err(|_| StatusCode::BAD_REQUEST) 102 + } 103 + 104 + /// Build a JSON error response with the required Content-Type header 105 + /// Content-Type: application/json (oauth endpoint does not like utf ending) 106 + /// Body shape: { "error": string, "error_description": string } 107 + pub fn oauth_json_error_response( 108 + status: StatusCode, 109 + error: impl Into<String>, 110 + message: impl Into<String>, 111 + ) -> Result<Response<Body>, StatusCode> { 112 + let body_str = match serde_json::to_string(&serde_json::json!({ 113 + "error": error.into(), 114 + "error_description": message.into(), 115 + })) { 116 + Ok(s) => s, 117 + Err(_) => return Err(StatusCode::BAD_REQUEST), 118 + }; 119 + 120 + Response::builder() 121 + .status(status) 122 + .header(CONTENT_TYPE, "application/json") 123 + .body(Body::from(body_str)) 124 + .map_err(|_| StatusCode::BAD_REQUEST) 125 + } 126 + 127 + /// Creates a random token of 10 characters for email 2FA 128 + pub fn get_random_token() -> String { 129 + let mut rng = rand::rng(); 130 + 131 + let mut full_code = String::with_capacity(10); 132 + for _ in 0..10 { 133 + let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len()); 134 + full_code.push(UPPERCASE_BASE32_CHARS[idx] as char); 135 + } 136 + 137 + //The PDS implementation creates in lowercase, then converts to uppercase. 138 + //Just going a head and doing uppercase here. 139 + let slice_one = &full_code[0..5].to_ascii_uppercase(); 140 + let slice_two = &full_code[5..10].to_ascii_uppercase(); 141 + format!("{slice_one}-{slice_two}") 142 + } 143 + 144 + pub enum TokenCheckError { 145 + InvalidToken, 146 + ExpiredToken, 147 + } 148 + 149 + pub enum AuthResult { 150 + WrongIdentityOrPassword, 151 + /// The string here is the email address to create a hint for oauth 152 + TwoFactorRequired(String), 153 + /// User does not have 2FA enabled, or using an app password, or passes it 154 + ProxyThrough, 155 + TokenCheckFailed(TokenCheckError), 156 + } 157 + 158 + pub enum IdentifierType { 159 + Email, 160 + Did, 161 + Handle, 162 + } 163 + 164 + impl IdentifierType { 165 + fn what_is_it(identifier: String) -> Self { 166 + if identifier.contains("@") { 167 + IdentifierType::Email 168 + } else if identifier.contains("did:") { 169 + IdentifierType::Did 170 + } else { 171 + IdentifierType::Handle 172 + } 173 + } 174 + } 175 + 176 + /// Creates a hex string from the password and salt to find app passwords 177 + fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> { 178 + let params = scrypt::Params::new(14, 8, 1, 64)?; 179 + let mut derived = [0u8; 64]; 180 + scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived)?; 181 + Ok(hex::encode(derived)) 182 + } 183 + 184 + /// Hashes the app password. did is used as the salt. 185 + pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> { 186 + let mut hasher = Sha256::new(); 187 + hasher.update(did.as_bytes()); 188 + let sha = hasher.finalize(); 189 + let salt = hex::encode(&sha[..16]); 190 + let hash_hex = scrypt_hex(password, &salt)?; 191 + Ok(format!("{salt}:{hash_hex}")) 192 + } 193 + 194 + async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> { 195 + // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) 196 + let mut parts = password_scrypt.splitn(2, ':'); 197 + let salt = match parts.next() { 198 + Some(s) if !s.is_empty() => s, 199 + _ => return Ok(false), 200 + }; 201 + let stored_hash_hex = match parts.next() { 202 + Some(h) if !h.is_empty() => h, 203 + _ => return Ok(false), 204 + }; 205 + 206 + // Derive using the shared helper and compare 207 + let derived_hex = match scrypt_hex(password, salt) { 208 + Ok(h) => h, 209 + Err(_) => return Ok(false), 210 + }; 211 + 212 + Ok(derived_hex.as_str() == stored_hash_hex) 213 + } 214 + 215 + /// Handles the auth checks along with sending a 2fa email 216 + pub async fn preauth_check( 217 + state: &AppState, 218 + identifier: &str, 219 + password: &str, 220 + two_factor_code: Option<String>, 221 + oauth: bool, 222 + ) -> anyhow::Result<AuthResult> { 223 + // Determine identifier type 224 + let id_type = IdentifierType::what_is_it(identifier.to_string()); 225 + 226 + // Query account DB for did and passwordScrypt based on identifier type 227 + let account_row: Option<(String, String, String, String)> = match id_type { 228 + IdentifierType::Email => { 229 + sqlx::query_as::<_, (String, String, String, String)>( 230 + "SELECT account.did, account.passwordScrypt, account.email, actor.handle 231 + FROM actor 232 + LEFT JOIN account ON actor.did = account.did 233 + where account.email = ? LIMIT 1", 234 + ) 235 + .bind(identifier) 236 + .fetch_optional(&state.account_pool) 237 + .await? 238 + } 239 + IdentifierType::Handle => { 240 + sqlx::query_as::<_, (String, String, String, String)>( 241 + "SELECT account.did, account.passwordScrypt, account.email, actor.handle 242 + FROM actor 243 + LEFT JOIN account ON actor.did = account.did 244 + where actor.handle = ? LIMIT 1", 245 + ) 246 + .bind(identifier) 247 + .fetch_optional(&state.account_pool) 248 + .await? 249 + } 250 + IdentifierType::Did => { 251 + sqlx::query_as::<_, (String, String, String, String)>( 252 + "SELECT account.did, account.passwordScrypt, account.email, actor.handle 253 + FROM actor 254 + LEFT JOIN account ON actor.did = account.did 255 + where account.did = ? LIMIT 1", 256 + ) 257 + .bind(identifier) 258 + .fetch_optional(&state.account_pool) 259 + .await? 260 + } 261 + }; 262 + 263 + if let Some((did, password_scrypt, email, handle)) = account_row { 264 + // Verify password before proceeding to 2FA email step 265 + let verified = verify_password(password, &password_scrypt).await?; 266 + if !verified { 267 + if oauth { 268 + //OAuth does not allow app password logins so just go ahead and send it along it's way 269 + return Ok(AuthResult::WrongIdentityOrPassword); 270 + } 271 + //Theres a chance it could be an app password so check that as well 272 + return match verify_app_password(&state.account_pool, &did, password).await { 273 + Ok(valid) => { 274 + if valid { 275 + //Was a valid app password up to the PDS now 276 + Ok(AuthResult::ProxyThrough) 277 + } else { 278 + Ok(AuthResult::WrongIdentityOrPassword) 279 + } 280 + } 281 + Err(err) => { 282 + log::error!("Error checking the app password: {err}"); 283 + Err(err) 284 + } 285 + }; 286 + } 287 + 288 + // Check two-factor requirement for this DID in the gatekeeper DB 289 + let required_opt = sqlx::query_as::<_, (u8,)>( 290 + "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 291 + ) 292 + .bind(did.clone()) 293 + .fetch_optional(&state.pds_gatekeeper_pool) 294 + .await?; 295 + 296 + let two_factor_required = match required_opt { 297 + Some(row) => row.0 != 0, 298 + None => false, 299 + }; 300 + 301 + if two_factor_required { 302 + //Two factor is required and a taken was provided 303 + if let Some(two_factor_code) = two_factor_code { 304 + //if the two_factor_code is set need to see if we have a valid token 305 + if !two_factor_code.is_empty() { 306 + return match assert_valid_token( 307 + &state.account_pool, 308 + did.clone(), 309 + two_factor_code, 310 + ) 311 + .await 312 + { 313 + Ok(_) => { 314 + let result_of_cleanup = 315 + delete_all_email_tokens(&state.account_pool, did.clone()).await; 316 + if result_of_cleanup.is_err() { 317 + log::error!( 318 + "There was an error deleting the email tokens after login: {:?}", 319 + result_of_cleanup.err() 320 + ) 321 + } 322 + Ok(AuthResult::ProxyThrough) 323 + } 324 + Err(err) => Ok(AuthResult::TokenCheckFailed(err)), 325 + }; 326 + } 327 + } 328 + 329 + return match create_two_factor_token(&state.account_pool, did).await { 330 + Ok(code) => { 331 + let mut email_data = Map::new(); 332 + email_data.insert("token".to_string(), Value::from(code.clone())); 333 + email_data.insert("handle".to_string(), Value::from(handle.clone())); 334 + let email_body = state 335 + .template_engine 336 + .render("two_factor_code.hbs", email_data)?; 337 + 338 + let email_message = Message::builder() 339 + //TODO prob get the proper type in the state 340 + .from(state.mailer_from.parse()?) 341 + .to(email.parse()?) 342 + .subject("Sign in to Bluesky") 343 + .multipart( 344 + MultiPart::alternative() // This is composed of two parts. 345 + .singlepart( 346 + SinglePart::builder() 347 + .header(header::ContentType::TEXT_PLAIN) 348 + .body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback. 349 + ) 350 + .singlepart( 351 + SinglePart::builder() 352 + .header(header::ContentType::TEXT_HTML) 353 + .body(email_body), 354 + ), 355 + )?; 356 + match state.mailer.send(email_message).await { 357 + Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))), 358 + Err(err) => { 359 + log::error!("Error sending the 2FA email: {err}"); 360 + Err(anyhow!(err)) 361 + } 362 + } 363 + } 364 + Err(err) => { 365 + log::error!("error on creating a 2fa token: {err}"); 366 + Err(anyhow!(err)) 367 + } 368 + }; 369 + } 370 + } 371 + 372 + // No local 2FA requirement (or account not found) 373 + Ok(AuthResult::ProxyThrough) 374 + } 375 + 376 + pub async fn create_two_factor_token( 377 + account_db: &SqlitePool, 378 + did: String, 379 + ) -> anyhow::Result<String> { 380 + let purpose = "2fa_code"; 381 + 382 + let token = get_random_token(); 383 + let right_now = Utc::now(); 384 + 385 + let res = sqlx::query( 386 + "INSERT INTO email_token (purpose, did, token, requestedAt) 387 + VALUES (?, ?, ?, ?) 388 + ON CONFLICT(purpose, did) DO UPDATE SET 389 + token=excluded.token, 390 + requestedAt=excluded.requestedAt 391 + WHERE did=excluded.did", 392 + ) 393 + .bind(purpose) 394 + .bind(&did) 395 + .bind(&token) 396 + .bind(right_now) 397 + .execute(account_db) 398 + .await; 399 + 400 + match res { 401 + Ok(_) => Ok(token), 402 + Err(err) => { 403 + log::error!("Error creating a two factor token: {err}"); 404 + Err(anyhow::anyhow!(err)) 405 + } 406 + } 407 + } 408 + 409 + pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> { 410 + sqlx::query("DELETE FROM email_token WHERE did = ?") 411 + .bind(did) 412 + .execute(account_db) 413 + .await?; 414 + Ok(()) 415 + } 416 + 417 + pub async fn assert_valid_token( 418 + account_db: &SqlitePool, 419 + did: String, 420 + token: String, 421 + ) -> Result<(), TokenCheckError> { 422 + let token_upper = token.to_ascii_uppercase(); 423 + let purpose = "2fa_code"; 424 + 425 + let row: Option<(String,)> = sqlx::query_as( 426 + "SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1", 427 + ) 428 + .bind(purpose) 429 + .bind(did) 430 + .bind(token_upper) 431 + .fetch_optional(account_db) 432 + .await 433 + .map_err(|err| { 434 + log::error!("Error getting the 2fa token: {err}"); 435 + InvalidToken 436 + })?; 437 + 438 + match row { 439 + None => Err(InvalidToken), 440 + Some(row) => { 441 + // Token lives for 15 minutes 442 + let expiration_ms = 15 * 60_000; 443 + 444 + let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) { 445 + Ok(dt) => dt.with_timezone(&Utc), 446 + Err(_) => { 447 + return Err(TokenCheckError::InvalidToken); 448 + } 449 + }; 450 + 451 + let now = Utc::now(); 452 + let age_ms = (now - requested_at_utc).num_milliseconds(); 453 + let expired = age_ms > expiration_ms; 454 + if expired { 455 + return Err(TokenCheckError::ExpiredToken); 456 + } 457 + 458 + Ok(()) 459 + } 460 + } 461 + } 462 + 463 + /// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions 464 + pub async fn verify_app_password( 465 + account_db: &SqlitePool, 466 + did: &str, 467 + password: &str, 468 + ) -> anyhow::Result<bool> { 469 + let password_scrypt = hash_app_password(did, password)?; 470 + 471 + let row: Option<(i64,)> = sqlx::query_as( 472 + "SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1", 473 + ) 474 + .bind(did) 475 + .bind(password_scrypt) 476 + .fetch_optional(account_db) 477 + .await?; 478 + 479 + Ok(match row { 480 + None => false, 481 + Some((count,)) => count > 0, 482 + }) 483 + } 484 + 485 + /// Mask an email address into a hint like "2***0@p***m". 486 + pub fn mask_email(email: String) -> String { 487 + // Basic split on first '@' 488 + let mut parts = email.splitn(2, '@'); 489 + let local = match parts.next() { 490 + Some(l) => l, 491 + None => return email.to_string(), 492 + }; 493 + let domain_rest = match parts.next() { 494 + Some(d) if !d.is_empty() => d, 495 + _ => return email.to_string(), 496 + }; 497 + 498 + // Helper to mask a single label (keep first and last, middle becomes ***). 499 + fn mask_label(s: &str) -> String { 500 + let chars: Vec<char> = s.chars().collect(); 501 + match chars.len() { 502 + 0 => String::new(), 503 + 1 => format!("{}***", chars[0]), 504 + 2 => format!("{}***{}", chars[0], chars[1]), 505 + _ => format!("{}***{}", chars[0], chars[chars.len() - 1]), 506 + } 507 + } 508 + 509 + // Mask local 510 + let masked_local = mask_label(local); 511 + 512 + // Mask first domain label only, keep the rest of the domain intact 513 + let mut dom_parts = domain_rest.splitn(2, '.'); 514 + let first_label = dom_parts.next().unwrap_or(""); 515 + let rest = dom_parts.next(); 516 + let masked_first = mask_label(first_label); 517 + let masked_domain = if let Some(rest) = rest { 518 + format!("{}.{rest}", masked_first) 519 + } else { 520 + masked_first 521 + }; 522 + 523 + format!("{masked_local}@{masked_domain}") 524 + }
+53 -26
src/main.rs
··· 1 + #![warn(clippy::unwrap_used)] 2 + use crate::oauth_provider::sign_in; 1 3 use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 2 - use axum::middleware as ax_middleware; 3 - mod middleware; 4 4 use axum::body::Body; 5 5 use axum::handler::Handler; 6 6 use axum::http::{Method, header}; 7 + use axum::middleware as ax_middleware; 7 8 use axum::routing::post; 8 9 use axum::{Router, routing::get}; 9 10 use axum_template::engine::Engine; ··· 21 22 use tower_governor::governor::GovernorConfigBuilder; 22 23 use tower_http::compression::CompressionLayer; 23 24 use tower_http::cors::{Any, CorsLayer}; 24 - use tracing::{error, log}; 25 + use tracing::log; 25 26 use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 26 27 28 + pub mod helpers; 29 + mod middleware; 30 + mod oauth_provider; 27 31 mod xrpc; 28 32 29 33 type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; ··· 34 38 struct EmailTemplates; 35 39 36 40 #[derive(Clone)] 37 - struct AppState { 41 + pub struct AppState { 38 42 account_pool: SqlitePool, 39 43 pds_gatekeeper_pool: SqlitePool, 40 44 reverse_proxy_client: HyperUtilClient, ··· 73 77 74 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 75 79 76 - let banner = format!(" {}\n{}", body, intro); 80 + let banner = format!(" {body}\n{intro}"); 77 81 78 82 ( 79 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], ··· 84 88 #[tokio::main] 85 89 async fn main() -> Result<(), Box<dyn std::error::Error>> { 86 90 setup_tracing(); 87 - //TODO prod 91 + //TODO may need to change where this reads from? Like an env variable for it's location? Or arg? 88 92 dotenvy::from_path(Path::new("./pds.env"))?; 89 93 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 90 - // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data"; 91 - let account_db_url = format!("{}/account.sqlite", pds_root); 92 - log::info!("accounts_db_url: {}", account_db_url); 94 + let account_db_url = format!("{pds_root}/account.sqlite"); 93 95 94 96 let account_options = SqliteConnectOptions::new() 95 - .journal_mode(SqliteJournalMode::Wal) 96 - .filename(account_db_url); 97 + .filename(account_db_url) 98 + .busy_timeout(Duration::from_secs(5)); 97 99 98 100 let account_pool = SqlitePoolOptions::new() 99 101 .max_connections(5) 100 102 .connect_with(account_options) 101 103 .await?; 102 104 103 - let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root); 105 + let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 104 106 let options = SqliteConnectOptions::new() 105 107 .journal_mode(SqliteJournalMode::Wal) 106 108 .filename(bells_db_url) 107 - .create_if_missing(true); 109 + .create_if_missing(true) 110 + .busy_timeout(Duration::from_secs(5)); 108 111 let pds_gatekeeper_pool = SqlitePoolOptions::new() 109 112 .max_connections(5) 110 113 .connect_with(options) 111 114 .await?; 112 115 113 - // Run migrations for the bells_and_whistles database 116 + // Run migrations for the extra database 114 117 // Note: the migrations are embedded at compile time from the given directory 115 118 // sqlx 116 119 sqlx::migrate!("./migrations") ··· 130 133 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 131 134 //Email templates setup 132 135 let mut hbs = Handlebars::new(); 133 - let _ = hbs.register_embed_templates::<EmailTemplates>(); 136 + 137 + let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 138 + if let Ok(users_email_directory) = users_email_directory { 139 + hbs.register_template_file( 140 + "two_factor_code.hbs", 141 + format!("{users_email_directory}/two_factor_code.hbs"), 142 + )?; 143 + } else { 144 + let _ = hbs.register_embed_templates::<EmailTemplates>(); 145 + } 146 + 147 + let pds_base_url = 148 + env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 134 149 135 150 let state = AppState { 136 151 account_pool, 137 152 pds_gatekeeper_pool, 138 153 reverse_proxy_client: client, 139 - //TODO should be env prob 140 - pds_base_url: "http://localhost:3000".to_string(), 154 + pds_base_url, 141 155 mailer, 142 156 mailer_from: sent_from, 143 157 template_engine: Engine::from(hbs), ··· 145 159 146 160 // Rate limiting 147 161 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 148 - let governor_conf = GovernorConfigBuilder::default() 162 + let create_session_governor_conf = GovernorConfigBuilder::default() 149 163 .per_second(60) 150 164 .burst_size(5) 151 165 .finish() 152 - .unwrap(); 153 - let governor_limiter = governor_conf.limiter().clone(); 166 + .expect("failed to create governor config. this should not happen and is a bug"); 167 + 168 + // Create a second config with the same settings for the other endpoint 169 + let sign_in_governor_conf = GovernorConfigBuilder::default() 170 + .per_second(60) 171 + .burst_size(5) 172 + .finish() 173 + .expect("failed to create governor config. this should not happen and is a bug"); 174 + 175 + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 176 + let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 154 177 let interval = Duration::from_secs(60); 155 178 // a separate background task to clean up 156 179 std::thread::spawn(move || { 157 180 loop { 158 181 std::thread::sleep(interval); 159 - tracing::info!("rate limiting storage size: {}", governor_limiter.len()); 160 - governor_limiter.retain_recent(); 182 + create_session_governor_limiter.retain_recent(); 183 + sign_in_governor_limiter.retain_recent(); 161 184 } 162 185 }); 163 186 ··· 176 199 "/xrpc/com.atproto.server.updateEmail", 177 200 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 178 201 ) 202 + .route( 203 + "/@atproto/oauth-provider/~api/sign-in", 204 + post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 205 + ) 179 206 .route( 180 207 "/xrpc/com.atproto.server.createSession", 181 - post(create_session.layer(GovernorLayer::new(governor_conf))), 208 + post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 182 209 ) 183 210 .layer(CompressionLayer::new()) 184 211 .layer(cors) 185 212 .with_state(state); 186 213 187 - let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 188 - let port: u16 = env::var("PORT") 214 + let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 215 + let port: u16 = env::var("GATEKEEPER_PORT") 189 216 .ok() 190 217 .and_then(|s| s.parse().ok()) 191 218 .unwrap_or(8080); ··· 202 229 .with_graceful_shutdown(shutdown_signal()); 203 230 204 231 if let Err(err) = server.await { 205 - error!(error = %err, "server error"); 232 + log::error!("server error:{err}"); 206 233 } 207 234 208 235 Ok(())
+19 -34
src/middleware.rs
··· 1 - use crate::xrpc::helpers::json_error_response; 1 + use crate::helpers::json_error_response; 2 2 use axum::extract::Request; 3 3 use axum::http::{HeaderMap, StatusCode}; 4 4 use axum::middleware::Next; ··· 7 7 use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError}; 8 8 use serde::{Deserialize, Serialize}; 9 9 use std::env; 10 + use tracing::log; 10 11 11 12 #[derive(Clone, Debug)] 12 13 pub struct Did(pub Option<String>); ··· 22 23 match token { 23 24 Ok(token) => { 24 25 match token { 25 - None => { 26 - return json_error_response( 27 - StatusCode::BAD_REQUEST, 28 - "TokenRequired", 29 - "", 30 - ).unwrap(); 31 - } 26 + None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 27 + .expect("Error creating an error response"), 32 28 Some(token) => { 33 29 let token = UntrustedToken::new(&token); 34 - //Doing weird unwraps cause I can't do Result for middleware? 35 30 if token.is_err() { 36 - return json_error_response( 37 - StatusCode::BAD_REQUEST, 38 - "TokenRequired", 39 - "", 40 - ).unwrap(); 31 + return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 32 + .expect("Error creating an error response"); 41 33 } 42 - let parsed_token = token.unwrap(); 34 + let parsed_token = token.expect("Already checked for error"); 43 35 let claims: Result<Claims<TokenClaims>, ValidationError> = 44 36 parsed_token.deserialize_claims_unchecked(); 45 37 if claims.is_err() { 46 - return json_error_response( 47 - StatusCode::BAD_REQUEST, 48 - "TokenRequired", 49 - "", 50 - ).unwrap(); 38 + return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 39 + .expect("Error creating an error response"); 51 40 } 52 41 53 - let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap()); 42 + let key = Hs256Key::new( 43 + env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), 44 + ); 54 45 let token: Result<Token<TokenClaims>, ValidationError> = 55 46 Hs256.validator(&key).validate(&parsed_token); 56 47 if token.is_err() { 57 - return json_error_response( 58 - StatusCode::BAD_REQUEST, 59 - "InvalidToken", 60 - "", 61 - ).unwrap(); 48 + return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 49 + .expect("Error creating an error response"); 62 50 } 63 - let token = token.unwrap(); 51 + let token = token.expect("Already checked for error,"); 64 52 //Not going to worry about expiration since it still goes to the PDS 65 - 66 53 req.extensions_mut() 67 54 .insert(Did(Some(token.claims().custom.sub.clone()))); 68 55 next.run(req).await 69 56 } 70 57 } 71 58 } 72 - Err(_) => { 73 - return json_error_response( 74 - StatusCode::BAD_REQUEST, 75 - "InvalidToken", 76 - "", 77 - ).unwrap(); 59 + Err(err) => { 60 + log::error!("Error extracting token: {err}"); 61 + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 62 + .expect("Error creating an error response") 78 63 } 79 64 } 80 65 }
+141
src/oauth_provider.rs
··· 1 + use crate::AppState; 2 + use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check}; 3 + use axum::body::Body; 4 + use axum::extract::State; 5 + use axum::http::header::CONTENT_TYPE; 6 + use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; 7 + use axum::response::{IntoResponse, Response}; 8 + use axum::{Json, extract}; 9 + use serde::{Deserialize, Serialize}; 10 + use tracing::log; 11 + 12 + #[derive(Serialize, Deserialize, Clone)] 13 + pub struct SignInRequest { 14 + pub username: String, 15 + pub password: String, 16 + pub remember: bool, 17 + pub locale: String, 18 + #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 19 + pub email_otp: Option<String>, 20 + } 21 + 22 + pub async fn sign_in( 23 + State(state): State<AppState>, 24 + headers: HeaderMap, 25 + Json(mut payload): extract::Json<SignInRequest>, 26 + ) -> Result<Response<Body>, StatusCode> { 27 + let identifier = payload.username.clone(); 28 + let password = payload.password.clone(); 29 + let auth_factor_token = payload.email_otp.clone(); 30 + 31 + match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { 32 + Ok(result) => match result { 33 + AuthResult::WrongIdentityOrPassword => oauth_json_error_response( 34 + StatusCode::BAD_REQUEST, 35 + "invalid_request", 36 + "Invalid identifier or password", 37 + ), 38 + AuthResult::TwoFactorRequired(masked_email) => { 39 + // Email sending step can be handled here if needed in the future. 40 + 41 + // {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"} 42 + let body_str = match serde_json::to_string(&serde_json::json!({ 43 + "error": "second_authentication_factor_required", 44 + "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email), 45 + "type": "emailOtp", 46 + "hint": masked_email, 47 + })) { 48 + Ok(s) => s, 49 + Err(_) => return Err(StatusCode::BAD_REQUEST), 50 + }; 51 + 52 + Response::builder() 53 + .status(StatusCode::BAD_REQUEST) 54 + .header(CONTENT_TYPE, "application/json") 55 + .body(Body::from(body_str)) 56 + .map_err(|_| StatusCode::BAD_REQUEST) 57 + } 58 + AuthResult::ProxyThrough => { 59 + //No 2FA or already passed 60 + let uri = format!( 61 + "{}{}", 62 + state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in" 63 + ); 64 + 65 + let mut req = axum::http::Request::post(uri); 66 + if let Some(req_headers) = req.headers_mut() { 67 + // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers 68 + copy_filtered_headers(&headers, req_headers); 69 + //Setting the content type to application/json manually 70 + req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 71 + } 72 + 73 + //Clears the email_otp because the pds will reject a request with it. 74 + payload.email_otp = None; 75 + let payload_bytes = 76 + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 77 + 78 + let req = req 79 + .body(Body::from(payload_bytes)) 80 + .map_err(|_| StatusCode::BAD_REQUEST)?; 81 + 82 + let proxied = state 83 + .reverse_proxy_client 84 + .request(req) 85 + .await 86 + .map_err(|_| StatusCode::BAD_REQUEST)? 87 + .into_response(); 88 + 89 + Ok(proxied) 90 + } 91 + //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same. 92 + AuthResult::TokenCheckFailed(_) => oauth_json_error_response( 93 + StatusCode::BAD_REQUEST, 94 + "invalid_request", 95 + "Unable to sign-in due to an unexpected server error", 96 + ), 97 + }, 98 + Err(err) => { 99 + log::error!( 100 + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" 101 + ); 102 + oauth_json_error_response( 103 + StatusCode::BAD_REQUEST, 104 + "pds_gatekeeper_error", 105 + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 106 + ) 107 + } 108 + } 109 + } 110 + 111 + fn is_disallowed_header(name: &HeaderName) -> bool { 112 + // possible problematic headers with proxying 113 + matches!( 114 + name.as_str(), 115 + "connection" 116 + | "keep-alive" 117 + | "proxy-authenticate" 118 + | "proxy-authorization" 119 + | "te" 120 + | "trailer" 121 + | "transfer-encoding" 122 + | "upgrade" 123 + | "host" 124 + | "content-length" 125 + | "content-encoding" 126 + | "expect" 127 + | "accept-encoding" 128 + ) 129 + } 130 + 131 + fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { 132 + for (name, value) in src.iter() { 133 + if is_disallowed_header(name) { 134 + continue; 135 + } 136 + // Only copy valid headers 137 + if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { 138 + dst.insert(name.clone(), hv); 139 + } 140 + } 141 + }
+66 -211
src/xrpc/com_atproto_server.rs
··· 1 1 use crate::AppState; 2 + use crate::helpers::{ 3 + AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json, 4 + }; 2 5 use crate::middleware::Did; 3 - use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json}; 4 6 use axum::body::Body; 5 7 use axum::extract::State; 6 8 use axum::http::{HeaderMap, StatusCode}; 7 9 use axum::response::{IntoResponse, Response}; 8 10 use axum::{Extension, Json, debug_handler, extract, extract::Request}; 9 - use axum_template::TemplateEngine; 10 - use lettre::message::{MultiPart, SinglePart, header}; 11 - use lettre::{AsyncTransport, Message}; 12 11 use serde::{Deserialize, Serialize}; 13 12 use serde_json; 14 - use serde_json::Value; 15 - use serde_json::value::Map; 16 13 use tracing::log; 17 14 18 15 #[derive(Serialize, Deserialize, Debug, Clone)] ··· 58 55 pub struct CreateSessionRequest { 59 56 identifier: String, 60 57 password: String, 61 - auth_factor_token: String, 62 - allow_takendown: bool, 63 - } 64 - 65 - pub enum AuthResult { 66 - WrongIdentityOrPassword, 67 - TwoFactorRequired, 68 - TwoFactorFailed, 69 - /// User does not have 2FA enabled, or passes it 70 - ProxyThrough, 71 - } 72 - 73 - pub enum IdentifierType { 74 - Email, 75 - DID, 76 - Handle, 77 - } 78 - 79 - impl IdentifierType { 80 - fn what_is_it(identifier: String) -> Self { 81 - if identifier.contains("@") { 82 - IdentifierType::Email 83 - } else if identifier.contains("did:") { 84 - IdentifierType::DID 85 - } else { 86 - IdentifierType::Handle 87 - } 88 - } 89 - } 90 - 91 - async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> { 92 - // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) 93 - let mut parts = password_scrypt.splitn(2, ':'); 94 - let salt = match parts.next() { 95 - Some(s) if !s.is_empty() => s, 96 - _ => return Ok(false), 97 - }; 98 - let stored_hash_hex = match parts.next() { 99 - Some(h) if !h.is_empty() => h, 100 - _ => return Ok(false), 101 - }; 102 - 103 - //Sets up scrypt to mimic node's scrypt 104 - let params = match scrypt::Params::new(14, 8, 1, 64) { 105 - Ok(p) => p, 106 - Err(_) => return Ok(false), 107 - }; 108 - let mut derived = [0u8; 64]; 109 - if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived).is_err() { 110 - return Ok(false); 111 - } 112 - 113 - let stored_bytes = match hex::decode(stored_hash_hex) { 114 - Ok(b) => b, 115 - Err(e) => { 116 - log::error!("Error decoding stored hash: {}", e); 117 - return Ok(false); 118 - } 119 - }; 120 - 121 - Ok(derived.as_slice() == stored_bytes.as_slice()) 122 - } 123 - 124 - async fn preauth_check( 125 - state: &AppState, 126 - identifier: &str, 127 - password: &str, 128 - ) -> Result<AuthResult, StatusCode> { 129 - // Determine identifier type 130 - let id_type = IdentifierType::what_is_it(identifier.to_string()); 131 - 132 - // Query account DB for did and passwordScrypt based on identifier type 133 - let account_row: Option<(String, String, String)> = match id_type { 134 - IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>( 135 - "SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1", 136 - ) 137 - .bind(identifier) 138 - .fetch_optional(&state.account_pool) 139 - .await 140 - .map_err(|_| StatusCode::BAD_REQUEST)?, 141 - IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>( 142 - "SELECT account.did, account.passwordScrypt, account.email 143 - FROM actor 144 - LEFT JOIN account ON actor.did = account.did 145 - where actor.handle =? LIMIT 1", 146 - ) 147 - .bind(identifier) 148 - .fetch_optional(&state.account_pool) 149 - .await 150 - .map_err(|_| StatusCode::BAD_REQUEST)?, 151 - IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>( 152 - "SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1", 153 - ) 154 - .bind(identifier) 155 - .fetch_optional(&state.account_pool) 156 - .await 157 - .map_err(|_| StatusCode::BAD_REQUEST)?, 158 - }; 159 - 160 - if let Some((did, password_scrypt, email)) = account_row { 161 - // Check two-factor requirement for this DID in the gatekeeper DB 162 - let required_opt = sqlx::query_as::<_, (u8,)>( 163 - "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 164 - ) 165 - .bind(&did) 166 - .fetch_optional(&state.pds_gatekeeper_pool) 167 - .await 168 - .map_err(|_| StatusCode::BAD_REQUEST)?; 169 - 170 - let two_factor_required = match required_opt { 171 - Some(row) => row.0 != 0, 172 - None => false, 173 - }; 174 - 175 - if two_factor_required { 176 - // Verify password before proceeding to 2FA email step 177 - let verified = verify_password(password, &password_scrypt).await?; 178 - if !verified { 179 - return Ok(AuthResult::WrongIdentityOrPassword); 180 - } 181 - let mut email_data = Map::new(); 182 - //TODO these need real values 183 - let token = "test".to_string(); 184 - let handle = "baileytownsend.dev".to_string(); 185 - email_data.insert("token".to_string(), Value::from(token.clone())); 186 - email_data.insert("handle".to_string(), Value::from(handle.clone())); 187 - //TODO bad unwrap 188 - let email_body = state 189 - .template_engine 190 - .render("two_factor_code.hbs", email_data) 191 - .unwrap(); 192 - 193 - let email = Message::builder() 194 - //TODO prob get the proper type in the state 195 - .from(state.mailer_from.parse().unwrap()) 196 - .to(email.parse().unwrap()) 197 - .subject("Sign in to Bluesky") 198 - .multipart( 199 - MultiPart::alternative() // This is composed of two parts. 200 - .singlepart( 201 - SinglePart::builder() 202 - .header(header::ContentType::TEXT_PLAIN) 203 - .body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback. 204 - ) 205 - .singlepart( 206 - SinglePart::builder() 207 - .header(header::ContentType::TEXT_HTML) 208 - .body(email_body), 209 - ), 210 - ) 211 - //TODO bad 212 - .unwrap(); 213 - return match state.mailer.send(email).await { 214 - Ok(_) => Ok(AuthResult::TwoFactorRequired), 215 - Err(err) => { 216 - log::error!("Error sending the 2FA email: {}", err); 217 - Err(StatusCode::BAD_REQUEST) 218 - } 219 - }; 220 - } 221 - } 222 - 223 - // No local 2FA requirement (or account not found) 224 - Ok(AuthResult::ProxyThrough) 58 + #[serde(skip_serializing_if = "Option::is_none")] 59 + auth_factor_token: Option<String>, 60 + #[serde(skip_serializing_if = "Option::is_none")] 61 + allow_takendown: Option<bool>, 225 62 } 226 63 227 64 pub async fn create_session( ··· 231 68 ) -> Result<Response<Body>, StatusCode> { 232 69 let identifier = payload.identifier.clone(); 233 70 let password = payload.password.clone(); 71 + let auth_factor_token = payload.auth_factor_token.clone(); 234 72 235 73 // Run the shared pre-auth logic to validate and check 2FA requirement 236 - match preauth_check(&state, &identifier, &password).await? { 237 - AuthResult::WrongIdentityOrPassword => json_error_response( 238 - StatusCode::UNAUTHORIZED, 239 - "AuthenticationRequired", 240 - "Invalid identifier or password", 241 - ), 242 - AuthResult::TwoFactorRequired => { 243 - // Email sending step can be handled here if needed in the future. 244 - json_error_response( 74 + match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { 75 + Ok(result) => match result { 76 + AuthResult::WrongIdentityOrPassword => json_error_response( 245 77 StatusCode::UNAUTHORIZED, 246 - "AuthFactorTokenRequired", 247 - "A sign in code has been sent to your email address", 248 - ) 249 - } 250 - AuthResult::TwoFactorFailed => { 251 - //Not sure what the errors are for this response is yet 252 - json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER") 253 - } 254 - AuthResult::ProxyThrough => { 255 - //No 2FA or already passed 256 - let uri = format!( 257 - "{}{}", 258 - state.pds_base_url, "/xrpc/com.atproto.server.createSession" 259 - ); 260 - 261 - let mut req = axum::http::Request::post(uri); 262 - if let Some(req_headers) = req.headers_mut() { 263 - req_headers.extend(headers.clone()); 78 + "AuthenticationRequired", 79 + "Invalid identifier or password", 80 + ), 81 + AuthResult::TwoFactorRequired(_) => { 82 + // Email sending step can be handled here if needed in the future. 83 + json_error_response( 84 + StatusCode::UNAUTHORIZED, 85 + "AuthFactorTokenRequired", 86 + "A sign in code has been sent to your email address", 87 + ) 264 88 } 89 + AuthResult::ProxyThrough => { 90 + log::info!("Proxying through"); 91 + //No 2FA or already passed 92 + let uri = format!( 93 + "{}{}", 94 + state.pds_base_url, "/xrpc/com.atproto.server.createSession" 95 + ); 96 + 97 + let mut req = axum::http::Request::post(uri); 98 + if let Some(req_headers) = req.headers_mut() { 99 + req_headers.extend(headers.clone()); 100 + } 265 101 266 - let payload_bytes = 267 - serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 268 - let req = req 269 - .body(Body::from(payload_bytes)) 270 - .map_err(|_| StatusCode::BAD_REQUEST)?; 102 + let payload_bytes = 103 + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 104 + let req = req 105 + .body(Body::from(payload_bytes)) 106 + .map_err(|_| StatusCode::BAD_REQUEST)?; 271 107 272 - let proxied = state 273 - .reverse_proxy_client 274 - .request(req) 275 - .await 276 - .map_err(|_| StatusCode::BAD_REQUEST)? 277 - .into_response(); 108 + let proxied = state 109 + .reverse_proxy_client 110 + .request(req) 111 + .await 112 + .map_err(|_| StatusCode::BAD_REQUEST)? 113 + .into_response(); 278 114 279 - Ok(proxied) 115 + Ok(proxied) 116 + } 117 + AuthResult::TokenCheckFailed(err) => match err { 118 + TokenCheckError::InvalidToken => { 119 + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid") 120 + } 121 + TokenCheckError::ExpiredToken => { 122 + json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired") 123 + } 124 + }, 125 + }, 126 + Err(err) => { 127 + log::error!( 128 + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" 129 + ); 130 + json_error_response( 131 + StatusCode::INTERNAL_SERVER_ERROR, 132 + "InternalServerError", 133 + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 134 + ) 280 135 } 281 136 } 282 137 } ··· 290 145 ) -> Result<Response<Body>, StatusCode> { 291 146 //If email auth is not set at all it is a update email address 292 147 let email_auth_not_set = payload.email_auth_factor.is_none(); 293 - //If email aurth is set it is to either turn on or off 2fa 148 + //If email auth is set it is to either turn on or off 2fa 294 149 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 295 150 296 151 // Email update asked for ··· 350 205 } 351 206 } 352 207 353 - // Updating the acutal email address 208 + // Updating the actual email address by sending it on to the PDS 354 209 let uri = format!( 355 210 "{}{}", 356 211 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
-150
src/xrpc/helpers.rs
··· 1 - use axum::body::{Body, to_bytes}; 2 - use axum::extract::Request; 3 - use axum::http::{HeaderMap, Method, StatusCode, Uri}; 4 - use axum::http::header::CONTENT_TYPE; 5 - use axum::response::{IntoResponse, Response}; 6 - use serde::de::DeserializeOwned; 7 - use tracing::error; 8 - 9 - use crate::AppState; 10 - 11 - /// The result of a proxied call that attempts to parse JSON. 12 - pub enum ProxiedResult<T> { 13 - /// Successfully parsed JSON body along with original response headers. 14 - Parsed { value: T, _headers: HeaderMap }, 15 - /// Could not or should not parse: return the original (or rebuilt) response as-is. 16 - Passthrough(Response<Body>), 17 - } 18 - 19 - /// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse 20 - /// the successful response body as JSON into `T`. 21 - /// 22 - /// Behavior: 23 - /// - If the proxied response is non-200, returns Passthrough with the original response. 24 - /// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers. 25 - /// - If parsing succeeds, returns Parsed { value, headers }. 26 - pub async fn proxy_get_json<T>( 27 - state: &AppState, 28 - mut req: Request, 29 - path: &str, 30 - ) -> Result<ProxiedResult<T>, StatusCode> 31 - where 32 - T: DeserializeOwned, 33 - { 34 - let uri = format!("{}{}", state.pds_base_url, path); 35 - *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 36 - 37 - let result = state 38 - .reverse_proxy_client 39 - .request(req) 40 - .await 41 - .map_err(|_| StatusCode::BAD_REQUEST)? 42 - .into_response(); 43 - 44 - if result.status() != StatusCode::OK { 45 - return Ok(ProxiedResult::Passthrough(result)); 46 - } 47 - 48 - let response_headers = result.headers().clone(); 49 - let body = result.into_body(); 50 - let body_bytes = to_bytes(body, usize::MAX) 51 - .await 52 - .map_err(|_| StatusCode::BAD_REQUEST)?; 53 - 54 - match serde_json::from_slice::<T>(&body_bytes) { 55 - Ok(value) => Ok(ProxiedResult::Parsed { 56 - value, 57 - _headers: response_headers, 58 - }), 59 - Err(err) => { 60 - error!(%err, "failed to parse proxied JSON response; returning original body"); 61 - let mut builder = Response::builder().status(StatusCode::OK); 62 - if let Some(headers) = builder.headers_mut() { 63 - *headers = response_headers; 64 - } 65 - let resp = builder 66 - .body(Body::from(body_bytes)) 67 - .map_err(|_| StatusCode::BAD_REQUEST)?; 68 - Ok(ProxiedResult::Passthrough(resp)) 69 - } 70 - } 71 - } 72 - 73 - /// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse 74 - /// the successful response body as JSON into `T`. 75 - /// 76 - /// Behavior mirrors `proxy_get_json`: 77 - /// - If the proxied response is non-200, returns Passthrough with the original response. 78 - /// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers. 79 - /// - If parsing succeeds, returns Parsed { value, headers }. 80 - pub async fn _proxy_post_json<T>( 81 - state: &AppState, 82 - mut req: Request, 83 - path: &str, 84 - ) -> Result<ProxiedResult<T>, StatusCode> 85 - where 86 - T: DeserializeOwned, 87 - { 88 - let uri = format!("{}{}", state.pds_base_url, path); 89 - *req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?; 90 - *req.method_mut() = Method::POST; 91 - 92 - let result = state 93 - .reverse_proxy_client 94 - .request(req) 95 - .await 96 - .map_err(|_| StatusCode::BAD_REQUEST)? 97 - .into_response(); 98 - 99 - if result.status() != StatusCode::OK { 100 - return Ok(ProxiedResult::Passthrough(result)); 101 - } 102 - 103 - let response_headers = result.headers().clone(); 104 - let body = result.into_body(); 105 - let body_bytes = to_bytes(body, usize::MAX) 106 - .await 107 - .map_err(|_| StatusCode::BAD_REQUEST)?; 108 - 109 - match serde_json::from_slice::<T>(&body_bytes) { 110 - Ok(value) => Ok(ProxiedResult::Parsed { 111 - value, 112 - _headers: response_headers, 113 - }), 114 - Err(err) => { 115 - error!(%err, "failed to parse proxied JSON response (POST); returning original body"); 116 - let mut builder = Response::builder().status(StatusCode::OK); 117 - if let Some(headers) = builder.headers_mut() { 118 - *headers = response_headers; 119 - } 120 - let resp = builder 121 - .body(Body::from(body_bytes)) 122 - .map_err(|_| StatusCode::BAD_REQUEST)?; 123 - Ok(ProxiedResult::Passthrough(resp)) 124 - } 125 - } 126 - } 127 - 128 - 129 - /// Build a JSON error response with the required Content-Type header 130 - /// Content-Type: application/json;charset=utf-8 131 - /// Body shape: { "error": string, "message": string } 132 - pub fn json_error_response( 133 - status: StatusCode, 134 - error: impl Into<String>, 135 - message: impl Into<String>, 136 - ) -> Result<Response<Body>, StatusCode> { 137 - let body_str = match serde_json::to_string(&serde_json::json!({ 138 - "error": error.into(), 139 - "message": message.into(), 140 - })) { 141 - Ok(s) => s, 142 - Err(_) => return Err(StatusCode::BAD_REQUEST), 143 - }; 144 - 145 - Response::builder() 146 - .status(status) 147 - .header(CONTENT_TYPE, "application/json;charset=utf-8") 148 - .body(Body::from(body_str)) 149 - .map_err(|_| StatusCode::BAD_REQUEST) 150 - }
-1
src/xrpc/mod.rs
··· 1 1 pub mod com_atproto_server; 2 - pub mod helpers;