Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)] 2use crate::oauth_provider::sign_in; 3use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email}; 4use axum::body::Body; 5use axum::handler::Handler; 6use axum::http::{Method, header}; 7use axum::middleware as ax_middleware; 8use axum::routing::post; 9use axum::{Router, routing::get}; 10use axum_template::engine::Engine; 11use handlebars::Handlebars; 12use hyper_util::client::legacy::connect::HttpConnector; 13use hyper_util::rt::TokioExecutor; 14use lettre::{AsyncSmtpTransport, Tokio1Executor}; 15use rust_embed::RustEmbed; 16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 18use std::path::Path; 19use std::time::Duration; 20use std::{env, net::SocketAddr}; 21use tower_governor::GovernorLayer; 22use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder}; 23use tower_http::compression::CompressionLayer; 24use tower_http::cors::{Any, CorsLayer}; 25use tracing::log; 26use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 27 28pub mod helpers; 29mod middleware; 30mod oauth_provider; 31mod xrpc; 32 33type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 34 35#[derive(RustEmbed)] 36#[folder = "email_templates"] 37#[include = "*.hbs"] 38struct EmailTemplates; 39 40#[derive(Clone)] 41pub struct AppState { 42 account_pool: SqlitePool, 43 pds_gatekeeper_pool: SqlitePool, 44 reverse_proxy_client: HyperUtilClient, 45 pds_base_url: String, 46 mailer: AsyncSmtpTransport<Tokio1Executor>, 47 mailer_from: String, 48 template_engine: Engine<Handlebars<'static>>, 49} 50 51async fn root_handler() -> impl axum::response::IntoResponse { 52 let body = r" 53 54 ...oO _.--X~~OO~~X--._ ...oOO 55 _.-~ / \ II / \ ~-._ 56 [].-~ \ / \||/ \ / ~-.[] ...o 57 ...o _ ||/ \ / || \ / \|| _ 58 (_) |X X || X X| (_) 59 _-~-_ ||\ / \ || / \ /|| _-~-_ 60 ||||| || \ / \ /||\ / \ / || ||||| 61 | |_|| \ / \ / || \ / \ / ||_| | 62 | |~|| X X || X X ||~| | 63==============| | || / \ / \ || / \ / \ || | |============== 64______________| | || / \ / \||/ \ / \ || | |______________ 65 . . | | ||/ \ / || \ / \|| | | . . 66 / | | |X X || X X| | | / / 67 / . | | ||\ / \ || / \ /|| | | . / . 68. / | | || \ / \ /||\ / \ / || | | . . 69 . . | | || \ / \ / || \ / \ / || | | . 70 / | | || X X || X X || | | . / . / 71 / . | | || / \ / \ || / \ / \ || | | / 72 / | | || / \ / \||/ \ / \ || | | . / 73. . . | | ||/ \ / /||\ \ / \|| | | /. . 74 | |_|X X / II \ X X|_| | . . / 75==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 76 "; 77 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 79 80 let banner = format!(" {body}\n{intro}"); 81 82 ( 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 84 banner, 85 ) 86} 87 88#[tokio::main] 89async fn main() -> Result<(), Box<dyn std::error::Error>> { 90 setup_tracing(); 91 let pds_env_location = 92 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 93 94 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 95 if let Err(e) = result_of_finding_pds_env { 96 log::error!( 97 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 98 ); 99 } 100 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 101 let account_db_url = format!("{pds_root}/account.sqlite"); 102 103 let account_options = SqliteConnectOptions::new() 104 .filename(account_db_url) 105 .busy_timeout(Duration::from_secs(5)); 106 107 let account_pool = SqlitePoolOptions::new() 108 .max_connections(5) 109 .connect_with(account_options) 110 .await?; 111 112 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 113 let options = SqliteConnectOptions::new() 114 .journal_mode(SqliteJournalMode::Wal) 115 .filename(bells_db_url) 116 .create_if_missing(true) 117 .busy_timeout(Duration::from_secs(5)); 118 let pds_gatekeeper_pool = SqlitePoolOptions::new() 119 .max_connections(5) 120 .connect_with(options) 121 .await?; 122 123 // Run migrations for the extra database 124 // Note: the migrations are embedded at compile time from the given directory 125 // sqlx 126 sqlx::migrate!("./migrations") 127 .run(&pds_gatekeeper_pool) 128 .await?; 129 130 let client: HyperUtilClient = 131 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 132 .build(HttpConnector::new()); 133 134 //Emailer set up 135 let smtp_url = 136 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 137 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS") 138 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 139 140 let mailer: AsyncSmtpTransport<Tokio1Executor> = 141 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 142 //Email templates setup 143 let mut hbs = Handlebars::new(); 144 145 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 146 if let Ok(users_email_directory) = users_email_directory { 147 hbs.register_template_file( 148 "two_factor_code.hbs", 149 format!("{users_email_directory}/two_factor_code.hbs"), 150 )?; 151 } else { 152 let _ = hbs.register_embed_templates::<EmailTemplates>(); 153 } 154 155 let pds_base_url = 156 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 157 158 let state = AppState { 159 account_pool, 160 pds_gatekeeper_pool, 161 reverse_proxy_client: client, 162 pds_base_url, 163 mailer, 164 mailer_from: sent_from, 165 template_engine: Engine::from(hbs), 166 }; 167 168 // Rate limiting 169 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 170 let create_session_governor_conf = GovernorConfigBuilder::default() 171 .per_second(60) 172 .burst_size(5) 173 .finish() 174 .expect("failed to create governor config for create session. this should not happen and is a bug"); 175 176 // Create a second config with the same settings for the other endpoint 177 let sign_in_governor_conf = GovernorConfigBuilder::default() 178 .per_second(60) 179 .burst_size(5) 180 .finish() 181 .expect( 182 "failed to create governor config for sign in. this should not happen and is a bug", 183 ); 184 185 let create_account_limiter_time: Option<String> = 186 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 187 let create_account_limiter_burst: Option<String> = 188 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 189 190 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 191 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 192 if create_account_limiter_time.is_some() { 193 let time = create_account_limiter_time 194 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 195 .parse::<u64>() 196 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 197 create_account_governor_conf.per_second(time); 198 } 199 200 if create_account_limiter_burst.is_some() { 201 let burst = create_account_limiter_burst 202 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 203 .parse::<u32>() 204 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 205 create_account_governor_conf.burst_size(burst); 206 } 207 208 let create_account_governor_conf = create_account_governor_conf.finish().expect( 209 "failed to create governor config for create account. this should not happen and is a bug", 210 ); 211 212 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 213 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 214 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 215 216 let interval = Duration::from_secs(60); 217 // a separate background task to clean up 218 std::thread::spawn(move || { 219 loop { 220 std::thread::sleep(interval); 221 create_session_governor_limiter.retain_recent(); 222 sign_in_governor_limiter.retain_recent(); 223 create_account_governor_limiter.retain_recent(); 224 } 225 }); 226 227 let cors = CorsLayer::new() 228 .allow_origin(Any) 229 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 230 .allow_headers(Any); 231 232 let app = Router::new() 233 .route("/", get(root_handler)) 234 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 235 .route( 236 "/xrpc/com.atproto.server.updateEmail", 237 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 238 ) 239 .route( 240 "/@atproto/oauth-provider/~api/sign-in", 241 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 242 ) 243 .route( 244 "/xrpc/com.atproto.server.createSession", 245 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 246 ) 247 .route( 248 "/xrpc/com.atproto.server.createAccount", 249 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 250 ) 251 .layer(CompressionLayer::new()) 252 .layer(cors) 253 .with_state(state); 254 255 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 256 let port: u16 = env::var("GATEKEEPER_PORT") 257 .ok() 258 .and_then(|s| s.parse().ok()) 259 .unwrap_or(8080); 260 let addr: SocketAddr = format!("{host}:{port}") 261 .parse() 262 .expect("valid socket address"); 263 264 let listener = tokio::net::TcpListener::bind(addr).await?; 265 266 let server = axum::serve( 267 listener, 268 app.into_make_service_with_connect_info::<SocketAddr>(), 269 ) 270 .with_graceful_shutdown(shutdown_signal()); 271 272 if let Err(err) = server.await { 273 log::error!("server error:{err}"); 274 } 275 276 Ok(()) 277} 278 279fn setup_tracing() { 280 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 281 tracing_subscriber::registry() 282 .with(env_filter) 283 .with(fmt::layer()) 284 .init(); 285} 286 287async fn shutdown_signal() { 288 // Wait for Ctrl+C 289 let ctrl_c = async { 290 tokio::signal::ctrl_c() 291 .await 292 .expect("failed to install Ctrl+C handler"); 293 }; 294 295 #[cfg(unix)] 296 let terminate = async { 297 use tokio::signal::unix::{SignalKind, signal}; 298 299 let mut sigterm = 300 signal(SignalKind::terminate()).expect("failed to install signal handler"); 301 sigterm.recv().await; 302 }; 303 304 #[cfg(not(unix))] 305 let terminate = std::future::pending::<()>(); 306 307 tokio::select! { 308 _ = ctrl_c => {}, 309 _ = terminate => {}, 310 } 311}