Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)] 2use crate::oauth_provider::sign_in; 3use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 4use axum::body::Body; 5use axum::handler::Handler; 6use axum::http::{Method, header}; 7use axum::middleware as ax_middleware; 8use axum::routing::post; 9use axum::{Router, routing::get}; 10use axum_template::engine::Engine; 11use handlebars::Handlebars; 12use hyper_util::client::legacy::connect::HttpConnector; 13use hyper_util::rt::TokioExecutor; 14use lettre::{AsyncSmtpTransport, Tokio1Executor}; 15use rust_embed::RustEmbed; 16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 18use std::path::Path; 19use std::time::Duration; 20use std::{env, net::SocketAddr}; 21use tower_governor::GovernorLayer; 22use tower_governor::governor::GovernorConfigBuilder; 23use tower_http::compression::CompressionLayer; 24use tower_http::cors::{Any, CorsLayer}; 25use tracing::log; 26use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 27 28pub mod helpers; 29mod middleware; 30mod oauth_provider; 31mod xrpc; 32 33type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 34 35#[derive(RustEmbed)] 36#[folder = "email_templates"] 37#[include = "*.hbs"] 38struct EmailTemplates; 39 40#[derive(Clone)] 41pub struct AppState { 42 account_pool: SqlitePool, 43 pds_gatekeeper_pool: SqlitePool, 44 reverse_proxy_client: HyperUtilClient, 45 pds_base_url: String, 46 mailer: AsyncSmtpTransport<Tokio1Executor>, 47 mailer_from: String, 48 template_engine: Engine<Handlebars<'static>>, 49} 50 51async fn root_handler() -> impl axum::response::IntoResponse { 52 let body = r" 53 54 ...oO _.--X~~OO~~X--._ ...oOO 55 _.-~ / \ II / \ ~-._ 56 [].-~ \ / \||/ \ / ~-.[] ...o 57 ...o _ ||/ \ / || \ / \|| _ 58 (_) |X X || X X| (_) 59 _-~-_ ||\ / \ || / \ /|| _-~-_ 60 ||||| || \ / \ /||\ / \ / || ||||| 61 | |_|| \ / \ / || \ / \ / ||_| | 62 | |~|| X X || X X ||~| | 63==============| | || / \ / \ || / \ / \ || | |============== 64______________| | || / \ / \||/ \ / \ || | |______________ 65 . . | | ||/ \ / || \ / \|| | | . . 66 / | | |X X || X X| | | / / 67 / . | | ||\ / \ || / \ /|| | | . / . 68. / | | || \ / \ /||\ / \ / || | | . . 69 . . | | || \ / \ / || \ / \ / || | | . 70 / | | || X X || X X || | | . / . / 71 / . | | || / \ / \ || / \ / \ || | | / 72 / | | || / \ / \||/ \ / \ || | | . / 73. . . | | ||/ \ / /||\ \ / \|| | | /. . 74 | |_|X X / II \ X X|_| | . . / 75==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 76 "; 77 78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 79 80 let banner = format!(" {body}\n{intro}"); 81 82 ( 83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 84 banner, 85 ) 86} 87 88#[tokio::main] 89async fn main() -> Result<(), Box<dyn std::error::Error>> { 90 setup_tracing(); 91 //TODO may need to change where this reads from? Like an env variable for it's location? Or arg? 92 dotenvy::from_path(Path::new("./pds.env"))?; 93 let pds_root = env::var("PDS_DATA_DIRECTORY")?; 94 let account_db_url = format!("{pds_root}/account.sqlite"); 95 96 let account_options = SqliteConnectOptions::new() 97 .filename(account_db_url) 98 .busy_timeout(Duration::from_secs(5)); 99 100 let account_pool = SqlitePoolOptions::new() 101 .max_connections(5) 102 .connect_with(account_options) 103 .await?; 104 105 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 106 let options = SqliteConnectOptions::new() 107 .journal_mode(SqliteJournalMode::Wal) 108 .filename(bells_db_url) 109 .create_if_missing(true) 110 .busy_timeout(Duration::from_secs(5)); 111 let pds_gatekeeper_pool = SqlitePoolOptions::new() 112 .max_connections(5) 113 .connect_with(options) 114 .await?; 115 116 // Run migrations for the extra database 117 // Note: the migrations are embedded at compile time from the given directory 118 // sqlx 119 sqlx::migrate!("./migrations") 120 .run(&pds_gatekeeper_pool) 121 .await?; 122 123 let client: HyperUtilClient = 124 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 125 .build(HttpConnector::new()); 126 127 //Emailer set up 128 let smtp_url = 129 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 130 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS") 131 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 132 let mailer: AsyncSmtpTransport<Tokio1Executor> = 133 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 134 //Email templates setup 135 let mut hbs = Handlebars::new(); 136 137 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 138 if let Ok(users_email_directory) = users_email_directory { 139 hbs.register_template_file( 140 "two_factor_code.hbs", 141 format!("{users_email_directory}/two_factor_code.hbs"), 142 )?; 143 } else { 144 let _ = hbs.register_embed_templates::<EmailTemplates>(); 145 } 146 147 let pds_base_url = 148 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 149 150 let state = AppState { 151 account_pool, 152 pds_gatekeeper_pool, 153 reverse_proxy_client: client, 154 pds_base_url, 155 mailer, 156 mailer_from: sent_from, 157 template_engine: Engine::from(hbs), 158 }; 159 160 // Rate limiting 161 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 162 let create_session_governor_conf = GovernorConfigBuilder::default() 163 .per_second(60) 164 .burst_size(5) 165 .finish() 166 .expect("failed to create governor config. this should not happen and is a bug"); 167 168 // Create a second config with the same settings for the other endpoint 169 let sign_in_governor_conf = GovernorConfigBuilder::default() 170 .per_second(60) 171 .burst_size(5) 172 .finish() 173 .expect("failed to create governor config. this should not happen and is a bug"); 174 175 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 176 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 177 let interval = Duration::from_secs(60); 178 // a separate background task to clean up 179 std::thread::spawn(move || { 180 loop { 181 std::thread::sleep(interval); 182 create_session_governor_limiter.retain_recent(); 183 sign_in_governor_limiter.retain_recent(); 184 } 185 }); 186 187 let cors = CorsLayer::new() 188 .allow_origin(Any) 189 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 190 .allow_headers(Any); 191 192 let app = Router::new() 193 .route("/", get(root_handler)) 194 .route( 195 "/xrpc/com.atproto.server.getSession", 196 get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)), 197 ) 198 .route( 199 "/xrpc/com.atproto.server.updateEmail", 200 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 201 ) 202 .route( 203 "/@atproto/oauth-provider/~api/sign-in", 204 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)), 205 ) 206 .route( 207 "/xrpc/com.atproto.server.createSession", 208 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 209 ) 210 .layer(CompressionLayer::new()) 211 .layer(cors) 212 .with_state(state); 213 214 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 215 let port: u16 = env::var("GATEKEEPER_PORT") 216 .ok() 217 .and_then(|s| s.parse().ok()) 218 .unwrap_or(8080); 219 let addr: SocketAddr = format!("{host}:{port}") 220 .parse() 221 .expect("valid socket address"); 222 223 let listener = tokio::net::TcpListener::bind(addr).await?; 224 225 let server = axum::serve( 226 listener, 227 app.into_make_service_with_connect_info::<SocketAddr>(), 228 ) 229 .with_graceful_shutdown(shutdown_signal()); 230 231 if let Err(err) = server.await { 232 log::error!("server error:{err}"); 233 } 234 235 Ok(()) 236} 237 238fn setup_tracing() { 239 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 240 tracing_subscriber::registry() 241 .with(env_filter) 242 .with(fmt::layer()) 243 .init(); 244} 245 246async fn shutdown_signal() { 247 // Wait for Ctrl+C 248 let ctrl_c = async { 249 tokio::signal::ctrl_c() 250 .await 251 .expect("failed to install Ctrl+C handler"); 252 }; 253 254 #[cfg(unix)] 255 let terminate = async { 256 use tokio::signal::unix::{SignalKind, signal}; 257 258 let mut sigterm = 259 signal(SignalKind::terminate()).expect("failed to install signal handler"); 260 sigterm.recv().await; 261 }; 262 263 #[cfg(not(unix))] 264 let terminate = std::future::pending::<()>(); 265 266 tokio::select! { 267 _ = ctrl_c => {}, 268 _ = terminate => {}, 269 } 270}