#![warn(clippy::unwrap_used)] use crate::gate::{get_gate, post_gate}; use crate::oauth_provider::sign_in; use crate::xrpc::com_atproto_server::{ create_account, create_session, describe_server, get_session, update_email, }; use axum::{ Router, body::Body, handler::Handler, http::{Method, header}, middleware as ax_middleware, routing::get, routing::post, }; use axum_template::engine::Engine; use handlebars::Handlebars; use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor}; use jacquard_common::types::did::Did; use jacquard_identity::{PublicResolver, resolver::PlcSource}; use lettre::{AsyncSmtpTransport, Tokio1Executor}; use rand::Rng; use rust_embed::RustEmbed; use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode}; use sqlx::{SqlitePool, sqlite::SqlitePoolOptions}; use std::path::Path; use std::sync::Arc; use std::time::Duration; use std::{env, net::SocketAddr}; use tower_governor::{ GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor, }; use tower_http::{ compression::CompressionLayer, cors::{Any, CorsLayer}, }; use tracing::log; use tracing_subscriber::{EnvFilter, fmt, prelude::*}; mod gate; pub mod helpers; mod middleware; mod oauth_provider; mod xrpc; type HyperUtilClient = hyper_util::client::legacy::Client; #[derive(RustEmbed)] #[folder = "email_templates"] #[include = "*.hbs"] struct EmailTemplates; #[derive(RustEmbed)] #[folder = "html_templates"] #[include = "*.hbs"] struct HtmlTemplates; /// Mostly the env variables that are used in the app #[derive(Clone, Debug)] pub struct AppConfig { pds_base_url: String, mailer_from: String, email_subject: String, allow_only_migrations: bool, use_captcha: bool, //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use //that need to capture this redirect url for creating an account default_successful_redirect_url: String, pds_service_did: Did<'static>, gate_jwe_key: Vec, captcha_success_redirects: Vec, } impl AppConfig { pub fn new() -> Self { let pds_base_url = env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string()); let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS") .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file"); //Hack not my favorite, but it does work let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS") .map(|val| val.parse::().unwrap_or(false)) .unwrap_or(false); let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA") .map(|val| val.parse::().unwrap_or(false)) .unwrap_or(false); // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME let pds_service_did = env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") { Ok(pds_hostname) => format!("did:web:{}", pds_hostname), Err(_) => { panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file") } }); let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT") .unwrap_or("Sign in to Bluesky".to_string()); // Load or generate JWE encryption key (32 bytes for AES-256) let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY") .ok() .and_then(|key_hex| hex::decode(key_hex).ok()) .unwrap_or_else(|| { // Generate a random 32-byte key if not provided let key: Vec = (0..32).map(|_| rand::rng().random()).collect(); log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key)); log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins)."); key }); if gate_jwe_key.len() != 32 { panic!( "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption" ); } let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") { Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(), Err(_) => { vec![ String::from("https://bsky.app"), String::from("https://pdsmoover.com"), String::from("https://blacksky.community"), String::from("https://tektite.cc"), ] } }; AppConfig { pds_base_url, mailer_from, email_subject, allow_only_migrations, use_captcha, default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT") .unwrap_or("https://bsky.app".to_string()), pds_service_did: pds_service_did .parse() .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"), gate_jwe_key, captcha_success_redirects, } } } #[derive(Clone)] pub struct AppState { account_pool: SqlitePool, pds_gatekeeper_pool: SqlitePool, reverse_proxy_client: HyperUtilClient, mailer: AsyncSmtpTransport, template_engine: Engine>, resolver: Arc, app_config: AppConfig, } async fn root_handler() -> impl axum::response::IntoResponse { let body = r" ...oO _.--X~~OO~~X--._ ...oOO _.-~ / \ II / \ ~-._ [].-~ \ / \||/ \ / ~-.[] ...o ...o _ ||/ \ / || \ / \|| _ (_) |X X || X X| (_) _-~-_ ||\ / \ || / \ /|| _-~-_ ||||| || \ / \ /||\ / \ / || ||||| | |_|| \ / \ / || \ / \ / ||_| | | |~|| X X || X X ||~| | ==============| | || / \ / \ || / \ / \ || | |============== ______________| | || / \ / \||/ \ / \ || | |______________ . . | | ||/ \ / || \ / \|| | | . . / | | |X X || X X| | | / / / . | | ||\ / \ || / \ /|| | | . / . . / | | || \ / \ /||\ / \ / || | | . . . . | | || \ / \ / || \ / \ / || | | . / | | || X X || X X || | | . / . / / . | | || / \ / \ || / \ / \ || | | / / | | || / \ / \||/ \ / \ || | | . / . . . | | ||/ \ / /||\ \ / \|| | | /. . | |_|X X / II \ X X|_| | . . / ==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |============== "; let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n"; let banner = format!(" {body}\n{intro}"); ( [(header::CONTENT_TYPE, "text/plain; charset=utf-8")], banner, ) } #[tokio::main] async fn main() -> Result<(), Box> { setup_tracing(); let pds_env_location = env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); if let Err(e) = result_of_finding_pds_env { log::error!( "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" ); } let pds_root = env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); let account_db_url = format!("{pds_root}/account.sqlite"); let account_options = SqliteConnectOptions::new() .filename(account_db_url) .busy_timeout(Duration::from_secs(5)); let account_pool = SqlitePoolOptions::new() .max_connections(5) .connect_with(account_options) .await?; let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite"); let options = SqliteConnectOptions::new() .journal_mode(SqliteJournalMode::Wal) .filename(bells_db_url) .create_if_missing(true) .busy_timeout(Duration::from_secs(5)); let pds_gatekeeper_pool = SqlitePoolOptions::new() .max_connections(5) .connect_with(options) .await?; // Run migrations for the extra database // Note: the migrations are embedded at compile time from the given directory // sqlx sqlx::migrate!("./migrations") .run(&pds_gatekeeper_pool) .await?; let client: HyperUtilClient = hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new()) .build(HttpConnector::new()); //Emailer set up let smtp_url = env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file"); let mailer: AsyncSmtpTransport = AsyncSmtpTransport::::from_url(smtp_url.as_str())?.build(); //Email templates setup let mut hbs = Handlebars::new(); let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY"); if let Ok(users_email_directory) = users_email_directory { hbs.register_template_file( "two_factor_code.hbs", format!("{users_email_directory}/two_factor_code.hbs"), )?; } else { let _ = hbs.register_embed_templates::(); } let _ = hbs.register_embed_templates::(); //Reads the PLC source from the pds env's or defaults to ol faithful let plc_source_url = env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string()); let plc_source = PlcSource::PlcDirectory { base: plc_source_url.parse().unwrap(), }; let mut resolver = PublicResolver::default(); resolver = resolver.with_plc_source(plc_source.clone()); let state = AppState { account_pool, pds_gatekeeper_pool, reverse_proxy_client: client, mailer, template_engine: Engine::from(hbs), resolver: Arc::new(resolver), app_config: AppConfig::new(), }; // Rate limiting //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds. let captcha_governor_conf = GovernorConfigBuilder::default() .per_second(60) .burst_size(5) .key_extractor(SmartIpKeyExtractor) .finish() .expect("failed to create governor config for create session. this should not happen and is a bug"); // Create a second config with the same settings for the other endpoint let sign_in_governor_conf = GovernorConfigBuilder::default() .per_second(60) .burst_size(5) .key_extractor(SmartIpKeyExtractor) .finish() .expect( "failed to create governor config for sign in. this should not happen and is a bug", ); let create_account_limiter_time: Option = env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); let create_account_limiter_burst: Option = env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally let mut create_account_governor_conf = GovernorConfigBuilder::default(); if create_account_limiter_time.is_some() { let time = create_account_limiter_time .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") .parse::() .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); create_account_governor_conf.per_second(time); } if create_account_limiter_burst.is_some() { let burst = create_account_limiter_burst .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") .parse::() .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); create_account_governor_conf.burst_size(burst); } let create_account_governor_conf = create_account_governor_conf .key_extractor(SmartIpKeyExtractor) .finish().expect( "failed to create governor config for create account. this should not happen and is a bug", ); let captcha_governor_limiter = captcha_governor_conf.limiter().clone(); let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf); let interval = Duration::from_secs(60); // a separate background task to clean up std::thread::spawn(move || { loop { std::thread::sleep(interval); captcha_governor_limiter.retain_recent(); sign_in_governor_limiter.retain_recent(); create_account_governor_limiter.retain_recent(); } }); let cors = CorsLayer::new() .allow_origin(Any) .allow_methods([Method::GET, Method::OPTIONS, Method::POST]) .allow_headers(Any); let mut app = Router::new() .route("/", get(root_handler)) .route("/xrpc/com.atproto.server.getSession", get(get_session)) .route( "/xrpc/com.atproto.server.describeServer", get(describe_server), ) .route( "/xrpc/com.atproto.server.updateEmail", post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), ) .route( "/@atproto/oauth-provider/~api/sign-in", post(sign_in).layer(sign_in_governor_layer.clone()), ) .route( "/xrpc/com.atproto.server.createSession", post(create_session.layer(sign_in_governor_layer)), ) .route( "/xrpc/com.atproto.server.createAccount", post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), ); if state.app_config.use_captcha { app = app.route( "/gate/signup", get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))), ); } let app = app .layer(CompressionLayer::new()) .layer(cors) .with_state(state); let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string()); let port: u16 = env::var("GATEKEEPER_PORT") .ok() .and_then(|s| s.parse().ok()) .unwrap_or(8080); let addr: SocketAddr = format!("{host}:{port}") .parse() .expect("valid socket address"); let listener = tokio::net::TcpListener::bind(addr).await?; let server = axum::serve( listener, app.into_make_service_with_connect_info::(), ) .with_graceful_shutdown(shutdown_signal()); if let Err(err) = server.await { log::error!("server error:{err}"); } Ok(()) } fn setup_tracing() { let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); tracing_subscriber::registry() .with(env_filter) .with(fmt::layer()) .init(); } async fn shutdown_signal() { // Wait for Ctrl+C let ctrl_c = async { tokio::signal::ctrl_c() .await .expect("failed to install Ctrl+C handler"); }; #[cfg(unix)] let terminate = async { use tokio::signal::unix::{SignalKind, signal}; let mut sigterm = signal(SignalKind::terminate()).expect("failed to install signal handler"); sigterm.recv().await; }; #[cfg(not(unix))] let terminate = std::future::pending::<()>(); tokio::select! { _ = ctrl_c => {}, _ = terminate => {}, } }