Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)] 2use crate::oauth_provider::sign_in; 3use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email}; 4use axum::body::Body; 5use axum::handler::Handler; 6use axum::http::{Method, header}; 7use axum::middleware as ax_middleware; 8use axum::routing::post; 9use axum::{Router, routing::get}; 10use axum_template::engine::Engine; 11use handlebars::Handlebars; 12use hyper_util::client::legacy::connect::HttpConnector; 13use hyper_util::rt::TokioExecutor; 14use lettre::{AsyncSmtpTransport, Tokio1Executor}; 15use rust_embed::RustEmbed; 16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 18use std::path::Path; 19use std::time::Duration; 20use std::{env, net::SocketAddr}; 21use tower_governor::GovernorLayer; 22use tower_governor::governor::GovernorConfigBuilder; 23use tower_http::compression::CompressionLayer; 24use tower_http::cors::{Any, CorsLayer}; 25use tracing::log; 26use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 27 28pub mod helpers; 29mod middleware; 30mod oauth_provider; 31mod xrpc; 32 33type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 34 35#[derive(RustEmbed)] 36#[folder = "email_templates"] 37#[include = "*.hbs"] 38struct EmailTemplates; 39 40#[derive(Clone)] 41pub struct AppState { 42 account_pool: SqlitePool, 43 pds_gatekeeper_pool: SqlitePool, 44 reverse_proxy_client: HyperUtilClient, 45 pds_base_url: String, 46 mailer: AsyncSmtpTransport<Tokio1Executor>, 47 mailer_from: String, 48 template_engine: Engine<Handlebars<'static>>, 49} 50 51async fn root_handler() -> impl axum::response::IntoResponse { 52 let body = r" 53 54 ...oO _.--X~~OO~~X--._ ...oOO 55 _.-~ / \ II / \ ~-._ 56 [].-~ \ / \||/ \ / ~-.[] ...o 57 ...o _ ||/ \ / || \ / \|| _ 58 (_) |X X || X X| (_) 59 _-~-_ ||\ / \ || / \ /|| _-~-_ 60 ||||| || \ / \ /||\ / \ / || ||||| 61 | |_|| \ / \ / || \ / \ / ||_| | 62 | |~|| X X || X X ||~| | 63==============| | || / \ / \ || / \ / \ || | |============== 64______________| | || / \ / \||/ \ / \ || | |______________ 65 . . | | ||/ \ / || \ / \|| | | . . 66 / | | |X X || X X| | | / / 67 / . | | ||\ / \ || / \ /|| | | . / . 68. / | | || \ / \ /||\ / \ / || | | . . 69 . . | | || \ / \ / || \ / \ / || | | . 70 / | | || X X || X X || | | . / . / 71 / . | | || / \ / \ || / \ / \ || | | / 72 / | | || / \ / \||/ \ / \ || | | . / 73. . . | | ||/ \ / /||\ \ / \|| | | /. . 74 | |_|X X / II \ X X|_| | . . / 75==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 76 "; 77 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 79 80 let banner = format!(" {body}\n{intro}"); 81 82 ( 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 84 banner, 85 ) 86} 87 88#[tokio::main] 89async fn main() -> Result<(), Box<dyn std::error::Error>> { 90 setup_tracing(); 91 let pds_env_location = 92 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 93 94 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 95 if let Err(e) = result_of_finding_pds_env { 96 log::error!( 97 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 98 ); 99 } 100 101 let pds_root = 102 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 103 let account_db_url = format!("{pds_root}/account.sqlite"); 104 105 let account_options = SqliteConnectOptions::new() 106 .filename(account_db_url) 107 .busy_timeout(Duration::from_secs(5)); 108 109 let account_pool = SqlitePoolOptions::new() 110 .max_connections(5) 111 .connect_with(account_options) 112 .await?; 113 114 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 115 let options = SqliteConnectOptions::new() 116 .journal_mode(SqliteJournalMode::Wal) 117 .filename(bells_db_url) 118 .create_if_missing(true) 119 .busy_timeout(Duration::from_secs(5)); 120 let pds_gatekeeper_pool = SqlitePoolOptions::new() 121 .max_connections(5) 122 .connect_with(options) 123 .await?; 124 125 // Run migrations for the extra database 126 // Note: the migrations are embedded at compile time from the given directory 127 // sqlx 128 sqlx::migrate!("./migrations") 129 .run(&pds_gatekeeper_pool) 130 .await?; 131 132 let client: HyperUtilClient = 133 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 134 .build(HttpConnector::new()); 135 136 //Emailer set up 137 let smtp_url = 138 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 139 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS") 140 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 141 142 let mailer: AsyncSmtpTransport<Tokio1Executor> = 143 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 144 //Email templates setup 145 let mut hbs = Handlebars::new(); 146 147 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 148 if let Ok(users_email_directory) = users_email_directory { 149 hbs.register_template_file( 150 "two_factor_code.hbs", 151 format!("{users_email_directory}/two_factor_code.hbs"), 152 )?; 153 } else { 154 let _ = hbs.register_embed_templates::<EmailTemplates>(); 155 } 156 157 let pds_base_url = 158 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 159 160 let state = AppState { 161 account_pool, 162 pds_gatekeeper_pool, 163 reverse_proxy_client: client, 164 pds_base_url, 165 mailer, 166 mailer_from: sent_from, 167 template_engine: Engine::from(hbs), 168 }; 169 170 // Rate limiting 171 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 172 let create_session_governor_conf = GovernorConfigBuilder::default() 173 .per_second(60) 174 .burst_size(5) 175 .finish() 176 .expect("failed to create governor config for create session. this should not happen and is a bug"); 177 178 // Create a second config with the same settings for the other endpoint 179 let sign_in_governor_conf = GovernorConfigBuilder::default() 180 .per_second(60) 181 .burst_size(5) 182 .finish() 183 .expect( 184 "failed to create governor config for sign in. this should not happen and is a bug", 185 ); 186 187 let create_account_limiter_time: Option<String> = 188 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 189 let create_account_limiter_burst: Option<String> = 190 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 191 192 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 193 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 194 if create_account_limiter_time.is_some() { 195 let time = create_account_limiter_time 196 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 197 .parse::<u64>() 198 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 199 create_account_governor_conf.per_second(time); 200 } 201 202 if create_account_limiter_burst.is_some() { 203 let burst = create_account_limiter_burst 204 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 205 .parse::<u32>() 206 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 207 create_account_governor_conf.burst_size(burst); 208 } 209 210 let create_account_governor_conf = create_account_governor_conf.finish().expect( 211 "failed to create governor config for create account. this should not happen and is a bug", 212 ); 213 214 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 215 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 216 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 217 218 let interval = Duration::from_secs(60); 219 // a separate background task to clean up 220 std::thread::spawn(move || { 221 loop { 222 std::thread::sleep(interval); 223 create_session_governor_limiter.retain_recent(); 224 sign_in_governor_limiter.retain_recent(); 225 create_account_governor_limiter.retain_recent(); 226 } 227 }); 228 229 let cors = CorsLayer::new() 230 .allow_origin(Any) 231 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 232 .allow_headers(Any); 233 234 let app = Router::new() 235 .route("/", get(root_handler)) 236 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 237 .route( 238 "/xrpc/com.atproto.server.updateEmail", 239 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 240 ) 241 .route( 242 "/@atproto/oauth-provider/~api/sign-in", 243 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 244 ) 245 .route( 246 "/xrpc/com.atproto.server.createSession", 247 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 248 ) 249 .route( 250 "/xrpc/com.atproto.server.createAccount", 251 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 252 ) 253 .layer(CompressionLayer::new()) 254 .layer(cors) 255 .with_state(state); 256 257 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 258 let port: u16 = env::var("GATEKEEPER_PORT") 259 .ok() 260 .and_then(|s| s.parse().ok()) 261 .unwrap_or(8080); 262 let addr: SocketAddr = format!("{host}:{port}") 263 .parse() 264 .expect("valid socket address"); 265 266 let listener = tokio::net::TcpListener::bind(addr).await?; 267 268 let server = axum::serve( 269 listener, 270 app.into_make_service_with_connect_info::<SocketAddr>(), 271 ) 272 .with_graceful_shutdown(shutdown_signal()); 273 274 if let Err(err) = server.await { 275 log::error!("server error:{err}"); 276 } 277 278 Ok(()) 279} 280 281fn setup_tracing() { 282 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 283 tracing_subscriber::registry() 284 .with(env_filter) 285 .with(fmt::layer()) 286 .init(); 287} 288 289async fn shutdown_signal() { 290 // Wait for Ctrl+C 291 let ctrl_c = async { 292 tokio::signal::ctrl_c() 293 .await 294 .expect("failed to install Ctrl+C handler"); 295 }; 296 297 #[cfg(unix)] 298 let terminate = async { 299 use tokio::signal::unix::{SignalKind, signal}; 300 301 let mut sigterm = 302 signal(SignalKind::terminate()).expect("failed to install signal handler"); 303 sigterm.recv().await; 304 }; 305 306 #[cfg(not(unix))] 307 let terminate = std::future::pending::<()>(); 308 309 tokio::select! { 310 _ = ctrl_c => {}, 311 _ = terminate => {}, 312 } 313}