Microservice to bring 2FA to self hosted PDSes
at main 12 kB view raw
1#![warn(clippy::unwrap_used)] 2use crate::oauth_provider::sign_in; 3use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email}; 4use axum::body::Body; 5use axum::handler::Handler; 6use axum::http::{Method, header}; 7use axum::middleware as ax_middleware; 8use axum::routing::post; 9use axum::{Router, routing::get}; 10use axum_template::engine::Engine; 11use handlebars::Handlebars; 12use hyper_util::client::legacy::connect::HttpConnector; 13use hyper_util::rt::TokioExecutor; 14use lettre::{AsyncSmtpTransport, Tokio1Executor}; 15use rust_embed::RustEmbed; 16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 18use std::path::Path; 19use std::time::Duration; 20use std::{env, net::SocketAddr}; 21use tower_governor::GovernorLayer; 22use tower_governor::governor::GovernorConfigBuilder; 23use tower_governor::key_extractor::SmartIpKeyExtractor; 24use tower_http::compression::CompressionLayer; 25use tower_http::cors::{Any, CorsLayer}; 26use tracing::log; 27use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 28 29pub mod helpers; 30mod middleware; 31mod oauth_provider; 32mod xrpc; 33 34type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 35 36#[derive(RustEmbed)] 37#[folder = "email_templates"] 38#[include = "*.hbs"] 39struct EmailTemplates; 40 41#[derive(Clone)] 42pub struct AppState { 43 account_pool: SqlitePool, 44 pds_gatekeeper_pool: SqlitePool, 45 reverse_proxy_client: HyperUtilClient, 46 pds_base_url: String, 47 mailer: AsyncSmtpTransport<Tokio1Executor>, 48 mailer_from: String, 49 template_engine: Engine<Handlebars<'static>>, 50} 51 52async fn root_handler() -> impl axum::response::IntoResponse { 53 let body = r" 54 55 ...oO _.--X~~OO~~X--._ ...oOO 56 _.-~ / \ II / \ ~-._ 57 [].-~ \ / \||/ \ / ~-.[] ...o 58 ...o _ ||/ \ / || \ / \|| _ 59 (_) |X X || X X| (_) 60 _-~-_ ||\ / \ || / \ /|| _-~-_ 61 ||||| || \ / \ /||\ / \ / || ||||| 62 | |_|| \ / \ / || \ / \ / ||_| | 63 | |~|| X X || X X ||~| | 64==============| | || / \ / \ || / \ / \ || | |============== 65______________| | || / \ / \||/ \ / \ || | |______________ 66 . . | | ||/ \ / || \ / \|| | | . . 67 / | | |X X || X X| | | / / 68 / . | | ||\ / \ || / \ /|| | | . / . 69. / | | || \ / \ /||\ / \ / || | | . . 70 . . | | || \ / \ / || \ / \ / || | | . 71 / | | || X X || X X || | | . / . / 72 / . | | || / \ / \ || / \ / \ || | | / 73 / | | || / \ / \||/ \ / \ || | | . / 74. . . | | ||/ \ / /||\ \ / \|| | | /. . 75 | |_|X X / II \ X X|_| | . . / 76==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 77 "; 78 79 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 80 81 let banner = format!(" {body}\n{intro}"); 82 83 ( 84 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 85 banner, 86 ) 87} 88 89#[tokio::main] 90async fn main() -> Result<(), Box<dyn std::error::Error>> { 91 setup_tracing(); 92 let pds_env_location = 93 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 94 95 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 96 if let Err(e) = result_of_finding_pds_env { 97 log::error!( 98 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 99 ); 100 } 101 102 let pds_root = 103 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 104 let account_db_url = format!("{pds_root}/account.sqlite"); 105 106 let account_options = SqliteConnectOptions::new() 107 .filename(account_db_url) 108 .busy_timeout(Duration::from_secs(5)); 109 110 let account_pool = SqlitePoolOptions::new() 111 .max_connections(5) 112 .connect_with(account_options) 113 .await?; 114 115 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 116 let options = SqliteConnectOptions::new() 117 .journal_mode(SqliteJournalMode::Wal) 118 .filename(bells_db_url) 119 .create_if_missing(true) 120 .busy_timeout(Duration::from_secs(5)); 121 let pds_gatekeeper_pool = SqlitePoolOptions::new() 122 .max_connections(5) 123 .connect_with(options) 124 .await?; 125 126 // Run migrations for the extra database 127 // Note: the migrations are embedded at compile time from the given directory 128 // sqlx 129 sqlx::migrate!("./migrations") 130 .run(&pds_gatekeeper_pool) 131 .await?; 132 133 let client: HyperUtilClient = 134 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 135 .build(HttpConnector::new()); 136 137 //Emailer set up 138 let smtp_url = 139 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 140 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS") 141 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 142 143 let mailer: AsyncSmtpTransport<Tokio1Executor> = 144 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 145 //Email templates setup 146 let mut hbs = Handlebars::new(); 147 148 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 149 if let Ok(users_email_directory) = users_email_directory { 150 hbs.register_template_file( 151 "two_factor_code.hbs", 152 format!("{users_email_directory}/two_factor_code.hbs"), 153 )?; 154 } else { 155 let _ = hbs.register_embed_templates::<EmailTemplates>(); 156 } 157 158 let pds_base_url = 159 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 160 161 let state = AppState { 162 account_pool, 163 pds_gatekeeper_pool, 164 reverse_proxy_client: client, 165 pds_base_url, 166 mailer, 167 mailer_from: sent_from, 168 template_engine: Engine::from(hbs), 169 }; 170 171 // Rate limiting 172 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 173 let create_session_governor_conf = GovernorConfigBuilder::default() 174 .per_second(60) 175 .burst_size(5) 176 .key_extractor(SmartIpKeyExtractor) 177 .finish() 178 .expect("failed to create governor config for create session. this should not happen and is a bug"); 179 180 // Create a second config with the same settings for the other endpoint 181 let sign_in_governor_conf = GovernorConfigBuilder::default() 182 .per_second(60) 183 .burst_size(5) 184 .key_extractor(SmartIpKeyExtractor) 185 .finish() 186 .expect( 187 "failed to create governor config for sign in. this should not happen and is a bug", 188 ); 189 190 let create_account_limiter_time: Option<String> = 191 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 192 let create_account_limiter_burst: Option<String> = 193 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 194 195 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 196 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 197 if create_account_limiter_time.is_some() { 198 let time = create_account_limiter_time 199 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 200 .parse::<u64>() 201 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 202 create_account_governor_conf.per_second(time); 203 } 204 205 if create_account_limiter_burst.is_some() { 206 let burst = create_account_limiter_burst 207 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 208 .parse::<u32>() 209 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 210 create_account_governor_conf.burst_size(burst); 211 } 212 213 let create_account_governor_conf = create_account_governor_conf 214 .key_extractor(SmartIpKeyExtractor) 215 .finish().expect( 216 "failed to create governor config for create account. this should not happen and is a bug", 217 ); 218 219 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 220 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 221 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 222 223 let interval = Duration::from_secs(60); 224 // a separate background task to clean up 225 std::thread::spawn(move || { 226 loop { 227 std::thread::sleep(interval); 228 create_session_governor_limiter.retain_recent(); 229 sign_in_governor_limiter.retain_recent(); 230 create_account_governor_limiter.retain_recent(); 231 } 232 }); 233 234 let cors = CorsLayer::new() 235 .allow_origin(Any) 236 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 237 .allow_headers(Any); 238 239 let app = Router::new() 240 .route("/", get(root_handler)) 241 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 242 .route( 243 "/xrpc/com.atproto.server.updateEmail", 244 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 245 ) 246 .route( 247 "/@atproto/oauth-provider/~api/sign-in", 248 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 249 ) 250 .route( 251 "/xrpc/com.atproto.server.createSession", 252 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 253 ) 254 .route( 255 "/xrpc/com.atproto.server.createAccount", 256 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 257 ) 258 .layer(CompressionLayer::new()) 259 .layer(cors) 260 .with_state(state); 261 262 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 263 let port: u16 = env::var("GATEKEEPER_PORT") 264 .ok() 265 .and_then(|s| s.parse().ok()) 266 .unwrap_or(8080); 267 let addr: SocketAddr = format!("{host}:{port}") 268 .parse() 269 .expect("valid socket address"); 270 271 let listener = tokio::net::TcpListener::bind(addr).await?; 272 273 let server = axum::serve( 274 listener, 275 app.into_make_service_with_connect_info::<SocketAddr>(), 276 ) 277 .with_graceful_shutdown(shutdown_signal()); 278 279 if let Err(err) = server.await { 280 log::error!("server error:{err}"); 281 } 282 283 Ok(()) 284} 285 286fn setup_tracing() { 287 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 288 tracing_subscriber::registry() 289 .with(env_filter) 290 .with(fmt::layer()) 291 .init(); 292} 293 294async fn shutdown_signal() { 295 // Wait for Ctrl+C 296 let ctrl_c = async { 297 tokio::signal::ctrl_c() 298 .await 299 .expect("failed to install Ctrl+C handler"); 300 }; 301 302 #[cfg(unix)] 303 let terminate = async { 304 use tokio::signal::unix::{SignalKind, signal}; 305 306 let mut sigterm = 307 signal(SignalKind::terminate()).expect("failed to install signal handler"); 308 sigterm.recv().await; 309 }; 310 311 #[cfg(not(unix))] 312 let terminate = std::future::pending::<()>(); 313 314 tokio::select! { 315 _ = ctrl_c => {}, 316 _ = terminate => {}, 317 } 318}