back interdiff of round #2 and #1

2FA logins gatekept #1

merged
opened by baileytownsend.dev targeting main from feature/2faCodeGeneration
ERROR
migrations_bells_and_whistles/.keep

Failed to calculate interdiff for this file.

REBASED
Cargo.lock

This patch was likely rebased, as context lines do not match.

ERROR
Cargo.toml

Failed to calculate interdiff for this file.

REBASED
src/xrpc/helpers.rs

This patch was likely rebased, as context lines do not match.

REBASED
src/middleware.rs

This patch was likely rebased, as context lines do not match.

ERROR
src/xrpc/mod.rs

Failed to calculate interdiff for this file.

ERROR
README.md

Failed to calculate interdiff for this file.

ERROR
src/helpers.rs

Failed to calculate interdiff for this file.

ERROR
src/oauth_provider.rs

Failed to calculate interdiff for this file.

NEW
src/main.rs
··· 1 + #![warn(clippy::unwrap_used)] 2 + use crate::oauth_provider::sign_in; 1 3 use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 2 - use axum::middleware as ax_middleware; 3 - mod middleware; 4 4 use axum::body::Body; 5 5 use axum::handler::Handler; 6 6 use axum::http::{Method, header}; 7 + use axum::middleware as ax_middleware; 7 8 use axum::routing::post; 8 9 use axum::{Router, routing::get}; 9 10 use axum_template::engine::Engine; ··· 21 22 use tower_governor::governor::GovernorConfigBuilder; 22 23 use tower_http::compression::CompressionLayer; 23 24 use tower_http::cors::{Any, CorsLayer}; 24 - use tracing::{error, log}; 25 + use tracing::log; 25 26 use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 26 27 28 + pub mod helpers; 29 + mod middleware; 30 + mod oauth_provider; 27 31 mod xrpc; 28 32 29 33 type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; ··· 34 38 struct EmailTemplates; 35 39 36 40 #[derive(Clone)] 37 - struct AppState { 41 + pub struct AppState { 38 42 account_pool: SqlitePool, 39 43 pds_gatekeeper_pool: SqlitePool, 40 44 reverse_proxy_client: HyperUtilClient, ··· 73 77 74 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 75 79 76 - let banner = format!(" {}\n{}", body, intro); 80 + let banner = format!(" {body}\n{intro}"); 77 81 78 82 ( 79 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], ··· 84 88 #[tokio::main] 85 89 async fn main() -> Result<(), Box<dyn std::error::Error>> { 86 90 setup_tracing(); 87 - //TODO prod 91 + //TODO may need to change where this reads from? Like an env variable for it's location? Or arg? 88 92 dotenvy::from_path(Path::new("./pds.env"))?; 89 93 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 90 - // let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data"; 91 - let account_db_url = format!("{}/account.sqlite", pds_root); 92 - log::info!("accounts_db_url: {}", account_db_url); 94 + let account_db_url = format!("{pds_root}/account.sqlite"); 93 95 94 96 let account_options = SqliteConnectOptions::new() 95 - .journal_mode(SqliteJournalMode::Wal) 96 - .filename(account_db_url); 97 + .filename(account_db_url) 98 + .busy_timeout(Duration::from_secs(5)); 97 99 98 100 let account_pool = SqlitePoolOptions::new() 99 101 .max_connections(5) 100 102 .connect_with(account_options) 101 103 .await?; 102 104 103 - let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root); 105 + let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 104 106 let options = SqliteConnectOptions::new() 105 107 .journal_mode(SqliteJournalMode::Wal) 106 108 .filename(bells_db_url) 107 - .create_if_missing(true); 109 + .create_if_missing(true) 110 + .busy_timeout(Duration::from_secs(5)); 108 111 let pds_gatekeeper_pool = SqlitePoolOptions::new() 109 112 .max_connections(5) 110 113 .connect_with(options) 111 114 .await?; 112 115 113 - // Run migrations for the bells_and_whistles database 116 + // Run migrations for the extra database 114 117 // Note: the migrations are embedded at compile time from the given directory 115 118 // sqlx 116 119 sqlx::migrate!("./migrations") ··· 130 133 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 131 134 //Email templates setup 132 135 let mut hbs = Handlebars::new(); 133 - let _ = hbs.register_embed_templates::<EmailTemplates>(); 136 + 137 + let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 138 + if let Ok(users_email_directory) = users_email_directory { 139 + hbs.register_template_file( 140 + "two_factor_code.hbs", 141 + format!("{users_email_directory}/two_factor_code.hbs"), 142 + )?; 143 + } else { 144 + let _ = hbs.register_embed_templates::<EmailTemplates>(); 145 + } 146 + 147 + let pds_base_url = 148 + env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 134 149 135 150 let state = AppState { 136 151 account_pool, 137 152 pds_gatekeeper_pool, 138 153 reverse_proxy_client: client, 139 - //TODO should be env prob 140 - pds_base_url: "http://localhost:3000".to_string(), 154 + pds_base_url, 141 155 mailer, 142 156 mailer_from: sent_from, 143 157 template_engine: Engine::from(hbs), ··· 145 159 146 160 // Rate limiting 147 161 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 148 - let governor_conf = GovernorConfigBuilder::default() 162 + let create_session_governor_conf = GovernorConfigBuilder::default() 149 163 .per_second(60) 150 164 .burst_size(5) 151 165 .finish() 152 - .unwrap(); 153 - let governor_limiter = governor_conf.limiter().clone(); 166 + .expect("failed to create governor config. this should not happen and is a bug"); 167 + 168 + // Create a second config with the same settings for the other endpoint 169 + let sign_in_governor_conf = GovernorConfigBuilder::default() 170 + .per_second(60) 171 + .burst_size(5) 172 + .finish() 173 + .expect("failed to create governor config. this should not happen and is a bug"); 174 + 175 + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 176 + let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 154 177 let interval = Duration::from_secs(60); 155 178 // a separate background task to clean up 156 179 std::thread::spawn(move || { 157 180 loop { 158 181 std::thread::sleep(interval); 159 - tracing::info!("rate limiting storage size: {}", governor_limiter.len()); 160 - governor_limiter.retain_recent(); 182 + create_session_governor_limiter.retain_recent(); 183 + sign_in_governor_limiter.retain_recent(); 161 184 } 162 185 }); 163 186 ··· 176 199 "/xrpc/com.atproto.server.updateEmail", 177 200 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 178 201 ) 202 + .route( 203 + "/@atproto/oauth-provider/~api/sign-in", 204 + post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 205 + ) 179 206 .route( 180 207 "/xrpc/com.atproto.server.createSession", 181 - post(create_session.layer(GovernorLayer::new(governor_conf))), 208 + post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 182 209 ) 183 210 .layer(CompressionLayer::new()) 184 211 .layer(cors) 185 212 .with_state(state); 186 213 187 - let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 188 - let port: u16 = env::var("PORT") 214 + let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 215 + let port: u16 = env::var("GATEKEEPER_PORT") 189 216 .ok() 190 217 .and_then(|s| s.parse().ok()) 191 218 .unwrap_or(8080); ··· 202 229 .with_graceful_shutdown(shutdown_signal()); 203 230 204 231 if let Err(err) = server.await { 205 - error!(error = %err, "server error"); 232 + log::error!("server error:{err}"); 206 233 } 207 234 208 235 Ok(())
NEW
src/xrpc/com_atproto_server.rs
··· 1 1 use crate::AppState; 2 + use crate::helpers::{ 3 + AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json, 4 + }; 2 5 use crate::middleware::Did; 3 - use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json}; 4 6 use axum::body::Body; 5 7 use axum::extract::State; 6 8 use axum::http::{HeaderMap, StatusCode}; 7 9 use axum::response::{IntoResponse, Response}; 8 10 use axum::{Extension, Json, debug_handler, extract, extract::Request}; 9 - use axum_template::TemplateEngine; 10 - use lettre::message::{MultiPart, SinglePart, header}; 11 - use lettre::{AsyncTransport, Message}; 12 11 use serde::{Deserialize, Serialize}; 13 12 use serde_json; 14 - use serde_json::Value; 15 - use serde_json::value::Map; 16 13 use tracing::log; 17 14 18 15 #[derive(Serialize, Deserialize, Debug, Clone)] ··· 58 55 pub struct CreateSessionRequest { 59 56 identifier: String, 60 57 password: String, 61 - auth_factor_token: String, 62 - allow_takendown: bool, 63 - } 64 - 65 - pub enum AuthResult { 66 - WrongIdentityOrPassword, 67 - TwoFactorRequired, 68 - TwoFactorFailed, 69 - /// User does not have 2FA enabled, or passes it 70 - ProxyThrough, 71 - } 72 - 73 - pub enum IdentifierType { 74 - Email, 75 - DID, 76 - Handle, 77 - } 78 - 79 - impl IdentifierType { 80 - fn what_is_it(identifier: String) -> Self { 81 - if identifier.contains("@") { 82 - IdentifierType::Email 83 - } else if identifier.contains("did:") { 84 - IdentifierType::DID 85 - } else { 86 - IdentifierType::Handle 87 - } 88 - } 89 - } 90 - 91 - async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> { 92 - // Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes) 93 - let mut parts = password_scrypt.splitn(2, ':'); 94 - let salt = match parts.next() { 95 - Some(s) if !s.is_empty() => s, 96 - _ => return Ok(false), 97 - }; 98 - let stored_hash_hex = match parts.next() { 99 - Some(h) if !h.is_empty() => h, 100 - _ => return Ok(false), 101 - }; 102 - 103 - //Sets up scrypt to mimic node's scrypt 104 - let params = match scrypt::Params::new(14, 8, 1, 64) { 105 - Ok(p) => p, 106 - Err(_) => return Ok(false), 107 - }; 108 - let mut derived = [0u8; 64]; 109 - if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), &params, &mut derived).is_err() { 110 - return Ok(false); 111 - } 112 - 113 - let stored_bytes = match hex::decode(stored_hash_hex) { 114 - Ok(b) => b, 115 - Err(e) => { 116 - log::error!("Error decoding stored hash: {}", e); 117 - return Ok(false); 118 - } 119 - }; 120 - 121 - Ok(derived.as_slice() == stored_bytes.as_slice()) 122 - } 123 - 124 - async fn preauth_check( 125 - state: &AppState, 126 - identifier: &str, 127 - password: &str, 128 - ) -> Result<AuthResult, StatusCode> { 129 - // Determine identifier type 130 - let id_type = IdentifierType::what_is_it(identifier.to_string()); 131 - 132 - // Query account DB for did and passwordScrypt based on identifier type 133 - let account_row: Option<(String, String, String)> = match id_type { 134 - IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>( 135 - "SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1", 136 - ) 137 - .bind(identifier) 138 - .fetch_optional(&state.account_pool) 139 - .await 140 - .map_err(|_| StatusCode::BAD_REQUEST)?, 141 - IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>( 142 - "SELECT account.did, account.passwordScrypt, account.email 143 - FROM actor 144 - LEFT JOIN account ON actor.did = account.did 145 - where actor.handle =? LIMIT 1", 146 - ) 147 - .bind(identifier) 148 - .fetch_optional(&state.account_pool) 149 - .await 150 - .map_err(|_| StatusCode::BAD_REQUEST)?, 151 - IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>( 152 - "SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1", 153 - ) 154 - .bind(identifier) 155 - .fetch_optional(&state.account_pool) 156 - .await 157 - .map_err(|_| StatusCode::BAD_REQUEST)?, 158 - }; 159 - 160 - if let Some((did, password_scrypt, email)) = account_row { 161 - // Check two-factor requirement for this DID in the gatekeeper DB 162 - let required_opt = sqlx::query_as::<_, (u8,)>( 163 - "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 164 - ) 165 - .bind(&did) 166 - .fetch_optional(&state.pds_gatekeeper_pool) 167 - .await 168 - .map_err(|_| StatusCode::BAD_REQUEST)?; 169 - 170 - let two_factor_required = match required_opt { 171 - Some(row) => row.0 != 0, 172 - None => false, 173 - }; 174 - 175 - if two_factor_required { 176 - // Verify password before proceeding to 2FA email step 177 - let verified = verify_password(password, &password_scrypt).await?; 178 - if !verified { 179 - return Ok(AuthResult::WrongIdentityOrPassword); 180 - } 181 - let mut email_data = Map::new(); 182 - //TODO these need real values 183 - let token = "test".to_string(); 184 - let handle = "baileytownsend.dev".to_string(); 185 - email_data.insert("token".to_string(), Value::from(token.clone())); 186 - email_data.insert("handle".to_string(), Value::from(handle.clone())); 187 - //TODO bad unwrap 188 - let email_body = state 189 - .template_engine 190 - .render("two_factor_code.hbs", email_data) 191 - .unwrap(); 192 - 193 - let email = Message::builder() 194 - //TODO prob get the proper type in the state 195 - .from(state.mailer_from.parse().unwrap()) 196 - .to(email.parse().unwrap()) 197 - .subject("Sign in to Bluesky") 198 - .multipart( 199 - MultiPart::alternative() // This is composed of two parts. 200 - .singlepart( 201 - SinglePart::builder() 202 - .header(header::ContentType::TEXT_PLAIN) 203 - .body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback. 204 - ) 205 - .singlepart( 206 - SinglePart::builder() 207 - .header(header::ContentType::TEXT_HTML) 208 - .body(email_body), 209 - ), 210 - ) 211 - //TODO bad 212 - .unwrap(); 213 - return match state.mailer.send(email).await { 214 - Ok(_) => Ok(AuthResult::TwoFactorRequired), 215 - Err(err) => { 216 - log::error!("Error sending the 2FA email: {}", err); 217 - Err(StatusCode::BAD_REQUEST) 218 - } 219 - }; 220 - } 221 - } 222 - 223 - // No local 2FA requirement (or account not found) 224 - Ok(AuthResult::ProxyThrough) 58 + #[serde(skip_serializing_if = "Option::is_none")] 59 + auth_factor_token: Option<String>, 60 + #[serde(skip_serializing_if = "Option::is_none")] 61 + allow_takendown: Option<bool>, 225 62 } 226 63 227 64 pub async fn create_session( ··· 231 68 ) -> Result<Response<Body>, StatusCode> { 232 69 let identifier = payload.identifier.clone(); 233 70 let password = payload.password.clone(); 71 + let auth_factor_token = payload.auth_factor_token.clone(); 234 72 235 73 // Run the shared pre-auth logic to validate and check 2FA requirement 236 - match preauth_check(&state, &identifier, &password).await? { 237 - AuthResult::WrongIdentityOrPassword => json_error_response( 238 - StatusCode::UNAUTHORIZED, 239 - "AuthenticationRequired", 240 - "Invalid identifier or password", 241 - ), 242 - AuthResult::TwoFactorRequired => { 243 - // Email sending step can be handled here if needed in the future. 244 - json_error_response( 74 + match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { 75 + Ok(result) => match result { 76 + AuthResult::WrongIdentityOrPassword => json_error_response( 245 77 StatusCode::UNAUTHORIZED, 246 - "AuthFactorTokenRequired", 247 - "A sign in code has been sent to your email address", 248 - ) 249 - } 250 - AuthResult::TwoFactorFailed => { 251 - //Not sure what the errors are for this response is yet 252 - json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER") 253 - } 254 - AuthResult::ProxyThrough => { 255 - //No 2FA or already passed 256 - let uri = format!( 257 - "{}{}", 258 - state.pds_base_url, "/xrpc/com.atproto.server.createSession" 259 - ); 260 - 261 - let mut req = axum::http::Request::post(uri); 262 - if let Some(req_headers) = req.headers_mut() { 263 - req_headers.extend(headers.clone()); 78 + "AuthenticationRequired", 79 + "Invalid identifier or password", 80 + ), 81 + AuthResult::TwoFactorRequired(_) => { 82 + // Email sending step can be handled here if needed in the future. 83 + json_error_response( 84 + StatusCode::UNAUTHORIZED, 85 + "AuthFactorTokenRequired", 86 + "A sign in code has been sent to your email address", 87 + ) 264 88 } 89 + AuthResult::ProxyThrough => { 90 + log::info!("Proxying through"); 91 + //No 2FA or already passed 92 + let uri = format!( 93 + "{}{}", 94 + state.pds_base_url, "/xrpc/com.atproto.server.createSession" 95 + ); 96 + 97 + let mut req = axum::http::Request::post(uri); 98 + if let Some(req_headers) = req.headers_mut() { 99 + req_headers.extend(headers.clone()); 100 + } 265 101 266 - let payload_bytes = 267 - serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 268 - let req = req 269 - .body(Body::from(payload_bytes)) 270 - .map_err(|_| StatusCode::BAD_REQUEST)?; 102 + let payload_bytes = 103 + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 104 + let req = req 105 + .body(Body::from(payload_bytes)) 106 + .map_err(|_| StatusCode::BAD_REQUEST)?; 271 107 272 - let proxied = state 273 - .reverse_proxy_client 274 - .request(req) 275 - .await 276 - .map_err(|_| StatusCode::BAD_REQUEST)? 277 - .into_response(); 108 + let proxied = state 109 + .reverse_proxy_client 110 + .request(req) 111 + .await 112 + .map_err(|_| StatusCode::BAD_REQUEST)? 113 + .into_response(); 278 114 279 - Ok(proxied) 115 + Ok(proxied) 116 + } 117 + AuthResult::TokenCheckFailed(err) => match err { 118 + TokenCheckError::InvalidToken => { 119 + json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid") 120 + } 121 + TokenCheckError::ExpiredToken => { 122 + json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired") 123 + } 124 + }, 125 + }, 126 + Err(err) => { 127 + log::error!( 128 + "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" 129 + ); 130 + json_error_response( 131 + StatusCode::INTERNAL_SERVER_ERROR, 132 + "InternalServerError", 133 + "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 134 + ) 280 135 } 281 136 } 282 137 } ··· 290 145 ) -> Result<Response<Body>, StatusCode> { 291 146 //If email auth is not set at all it is a update email address 292 147 let email_auth_not_set = payload.email_auth_factor.is_none(); 293 - //If email aurth is set it is to either turn on or off 2fa 148 + //If email auth is set it is to either turn on or off 2fa 294 149 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 295 150 296 151 // Email update asked for ··· 350 205 } 351 206 } 352 207 353 - // Updating the acutal email address 208 + // Updating the actual email address by sending it on to the PDS 354 209 let uri = format!( 355 210 "{}{}", 356 211 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"