forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)]
2use crate::gate::{get_gate, post_gate};
3use crate::oauth_provider::sign_in;
4use crate::xrpc::com_atproto_server::{
5 create_account, create_session, describe_server, get_session, update_email,
6};
7use axum::{
8 Router,
9 body::Body,
10 handler::Handler,
11 http::{Method, header},
12 middleware as ax_middleware,
13 routing::get,
14 routing::post,
15};
16use axum_template::engine::Engine;
17use handlebars::Handlebars;
18use hyper_util::{client::legacy::connect::HttpConnector, rt::TokioExecutor};
19use jacquard_common::types::did::Did;
20use jacquard_identity::{PublicResolver, resolver::PlcSource};
21use lettre::{AsyncSmtpTransport, Tokio1Executor};
22use rand::Rng;
23use rust_embed::RustEmbed;
24use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
25use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
26use std::path::Path;
27use std::sync::Arc;
28use std::time::Duration;
29use std::{env, net::SocketAddr};
30use tower_governor::{
31 GovernorLayer, governor::GovernorConfigBuilder, key_extractor::SmartIpKeyExtractor,
32};
33use tower_http::{
34 compression::CompressionLayer,
35 cors::{Any, CorsLayer},
36};
37use tracing::log;
38use tracing_subscriber::{EnvFilter, fmt, prelude::*};
39
40mod gate;
41pub mod helpers;
42mod middleware;
43mod oauth_provider;
44mod xrpc;
45
46type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
47
48#[derive(RustEmbed)]
49#[folder = "email_templates"]
50#[include = "*.hbs"]
51struct EmailTemplates;
52
53#[derive(RustEmbed)]
54#[folder = "html_templates"]
55#[include = "*.hbs"]
56struct HtmlTemplates;
57
58/// Mostly the env variables that are used in the app
59#[derive(Clone, Debug)]
60pub struct AppConfig {
61 pds_base_url: String,
62 mailer_from: String,
63 email_subject: String,
64 allow_only_migrations: bool,
65 use_captcha: bool,
66 //The url to redirect to after a successful captcha. Defaults to https://bsky.app, but you may have another social-app fork you rather your users use
67 //that need to capture this redirect url for creating an account
68 default_successful_redirect_url: String,
69 pds_service_did: Did<'static>,
70 gate_jwe_key: Vec<u8>,
71 captcha_success_redirects: Vec<String>,
72}
73
74impl AppConfig {
75 pub fn new() -> Self {
76 let pds_base_url =
77 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
78 let mailer_from = env::var("PDS_EMAIL_FROM_ADDRESS")
79 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
80 //Hack not my favorite, but it does work
81 let allow_only_migrations = env::var("GATEKEEPER_ALLOW_ONLY_MIGRATIONS")
82 .map(|val| val.parse::<bool>().unwrap_or(false))
83 .unwrap_or(false);
84
85 let use_captcha = env::var("GATEKEEPER_CREATE_ACCOUNT_CAPTCHA")
86 .map(|val| val.parse::<bool>().unwrap_or(false))
87 .unwrap_or(false);
88
89 // PDS_SERVICE_DID is the did:web if set, if not it's PDS_HOSTNAME
90 let pds_service_did =
91 env::var("PDS_SERVICE_DID").unwrap_or_else(|_| match env::var("PDS_HOSTNAME") {
92 Ok(pds_hostname) => format!("did:web:{}", pds_hostname),
93 Err(_) => {
94 panic!("PDS_HOSTNAME or PDS_SERVICE_DID must be set in your pds.env file")
95 }
96 });
97
98 let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT")
99 .unwrap_or("Sign in to Bluesky".to_string());
100
101 // Load or generate JWE encryption key (32 bytes for AES-256)
102 let gate_jwe_key = env::var("GATEKEEPER_JWE_KEY")
103 .ok()
104 .and_then(|key_hex| hex::decode(key_hex).ok())
105 .unwrap_or_else(|| {
106 // Generate a random 32-byte key if not provided
107 let key: Vec<u8> = (0..32).map(|_| rand::rng().random()).collect();
108 log::warn!("WARNING: No GATEKEEPER_JWE_KEY found in the environment. Generated random key (hex): {}", hex::encode(&key));
109 log::warn!("This is not strictly needed unless you scale PDS Gatekeeper. Will not also be able to verify tokens between reboots, but they are short lived (5mins).");
110 key
111 });
112
113 if gate_jwe_key.len() != 32 {
114 panic!(
115 "GATEKEEPER_JWE_KEY must be 32 bytes (64 hex characters) for AES-256 encryption"
116 );
117 }
118
119 let captcha_success_redirects = match env::var("GATEKEEPER_CAPTCHA_SUCCESS_REDIRECTS") {
120 Ok(from_env) => from_env.split(",").map(|s| s.trim().to_string()).collect(),
121 Err(_) => {
122 vec![
123 String::from("https://bsky.app"),
124 String::from("https://pdsmoover.com"),
125 String::from("https://blacksky.community"),
126 String::from("https://tektite.cc"),
127 ]
128 }
129 };
130
131 AppConfig {
132 pds_base_url,
133 mailer_from,
134 email_subject,
135 allow_only_migrations,
136 use_captcha,
137 default_successful_redirect_url: env::var("GATEKEEPER_DEFAULT_CAPTCHA_REDIRECT")
138 .unwrap_or("https://bsky.app".to_string()),
139 pds_service_did: pds_service_did
140 .parse()
141 .expect("PDS_SERVICE_DID is not a valid did or could not infer from PDS_HOSTNAME"),
142 gate_jwe_key,
143 captcha_success_redirects,
144 }
145 }
146}
147
148#[derive(Clone)]
149pub struct AppState {
150 account_pool: SqlitePool,
151 pds_gatekeeper_pool: SqlitePool,
152 reverse_proxy_client: HyperUtilClient,
153 mailer: AsyncSmtpTransport<Tokio1Executor>,
154 template_engine: Engine<Handlebars<'static>>,
155 resolver: Arc<PublicResolver>,
156 app_config: AppConfig,
157}
158
159async fn root_handler() -> impl axum::response::IntoResponse {
160 let body = r"
161
162 ...oO _.--X~~OO~~X--._ ...oOO
163 _.-~ / \ II / \ ~-._
164 [].-~ \ / \||/ \ / ~-.[] ...o
165 ...o _ ||/ \ / || \ / \|| _
166 (_) |X X || X X| (_)
167 _-~-_ ||\ / \ || / \ /|| _-~-_
168 ||||| || \ / \ /||\ / \ / || |||||
169 | |_|| \ / \ / || \ / \ / ||_| |
170 | |~|| X X || X X ||~| |
171==============| | || / \ / \ || / \ / \ || | |==============
172______________| | || / \ / \||/ \ / \ || | |______________
173 . . | | ||/ \ / || \ / \|| | | . .
174 / | | |X X || X X| | | / /
175 / . | | ||\ / \ || / \ /|| | | . / .
176. / | | || \ / \ /||\ / \ / || | | . .
177 . . | | || \ / \ / || \ / \ / || | | .
178 / | | || X X || X X || | | . / . /
179 / . | | || / \ / \ || / \ / \ || | | /
180 / | | || / \ / \||/ \ / \ || | | . /
181. . . | | ||/ \ / /||\ \ / \|| | | /. .
182 | |_|X X / II \ X X|_| | . . /
183==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
184 ";
185
186 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
187
188 let banner = format!(" {body}\n{intro}");
189
190 (
191 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
192 banner,
193 )
194}
195
196#[tokio::main]
197async fn main() -> Result<(), Box<dyn std::error::Error>> {
198 setup_tracing();
199 let pds_env_location =
200 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
201
202 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
203 if let Err(e) = result_of_finding_pds_env {
204 log::error!(
205 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
206 );
207 }
208
209 let pds_root =
210 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file");
211 let account_db_url = format!("{pds_root}/account.sqlite");
212
213 let account_options = SqliteConnectOptions::new()
214 .filename(account_db_url)
215 .busy_timeout(Duration::from_secs(5));
216
217 let account_pool = SqlitePoolOptions::new()
218 .max_connections(5)
219 .connect_with(account_options)
220 .await?;
221
222 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
223 let options = SqliteConnectOptions::new()
224 .journal_mode(SqliteJournalMode::Wal)
225 .filename(bells_db_url)
226 .create_if_missing(true)
227 .busy_timeout(Duration::from_secs(5));
228 let pds_gatekeeper_pool = SqlitePoolOptions::new()
229 .max_connections(5)
230 .connect_with(options)
231 .await?;
232
233 // Run migrations for the extra database
234 // Note: the migrations are embedded at compile time from the given directory
235 // sqlx
236 sqlx::migrate!("./migrations")
237 .run(&pds_gatekeeper_pool)
238 .await?;
239
240 let client: HyperUtilClient =
241 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
242 .build(HttpConnector::new());
243
244 //Emailer set up
245 let smtp_url =
246 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file");
247
248 let mailer: AsyncSmtpTransport<Tokio1Executor> =
249 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
250 //Email templates setup
251 let mut hbs = Handlebars::new();
252
253 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
254 if let Ok(users_email_directory) = users_email_directory {
255 hbs.register_template_file(
256 "two_factor_code.hbs",
257 format!("{users_email_directory}/two_factor_code.hbs"),
258 )?;
259 } else {
260 let _ = hbs.register_embed_templates::<EmailTemplates>();
261 }
262
263 let _ = hbs.register_embed_templates::<HtmlTemplates>();
264
265 //Reads the PLC source from the pds env's or defaults to ol faithful
266 let plc_source_url =
267 env::var("PDS_DID_PLC_URL").unwrap_or_else(|_| "https://plc.directory".to_string());
268 let plc_source = PlcSource::PlcDirectory {
269 base: plc_source_url.parse().unwrap(),
270 };
271 let mut resolver = PublicResolver::default();
272 resolver = resolver.with_plc_source(plc_source.clone());
273
274 let state = AppState {
275 account_pool,
276 pds_gatekeeper_pool,
277 reverse_proxy_client: client,
278 mailer,
279 template_engine: Engine::from(hbs),
280 resolver: Arc::new(resolver),
281 app_config: AppConfig::new(),
282 };
283
284 // Rate limiting
285 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
286 let captcha_governor_conf = GovernorConfigBuilder::default()
287 .per_second(60)
288 .burst_size(5)
289 .key_extractor(SmartIpKeyExtractor)
290 .finish()
291 .expect("failed to create governor config for create session. this should not happen and is a bug");
292
293 // Create a second config with the same settings for the other endpoint
294 let sign_in_governor_conf = GovernorConfigBuilder::default()
295 .per_second(60)
296 .burst_size(5)
297 .key_extractor(SmartIpKeyExtractor)
298 .finish()
299 .expect(
300 "failed to create governor config for sign in. this should not happen and is a bug",
301 );
302
303 let create_account_limiter_time: Option<String> =
304 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
305 let create_account_limiter_burst: Option<String> =
306 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
307
308 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
309 let mut create_account_governor_conf = GovernorConfigBuilder::default();
310 if create_account_limiter_time.is_some() {
311 let time = create_account_limiter_time
312 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
313 .parse::<u64>()
314 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
315 create_account_governor_conf.per_second(time);
316 }
317
318 if create_account_limiter_burst.is_some() {
319 let burst = create_account_limiter_burst
320 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
321 .parse::<u32>()
322 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
323 create_account_governor_conf.burst_size(burst);
324 }
325
326 let create_account_governor_conf = create_account_governor_conf
327 .key_extractor(SmartIpKeyExtractor)
328 .finish().expect(
329 "failed to create governor config for create account. this should not happen and is a bug",
330 );
331
332 let captcha_governor_limiter = captcha_governor_conf.limiter().clone();
333 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
334 let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
335
336 let sign_in_governor_layer = GovernorLayer::new(sign_in_governor_conf);
337
338 let interval = Duration::from_secs(60);
339 // a separate background task to clean up
340 std::thread::spawn(move || {
341 loop {
342 std::thread::sleep(interval);
343 captcha_governor_limiter.retain_recent();
344 sign_in_governor_limiter.retain_recent();
345 create_account_governor_limiter.retain_recent();
346 }
347 });
348
349 let cors = CorsLayer::new()
350 .allow_origin(Any)
351 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
352 .allow_headers(Any);
353
354 let mut app = Router::new()
355 .route("/", get(root_handler))
356 .route("/xrpc/com.atproto.server.getSession", get(get_session))
357 .route(
358 "/xrpc/com.atproto.server.describeServer",
359 get(describe_server),
360 )
361 .route(
362 "/xrpc/com.atproto.server.updateEmail",
363 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
364 )
365 .route(
366 "/@atproto/oauth-provider/~api/sign-in",
367 post(sign_in).layer(sign_in_governor_layer.clone()),
368 )
369 .route(
370 "/xrpc/com.atproto.server.createSession",
371 post(create_session.layer(sign_in_governor_layer)),
372 )
373 .route(
374 "/xrpc/com.atproto.server.createAccount",
375 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
376 );
377
378 if state.app_config.use_captcha {
379 app = app.route(
380 "/gate/signup",
381 get(get_gate).post(post_gate.layer(GovernorLayer::new(captcha_governor_conf))),
382 );
383 }
384
385 let app = app
386 .layer(CompressionLayer::new())
387 .layer(cors)
388 .with_state(state);
389
390 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
391 let port: u16 = env::var("GATEKEEPER_PORT")
392 .ok()
393 .and_then(|s| s.parse().ok())
394 .unwrap_or(8080);
395 let addr: SocketAddr = format!("{host}:{port}")
396 .parse()
397 .expect("valid socket address");
398
399 let listener = tokio::net::TcpListener::bind(addr).await?;
400
401 let server = axum::serve(
402 listener,
403 app.into_make_service_with_connect_info::<SocketAddr>(),
404 )
405 .with_graceful_shutdown(shutdown_signal());
406
407 if let Err(err) = server.await {
408 log::error!("server error:{err}");
409 }
410
411 Ok(())
412}
413
414fn setup_tracing() {
415 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
416 tracing_subscriber::registry()
417 .with(env_filter)
418 .with(fmt::layer())
419 .init();
420}
421
422async fn shutdown_signal() {
423 // Wait for Ctrl+C
424 let ctrl_c = async {
425 tokio::signal::ctrl_c()
426 .await
427 .expect("failed to install Ctrl+C handler");
428 };
429
430 #[cfg(unix)]
431 let terminate = async {
432 use tokio::signal::unix::{SignalKind, signal};
433
434 let mut sigterm =
435 signal(SignalKind::terminate()).expect("failed to install signal handler");
436 sigterm.recv().await;
437 };
438
439 #[cfg(not(unix))]
440 let terminate = std::future::pending::<()>();
441
442 tokio::select! {
443 _ = ctrl_c => {},
444 _ = terminate => {},
445 }
446}