Microservice to bring 2FA to self hosted PDSes

Some more clean ups

+1 -1
src/helpers.rs
··· 3 3 use anyhow::anyhow; 4 4 use axum::body::{Body, to_bytes}; 5 5 use axum::extract::Request; 6 - use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE}; 6 + use axum::http::header::CONTENT_TYPE; 7 7 use axum::http::{HeaderMap, StatusCode, Uri}; 8 8 use axum::response::{IntoResponse, Response}; 9 9 use axum_template::TemplateEngine;
+27 -16
src/main.rs
··· 22 22 use tower_governor::governor::GovernorConfigBuilder; 23 23 use tower_http::compression::CompressionLayer; 24 24 use tower_http::cors::{Any, CorsLayer}; 25 - use tracing::error; 25 + use tracing::log; 26 26 use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 27 27 28 28 pub mod helpers; ··· 88 88 #[tokio::main] 89 89 async fn main() -> Result<(), Box<dyn std::error::Error>> { 90 90 setup_tracing(); 91 - //TODO may need to change where this reads from? Like an env variable for it's location? 91 + //TODO may need to change where this reads from? Like an env variable for it's location? Or arg? 92 92 dotenvy::from_path(Path::new("./pds.env"))?; 93 93 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 94 94 let account_db_url = format!("{pds_root}/account.sqlite"); 95 95 96 96 let account_options = SqliteConnectOptions::new() 97 - .journal_mode(SqliteJournalMode::Wal) 98 - .filename(account_db_url); 97 + .filename(account_db_url) 98 + .busy_timeout(Duration::from_secs(5)); 99 99 100 100 let account_pool = SqlitePoolOptions::new() 101 101 .max_connections(5) ··· 106 106 let options = SqliteConnectOptions::new() 107 107 .journal_mode(SqliteJournalMode::Wal) 108 108 .filename(bells_db_url) 109 - .create_if_missing(true); 109 + .create_if_missing(true) 110 + .busy_timeout(Duration::from_secs(5)); 110 111 let pds_gatekeeper_pool = SqlitePoolOptions::new() 111 112 .max_connections(5) 112 113 .connect_with(options) ··· 135 136 //TODO add an override to manually load in the hbs templates 136 137 let _ = hbs.register_embed_templates::<EmailTemplates>(); 137 138 139 + let pds_base_url = 140 + env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 141 + 138 142 let state = AppState { 139 143 account_pool, 140 144 pds_gatekeeper_pool, 141 145 reverse_proxy_client: client, 142 - //TODO should be env prob 143 - pds_base_url: "http://localhost:3000".to_string(), 146 + pds_base_url, 144 147 mailer, 145 148 mailer_from: sent_from, 146 149 template_engine: Engine::from(hbs), ··· 148 151 149 152 // Rate limiting 150 153 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 151 - let governor_conf = GovernorConfigBuilder::default() 154 + let create_session_governor_conf = GovernorConfigBuilder::default() 155 + .per_second(60) 156 + .burst_size(5) 157 + .finish() 158 + .expect("failed to create governor config. this should not happen and is a bug"); 159 + 160 + // Create a second config with the same settings for the other endpoint 161 + let sign_in_governor_conf = GovernorConfigBuilder::default() 152 162 .per_second(60) 153 163 .burst_size(5) 154 164 .finish() 155 165 .expect("failed to create governor config. this should not happen and is a bug"); 156 166 157 - let governor_limiter = governor_conf.limiter().clone(); 167 + let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 168 + let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 158 169 let interval = Duration::from_secs(60); 159 170 // a separate background task to clean up 160 171 std::thread::spawn(move || { 161 172 loop { 162 173 std::thread::sleep(interval); 163 - tracing::info!("rate limiting storage size: {}", governor_limiter.len()); 164 - governor_limiter.retain_recent(); 174 + create_session_governor_limiter.retain_recent(); 175 + sign_in_governor_limiter.retain_recent(); 165 176 } 166 177 }); 167 178 ··· 182 193 ) 183 194 .route( 184 195 "/@atproto/oauth-provider/~api/sign-in", 185 - post(sign_in), // .layer(GovernorLayer::new(governor_conf.clone()))), 196 + post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 186 197 ) 187 198 .route( 188 199 "/xrpc/com.atproto.server.createSession", 189 - post(create_session.layer(GovernorLayer::new(governor_conf))), 200 + post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 190 201 ) 191 202 .layer(CompressionLayer::new()) 192 203 .layer(cors) 193 204 .with_state(state); 194 205 195 - let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 196 - let port: u16 = env::var("PORT") 206 + let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 207 + let port: u16 = env::var("GATEKEEPER_PORT") 197 208 .ok() 198 209 .and_then(|s| s.parse().ok()) 199 210 .unwrap_or(8080); ··· 210 221 .with_graceful_shutdown(shutdown_signal()); 211 222 212 223 if let Err(err) = server.await { 213 - error!(error = %err, "server error"); 224 + log::error!("server error:{err}"); 214 225 } 215 226 216 227 Ok(())
+37 -45
src/oauth_provider.rs
··· 11 11 use serde::{Deserialize, Serialize}; 12 12 use tracing::log; 13 13 14 - fn is_disallowed_header(name: &HeaderName) -> bool { 15 - // RFC 7230 hop-by-hop headers and other problematic ones for proxying a new body 16 - // Use lowercase comparison; HeaderName equality is case-insensitive but we compare by string for a set 17 - match name.as_str() { 18 - // Hop-by-hop 19 - "connection" 20 - | "keep-alive" 21 - | "proxy-authenticate" 22 - | "proxy-authorization" 23 - | "te" 24 - | "trailer" 25 - | "transfer-encoding" 26 - | "upgrade" => true, 27 - // Payload or routing related we should not forward 28 - "host" | "content-length" | "content-encoding" | "expect" => true, 29 - // Compression negotiation can interfere; let upstream decide defaults 30 - // We can drop Accept-Encoding to avoid getting compressed payloads if not needed 31 - "accept-encoding" => true, 32 - _ => false, 33 - } 34 - } 35 - 36 - fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { 37 - for (name, value) in src.iter() { 38 - if is_disallowed_header(name) { 39 - continue; 40 - } 41 - // Only copy valid headers 42 - if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { 43 - dst.insert(name.clone(), hv); 44 - } 45 - } 46 - } 47 - 48 14 #[derive(Serialize, Deserialize, Clone)] 49 15 pub struct SignInRequest { 50 16 pub username: String, ··· 64 30 let password = payload.password.clone(); 65 31 let auth_factor_token = payload.email_otp.clone(); 66 32 67 - //TODO need to pass in a flag to ignore app passwords for Oauth 68 - // Run the shared pre-auth logic to validate and check 2FA requirement 69 33 match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { 70 34 Ok(result) => match result { 71 35 AuthResult::WrongIdentityOrPassword => oauth_json_error_response( ··· 102 66 ); 103 67 104 68 let mut req = axum::http::Request::post(uri); 105 - // if let Some(cookie) = headers.get("Cookie") { 106 - // req = req.header("Cookie", cookie.to_str().unwrap()); 107 - // } 108 69 if let Some(req_headers) = req.headers_mut() { 109 - // Copy headers but remove hop-by-hop and problematic ones 70 + // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers 110 71 copy_filtered_headers(&headers, req_headers); 111 - // Ensure JSON content type is set explicitly 72 + //Setting the content type to application/json manually 112 73 req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 113 74 } 114 75 115 76 payload.email_otp = None; 116 - // let payload_bytes = 117 - // serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 118 - let body = serde_json::to_string(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 77 + let payload_bytes = 78 + serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 119 79 120 80 let req = req 121 - .body(Body::from(body)) 81 + .body(Body::from(payload_bytes)) 122 82 .map_err(|_| StatusCode::BAD_REQUEST)?; 123 83 124 84 let proxied = state ··· 155 115 } 156 116 } 157 117 } 118 + 119 + fn is_disallowed_header(name: &HeaderName) -> bool { 120 + // possible problematic headers with proxying 121 + matches!( 122 + name.as_str(), 123 + "connection" 124 + | "keep-alive" 125 + | "proxy-authenticate" 126 + | "proxy-authorization" 127 + | "te" 128 + | "trailer" 129 + | "transfer-encoding" 130 + | "upgrade" 131 + | "host" 132 + | "content-length" 133 + | "content-encoding" 134 + | "expect" 135 + | "accept-encoding" 136 + ) 137 + } 138 + 139 + fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { 140 + for (name, value) in src.iter() { 141 + if is_disallowed_header(name) { 142 + continue; 143 + } 144 + // Only copy valid headers 145 + if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { 146 + dst.insert(name.clone(), hv); 147 + } 148 + } 149 + }
+1 -2
src/xrpc/com_atproto_server.rs
··· 11 11 use serde::{Deserialize, Serialize}; 12 12 use serde_json; 13 13 use tracing::log; 14 - use tracing::log::log; 15 14 16 15 #[derive(Serialize, Deserialize, Debug, Clone)] 17 16 #[serde(rename_all = "camelCase")] ··· 65 64 pub async fn create_session( 66 65 State(state): State<AppState>, 67 66 headers: HeaderMap, 68 - Json(mut payload): extract::Json<CreateSessionRequest>, 67 + Json(payload): extract::Json<CreateSessionRequest>, 69 68 ) -> Result<Response<Body>, StatusCode> { 70 69 let identifier = payload.identifier.clone(); 71 70 let password = payload.password.clone();