Microservice to bring 2FA to self hosted PDSes
at main 16 kB view raw
1#![warn(clippy::unwrap_used)] 2use crate::gate::{get_gate, post_gate}; 3use crate::oauth_provider::sign_in; 4use crate::xrpc::com_atproto_server::{ 5 create_account, create_session, describe_server, get_session, update_email, 6}; 7use axum::{ 8 Router, 9 body::Body, 10 handler::Handler, 11 http::{Method, header}, 12 middleware as ax_middleware, 13 routing::get, 14 routing::post, 15}; 16use axum_template::engine::Engine; 17use handlebars::Handlebars; 18use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; 19use jacquard_common::types::did::Did; 20use jacquard_identity::{PublicResolver, resolver::PlcSource}; 21use lettre::{AsyncSmtpTransport, Tokio1Executor}; 22use rand::Rng; 23use rust_embed::RustEmbed; 24use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; 25use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; 26use std::path::Path; 27use std::sync::Arc; 28use std::time::Duration; 29use std::{env, net::SocketAddr}; 30use tower_governor::{ 31 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, 32}; 33use tower_http::{ 34 compression::CompressionLayer, 35 cors::{Any, CorsLayer}, 36}; 37use tracing::log; 38use tracing_subscriber::{EnvFilter, fmt, prelude::*}; 39 40mod gate; 41pub mod helpers; 42mod middleware; 43mod oauth_provider; 44mod xrpc; 45 46type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>; 47 48#[derive(RustEmbed)] 49#[folder = "email_templates"] 50#[include = "*.hbs"] 51struct EmailTemplates; 52 53#[derive(RustEmbed)] 54#[folder = "html_templates"] 55#[include = "*.hbs"] 56struct HtmlTemplates; 57 58/// Mostly the env variables that are used in the app 59#[derive(Clone, Debug)] 60pub struct AppConfig { 61 pds_base_url: String, 62 mailer_from: String, 63 email_subject: String, 64 allow_only_migrations: bool, 65 use_captcha: bool, 66 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use 67 //that need to capture this redirect url for creating an account 68 default_successful_redirect_url: String, 69 pds_service_did: Did<'static>, 70 gate_jwe_key: Vec<u8>, 71 captcha_success_redirects: Vec<String>, 72} 73 74impl AppConfig { 75 pub fn new() -> Self { 76 let pds_base_url = 77 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); 78 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS") 79 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); 80 //Hack not my favorite, but it does work 81 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS") 82 .map(|val| val.parse::<bool>().unwrap_or(false)) 83 .unwrap_or(false); 84 85 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA") 86 .map(|val| val.parse::<bool>().unwrap_or(false)) 87 .unwrap_or(false); 88 89 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME 90 let pds_service_did = 91 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") { 92 Ok(pds_hostname) => format!("did:web:{}", pds_hostname), 93 Err(_) => { 94 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file") 95 } 96 }); 97 98 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT") 99 .unwrap_or("Sign in to Bluesky".to_string()); 100 101 // Load or generate JWE encryption key (32 bytes for AES-256) 102 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY") 103 .ok() 104 .and_then(|key_hex| hex::decode(key_hex).ok()) 105 .unwrap_or_else(|| { 106 // Generate a random 32-byte key if not provided 107 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect(); 108 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key)); 109 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins)."); 110 key 111 }); 112 113 if gate_jwe_key.len() != 32 { 114 panic!( 115 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption" 116 ); 117 } 118 119 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") { 120 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(), 121 Err(_) => { 122 vec![ 123 String::from("https://bsky.app"), 124 String::from("https://pdsmoover.com"), 125 String::from("https://blacksky.community"), 126 String::from("https://tektite.cc"), 127 ] 128 } 129 }; 130 131 AppConfig { 132 pds_base_url, 133 mailer_from, 134 email_subject, 135 allow_only_migrations, 136 use_captcha, 137 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT") 138 .unwrap_or("https://bsky.app".to_string()), 139 pds_service_did: pds_service_did 140 .parse() 141 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"), 142 gate_jwe_key, 143 captcha_success_redirects, 144 } 145 } 146} 147 148#[derive(Clone)] 149pub struct AppState { 150 account_pool: SqlitePool, 151 pds_gatekeeper_pool: SqlitePool, 152 reverse_proxy_client: HyperUtilClient, 153 mailer: AsyncSmtpTransport<Tokio1Executor>, 154 template_engine: Engine<Handlebars<'static>>, 155 resolver: Arc<PublicResolver>, 156 app_config: AppConfig, 157} 158 159async fn root_handler() -> impl axum::response::IntoResponse { 160 let body = r" 161 162 ...oO _.--X~~OO~~X--._ ...oOO 163 _.-~ / \ II / \ ~-._ 164 [].-~ \ / \||/ \ / ~-.[] ...o 165 ...o _ ||/ \ / || \ / \|| _ 166 (_) |X X || X X| (_) 167 _-~-_ ||\ / \ || / \ /|| _-~-_ 168 ||||| || \ / \ /||\ / \ / || ||||| 169 | |_|| \ / \ / || \ / \ / ||_| | 170 | |~|| X X || X X ||~| | 171==============| | || / \ / \ || / \ / \ || | |============== 172______________| | || / \ / \||/ \ / \ || | |______________ 173 . . | | ||/ \ / || \ / \|| | | . . 174 / | | |X X || X X| | | / / 175 / . | | ||\ / \ || / \ /|| | | . / . 176. / | | || \ / \ /||\ / \ / || | | . . 177 . . | | || \ / \ / || \ / \ / || | | . 178 / | | || X X || X X || | | . / . / 179 / . | | || / \ / \ || / \ / \ || | | / 180 / | | || / \ / \||/ \ / \ || | | . / 181. . . | | ||/ \ / /||\ \ / \|| | | /. . 182 | |_|X X / II \ X X|_| | . . / 183==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== 184 "; 185 186 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; 187 188 let banner = format!(" {body}\n{intro}"); 189 190 ( 191 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], 192 banner, 193 ) 194} 195 196#[tokio::main] 197async fn main() -> Result<(), Box<dyn std::error::Error>> { 198 setup_tracing(); 199 let pds_env_location = 200 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 201 202 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 203 if let Err(e) = result_of_finding_pds_env { 204 log::error!( 205 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 206 ); 207 } 208 209 let pds_root = 210 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 211 let account_db_url = format!("{pds_root}/account.sqlite"); 212 213 let account_options = SqliteConnectOptions::new() 214 .filename(account_db_url) 215 .busy_timeout(Duration::from_secs(5)); 216 217 let account_pool = SqlitePoolOptions::new() 218 .max_connections(5) 219 .connect_with(account_options) 220 .await?; 221 222 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); 223 let options = SqliteConnectOptions::new() 224 .journal_mode(SqliteJournalMode::Wal) 225 .filename(bells_db_url) 226 .create_if_missing(true) 227 .busy_timeout(Duration::from_secs(5)); 228 let pds_gatekeeper_pool = SqlitePoolOptions::new() 229 .max_connections(5) 230 .connect_with(options) 231 .await?; 232 233 // Run migrations for the extra database 234 // Note: the migrations are embedded at compile time from the given directory 235 // sqlx 236 sqlx::migrate!("./migrations") 237 .run(&pds_gatekeeper_pool) 238 .await?; 239 240 let client: HyperUtilClient = 241 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) 242 .build(HttpConnector::new()); 243 244 //Emailer set up 245 let smtp_url = 246 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); 247 248 let mailer: AsyncSmtpTransport<Tokio1Executor> = 249 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build(); 250 //Email templates setup 251 let mut hbs = Handlebars::new(); 252 253 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); 254 if let Ok(users_email_directory) = users_email_directory { 255 hbs.register_template_file( 256 "two_factor_code.hbs", 257 format!("{users_email_directory}/two_factor_code.hbs"), 258 )?; 259 } else { 260 let _ = hbs.register_embed_templates::<EmailTemplates>(); 261 } 262 263 let _ = hbs.register_embed_templates::<HtmlTemplates>(); 264 265 //Reads the PLC source from the pds env's or defaults to ol faithful 266 let plc_source_url = 267 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string()); 268 let plc_source = PlcSource::PlcDirectory { 269 base: plc_source_url.parse().unwrap(), 270 }; 271 let mut resolver = PublicResolver::default(); 272 resolver = resolver.with_plc_source(plc_source.clone()); 273 274 let state = AppState { 275 account_pool, 276 pds_gatekeeper_pool, 277 reverse_proxy_client: client, 278 mailer, 279 template_engine: Engine::from(hbs), 280 resolver: Arc::new(resolver), 281 app_config: AppConfig::new(), 282 }; 283 284 // Rate limiting 285 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. 286 let captcha_governor_conf = GovernorConfigBuilder::default() 287 .per_second(60) 288 .burst_size(5) 289 .key_extractor(SmartIpKeyExtractor) 290 .finish() 291 .expect("failed to create governor config for create session. this should not happen and is a bug"); 292 293 // Create a second config with the same settings for the other endpoint 294 let sign_in_governor_conf = GovernorConfigBuilder::default() 295 .per_second(60) 296 .burst_size(5) 297 .key_extractor(SmartIpKeyExtractor) 298 .finish() 299 .expect( 300 "failed to create governor config for sign in. this should not happen and is a bug", 301 ); 302 303 let create_account_limiter_time: Option<String> = 304 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 305 let create_account_limiter_burst: Option<String> = 306 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 307 308 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 309 let mut create_account_governor_conf = GovernorConfigBuilder::default(); 310 if create_account_limiter_time.is_some() { 311 let time = create_account_limiter_time 312 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 313 .parse::<u64>() 314 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 315 create_account_governor_conf.per_second(time); 316 } 317 318 if create_account_limiter_burst.is_some() { 319 let burst = create_account_limiter_burst 320 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 321 .parse::<u32>() 322 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 323 create_account_governor_conf.burst_size(burst); 324 } 325 326 let create_account_governor_conf = create_account_governor_conf 327 .key_extractor(SmartIpKeyExtractor) 328 .finish().expect( 329 "failed to create governor config for create account. this should not happen and is a bug", 330 ); 331 332 let captcha_governor_limiter = captcha_governor_conf.limiter().clone(); 333 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 334 let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 335 336 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf); 337 338 let interval = Duration::from_secs(60); 339 // a separate background task to clean up 340 std::thread::spawn(move || { 341 loop { 342 std::thread::sleep(interval); 343 captcha_governor_limiter.retain_recent(); 344 sign_in_governor_limiter.retain_recent(); 345 create_account_governor_limiter.retain_recent(); 346 } 347 }); 348 349 let cors = CorsLayer::new() 350 .allow_origin(Any) 351 .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) 352 .allow_headers(Any); 353 354 let mut app = Router::new() 355 .route("/", get(root_handler)) 356 .route("/xrpc/com.atproto.server.getSession", get(get_session)) 357 .route( 358 "/xrpc/com.atproto.server.describeServer", 359 get(describe_server), 360 ) 361 .route( 362 "/xrpc/com.atproto.server.updateEmail", 363 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), 364 ) 365 .route( 366 "/@atproto/oauth-provider/~api/sign-in", 367 post(sign_in).layer(sign_in_governor_layer.clone()), 368 ) 369 .route( 370 "/xrpc/com.atproto.server.createSession", 371 post(create_session.layer(sign_in_governor_layer)), 372 ) 373 .route( 374 "/xrpc/com.atproto.server.createAccount", 375 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 376 ); 377 378 if state.app_config.use_captcha { 379 app = app.route( 380 "/gate/signup", 381 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))), 382 ); 383 } 384 385 let app = app 386 .layer(CompressionLayer::new()) 387 .layer(cors) 388 .with_state(state); 389 390 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); 391 let port: u16 = env::var("GATEKEEPER_PORT") 392 .ok() 393 .and_then(|s| s.parse().ok()) 394 .unwrap_or(8080); 395 let addr: SocketAddr = format!("{host}:{port}") 396 .parse() 397 .expect("valid socket address"); 398 399 let listener = tokio::net::TcpListener::bind(addr).await?; 400 401 let server = axum::serve( 402 listener, 403 app.into_make_service_with_connect_info::<SocketAddr>(), 404 ) 405 .with_graceful_shutdown(shutdown_signal()); 406 407 if let Err(err) = server.await { 408 log::error!("server error:{err}"); 409 } 410 411 Ok(()) 412} 413 414fn setup_tracing() { 415 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); 416 tracing_subscriber::registry() 417 .with(env_filter) 418 .with(fmt::layer()) 419 .init(); 420} 421 422async fn shutdown_signal() { 423 // Wait for Ctrl+C 424 let ctrl_c = async { 425 tokio::signal::ctrl_c() 426 .await 427 .expect("failed to install Ctrl+C handler"); 428 }; 429 430 #[cfg(unix)] 431 let terminate = async { 432 use tokio::signal::unix::{SignalKind, signal}; 433 434 let mut sigterm = 435 signal(SignalKind::terminate()).expect("failed to install signal handler"); 436 sigterm.recv().await; 437 }; 438 439 #[cfg(not(unix))] 440 let terminate = std::future::pending::<()>(); 441 442 tokio::select! { 443 _ = ctrl_c => {}, 444 _ = terminate => {}, 445 } 446}