Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)] 2use crate::oauth_provider::sign_in; 3use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 4use axum::body::Body; 5use axum::handler::Handler; 6use axum::http::{Method, header}; 7use axum::middleware as ax_middleware; 8use axum::routing::post; 9use axum::{Router, routing::get}; 10use axum_template::engine::Engine; 11use handlebars::Handlebars; 12use hyper_util::client::legacy::connect::HttpConnector; 13use hyper_util::rt::TokioExecutor; 14use lettre::{AsyncSmtpTransport, Tokio1Executor}; 15use rust_embed::RustEmbed; 16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 18use std::path::Path; 19use std::time::Duration; 20use std::{env, net::SocketAddr}; 21use tower_governor::GovernorLayer; 22use tower_governor::governor::GovernorConfigBuilder; 23use tower_http::compression::CompressionLayer; 24use tower_http::cors::{Any, CorsLayer}; 25use tracing::log; 26use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 27 28pub mod helpers; 29mod middleware; 30mod oauth_provider; 31mod xrpc; 32 33type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 34 35#[derive(RustEmbed)] 36#[folder = "email_templates"] 37#[include = "*.hbs"] 38struct EmailTemplates; 39 40#[derive(Clone)] 41pub struct AppState { 42 account_pool: SqlitePool, 43 pds_gatekeeper_pool: SqlitePool, 44 reverse_proxy_client: HyperUtilClient, 45 pds_base_url: String, 46 mailer: AsyncSmtpTransport<Tokio1Executor>, 47 mailer_from: String, 48 template_engine: Engine<Handlebars<'static>>, 49} 50 51async fn root_handler() -> impl axum::response::IntoResponse { 52 let body = r" 53 54 ...oO _.--X~~OO~~X--._ ...oOO 55 _.-~ / \ II / \ ~-._ 56 [].-~ \ / \||/ \ / ~-.[] ...o 57 ...o _ ||/ \ / || \ / \|| _ 58 (_) |X X || X X| (_) 59 _-~-_ ||\ / \ || / \ /|| _-~-_ 60 ||||| || \ / \ /||\ / \ / || ||||| 61 | |_|| \ / \ / || \ / \ / ||_| | 62 | |~|| X X || X X ||~| | 63==============| | || / \ / \ || / \ / \ || | |============== 64______________| | || / \ / \||/ \ / \ || | |______________ 65 . . | | ||/ \ / || \ / \|| | | . . 66 / | | |X X || X X| | | / / 67 / . | | ||\ / \ || / \ /|| | | . / . 68. / | | || \ / \ /||\ / \ / || | | . . 69 . . | | || \ / \ / || \ / \ / || | | . 70 / | | || X X || X X || | | . / . / 71 / . | | || / \ / \ || / \ / \ || | | / 72 / | | || / \ / \||/ \ / \ || | | . / 73. . . | | ||/ \ / /||\ \ / \|| | | /. . 74 | |_|X X / II \ X X|_| | . . / 75==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 76 "; 77 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 79 80 let banner = format!(" {body}\n{intro}"); 81 82 ( 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 84 banner, 85 ) 86} 87 88#[tokio::main] 89async fn main() -> Result<(), Box<dyn std::error::Error>> { 90 setup_tracing(); 91 let pds_env_location = 92 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 93 94 dotenvy::from_path(Path::new(&pds_env_location))?; 95 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 96 let account_db_url = format!("{pds_root}/account.sqlite"); 97 98 let account_options = SqliteConnectOptions::new() 99 .filename(account_db_url) 100 .busy_timeout(Duration::from_secs(5)); 101 102 let account_pool = SqlitePoolOptions::new() 103 .max_connections(5) 104 .connect_with(account_options) 105 .await?; 106 107 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 108 let options = SqliteConnectOptions::new() 109 .journal_mode(SqliteJournalMode::Wal) 110 .filename(bells_db_url) 111 .create_if_missing(true) 112 .busy_timeout(Duration::from_secs(5)); 113 let pds_gatekeeper_pool = SqlitePoolOptions::new() 114 .max_connections(5) 115 .connect_with(options) 116 .await?; 117 118 // Run migrations for the extra database 119 // Note: the migrations are embedded at compile time from the given directory 120 // sqlx 121 sqlx::migrate!("./migrations") 122 .run(&pds_gatekeeper_pool) 123 .await?; 124 125 let client: HyperUtilClient = 126 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 127 .build(HttpConnector::new()); 128 129 //Emailer set up 130 let smtp_url = 131 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 132 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS") 133 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 134 135 let mailer: AsyncSmtpTransport<Tokio1Executor> = 136 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 137 //Email templates setup 138 let mut hbs = Handlebars::new(); 139 140 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 141 if let Ok(users_email_directory) = users_email_directory { 142 hbs.register_template_file( 143 "two_factor_code.hbs", 144 format!("{users_email_directory}/two_factor_code.hbs"), 145 )?; 146 } else { 147 let _ = hbs.register_embed_templates::<EmailTemplates>(); 148 } 149 150 let pds_base_url = 151 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 152 153 let state = AppState { 154 account_pool, 155 pds_gatekeeper_pool, 156 reverse_proxy_client: client, 157 pds_base_url, 158 mailer, 159 mailer_from: sent_from, 160 template_engine: Engine::from(hbs), 161 }; 162 163 // Rate limiting 164 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 165 let create_session_governor_conf = GovernorConfigBuilder::default() 166 .per_second(60) 167 .burst_size(5) 168 .finish() 169 .expect("failed to create governor config. this should not happen and is a bug"); 170 171 // Create a second config with the same settings for the other endpoint 172 let sign_in_governor_conf = GovernorConfigBuilder::default() 173 .per_second(60) 174 .burst_size(5) 175 .finish() 176 .expect("failed to create governor config. this should not happen and is a bug"); 177 178 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 179 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 180 let interval = Duration::from_secs(60); 181 // a separate background task to clean up 182 std::thread::spawn(move || { 183 loop { 184 std::thread::sleep(interval); 185 create_session_governor_limiter.retain_recent(); 186 sign_in_governor_limiter.retain_recent(); 187 } 188 }); 189 190 let cors = CorsLayer::new() 191 .allow_origin(Any) 192 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 193 .allow_headers(Any); 194 195 let app = Router::new() 196 .route("/", get(root_handler)) 197 .route( 198 "/xrpc/com.atproto.server.getSession", 199 get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)), 200 ) 201 .route( 202 "/xrpc/com.atproto.server.updateEmail", 203 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 204 ) 205 .route( 206 "/@atproto/oauth-provider/~api/sign-in", 207 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 208 ) 209 .route( 210 "/xrpc/com.atproto.server.createSession", 211 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 212 ) 213 .layer(CompressionLayer::new()) 214 .layer(cors) 215 .with_state(state); 216 217 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 218 let port: u16 = env::var("GATEKEEPER_PORT") 219 .ok() 220 .and_then(|s| s.parse().ok()) 221 .unwrap_or(8080); 222 let addr: SocketAddr = format!("{host}:{port}") 223 .parse() 224 .expect("valid socket address"); 225 226 let listener = tokio::net::TcpListener::bind(addr).await?; 227 228 let server = axum::serve( 229 listener, 230 app.into_make_service_with_connect_info::<SocketAddr>(), 231 ) 232 .with_graceful_shutdown(shutdown_signal()); 233 234 if let Err(err) = server.await { 235 log::error!("server error:{err}"); 236 } 237 238 Ok(()) 239} 240 241fn setup_tracing() { 242 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 243 tracing_subscriber::registry() 244 .with(env_filter) 245 .with(fmt::layer()) 246 .init(); 247} 248 249async fn shutdown_signal() { 250 // Wait for Ctrl+C 251 let ctrl_c = async { 252 tokio::signal::ctrl_c() 253 .await 254 .expect("failed to install Ctrl+C handler"); 255 }; 256 257 #[cfg(unix)] 258 let terminate = async { 259 use tokio::signal::unix::{SignalKind, signal}; 260 261 let mut sigterm = 262 signal(SignalKind::terminate()).expect("failed to install signal handler"); 263 sigterm.recv().await; 264 }; 265 266 #[cfg(not(unix))] 267 let terminate = std::future::pending::<()>(); 268 269 tokio::select! { 270 _ = ctrl_c => {}, 271 _ = terminate => {}, 272 } 273}