forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)]
2use crate::oauth_provider::sign_in;
3use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email};
4use axum::body::Body;
5use axum::handler::Handler;
6use axum::http::{Method, header};
7use axum::middleware as ax_middleware;
8use axum::routing::post;
9use axum::{Router, routing::get};
10use axum_template::engine::Engine;
11use handlebars::Handlebars;
12use hyper_util::client::legacy::connect::HttpConnector;
13use hyper_util::rt::TokioExecutor;
14use lettre::{AsyncSmtpTransport, Tokio1Executor};
15use rust_embed::RustEmbed;
16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
18use std::path::Path;
19use std::time::Duration;
20use std::{env, net::SocketAddr};
21use tower_governor::GovernorLayer;
22use tower_governor::governor::GovernorConfigBuilder;
23use tower_governor::key_extractor::SmartIpKeyExtractor;
24use tower_http::compression::CompressionLayer;
25use tower_http::cors::{Any, CorsLayer};
26use tracing::log;
27use tracing_subscriber::{EnvFilter, fmt, prelude::*};
28
29pub mod helpers;
30mod middleware;
31mod oauth_provider;
32mod xrpc;
33
34type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
35
36#[derive(RustEmbed)]
37#[folder = "email_templates"]
38#[include = "*.hbs"]
39struct EmailTemplates;
40
41#[derive(Clone)]
42pub struct AppState {
43 account_pool: SqlitePool,
44 pds_gatekeeper_pool: SqlitePool,
45 reverse_proxy_client: HyperUtilClient,
46 pds_base_url: String,
47 mailer: AsyncSmtpTransport<Tokio1Executor>,
48 mailer_from: String,
49 template_engine: Engine<Handlebars<'static>>,
50}
51
52async fn root_handler() -> impl axum::response::IntoResponse {
53 let body = r"
54
55 ...oO _.--X~~OO~~X--._ ...oOO
56 _.-~ / \ II / \ ~-._
57 [].-~ \ / \||/ \ / ~-.[] ...o
58 ...o _ ||/ \ / || \ / \|| _
59 (_) |X X || X X| (_)
60 _-~-_ ||\ / \ || / \ /|| _-~-_
61 ||||| || \ / \ /||\ / \ / || |||||
62 | |_|| \ / \ / || \ / \ / ||_| |
63 | |~|| X X || X X ||~| |
64==============| | || / \ / \ || / \ / \ || | |==============
65______________| | || / \ / \||/ \ / \ || | |______________
66 . . | | ||/ \ / || \ / \|| | | . .
67 / | | |X X || X X| | | / /
68 / . | | ||\ / \ || / \ /|| | | . / .
69. / | | || \ / \ /||\ / \ / || | | . .
70 . . | | || \ / \ / || \ / \ / || | | .
71 / | | || X X || X X || | | . / . /
72 / . | | || / \ / \ || / \ / \ || | | /
73 / | | || / \ / \||/ \ / \ || | | . /
74. . . | | ||/ \ / /||\ \ / \|| | | /. .
75 | |_|X X / II \ X X|_| | . . /
76==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
77 ";
78
79 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
80
81 let banner = format!(" {body}\n{intro}");
82
83 (
84 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
85 banner,
86 )
87}
88
89#[tokio::main]
90async fn main() -> Result<(), Box<dyn std::error::Error>> {
91 setup_tracing();
92 let pds_env_location =
93 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
94
95 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
96 if let Err(e) = result_of_finding_pds_env {
97 log::error!(
98 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
99 );
100 }
101
102 let pds_root =
103 env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file");
104 let account_db_url = format!("{pds_root}/account.sqlite");
105
106 let account_options = SqliteConnectOptions::new()
107 .filename(account_db_url)
108 .busy_timeout(Duration::from_secs(5));
109
110 let account_pool = SqlitePoolOptions::new()
111 .max_connections(5)
112 .connect_with(account_options)
113 .await?;
114
115 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
116 let options = SqliteConnectOptions::new()
117 .journal_mode(SqliteJournalMode::Wal)
118 .filename(bells_db_url)
119 .create_if_missing(true)
120 .busy_timeout(Duration::from_secs(5));
121 let pds_gatekeeper_pool = SqlitePoolOptions::new()
122 .max_connections(5)
123 .connect_with(options)
124 .await?;
125
126 // Run migrations for the extra database
127 // Note: the migrations are embedded at compile time from the given directory
128 // sqlx
129 sqlx::migrate!("./migrations")
130 .run(&pds_gatekeeper_pool)
131 .await?;
132
133 let client: HyperUtilClient =
134 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
135 .build(HttpConnector::new());
136
137 //Emailer set up
138 let smtp_url =
139 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file");
140 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS")
141 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
142
143 let mailer: AsyncSmtpTransport<Tokio1Executor> =
144 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
145 //Email templates setup
146 let mut hbs = Handlebars::new();
147
148 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
149 if let Ok(users_email_directory) = users_email_directory {
150 hbs.register_template_file(
151 "two_factor_code.hbs",
152 format!("{users_email_directory}/two_factor_code.hbs"),
153 )?;
154 } else {
155 let _ = hbs.register_embed_templates::<EmailTemplates>();
156 }
157
158 let pds_base_url =
159 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
160
161 let state = AppState {
162 account_pool,
163 pds_gatekeeper_pool,
164 reverse_proxy_client: client,
165 pds_base_url,
166 mailer,
167 mailer_from: sent_from,
168 template_engine: Engine::from(hbs),
169 };
170
171 // Rate limiting
172 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
173 let create_session_governor_conf = GovernorConfigBuilder::default()
174 .per_second(60)
175 .burst_size(5)
176 .key_extractor(SmartIpKeyExtractor)
177 .finish()
178 .expect("failed to create governor config for create session. this should not happen and is a bug");
179
180 // Create a second config with the same settings for the other endpoint
181 let sign_in_governor_conf = GovernorConfigBuilder::default()
182 .per_second(60)
183 .burst_size(5)
184 .key_extractor(SmartIpKeyExtractor)
185 .finish()
186 .expect(
187 "failed to create governor config for sign in. this should not happen and is a bug",
188 );
189
190 let create_account_limiter_time: Option<String> =
191 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
192 let create_account_limiter_burst: Option<String> =
193 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
194
195 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
196 let mut create_account_governor_conf = GovernorConfigBuilder::default();
197 if create_account_limiter_time.is_some() {
198 let time = create_account_limiter_time
199 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
200 .parse::<u64>()
201 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
202 create_account_governor_conf.per_second(time);
203 }
204
205 if create_account_limiter_burst.is_some() {
206 let burst = create_account_limiter_burst
207 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
208 .parse::<u32>()
209 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
210 create_account_governor_conf.burst_size(burst);
211 }
212
213 let create_account_governor_conf = create_account_governor_conf
214 .key_extractor(SmartIpKeyExtractor)
215 .finish().expect(
216 "failed to create governor config for create account. this should not happen and is a bug",
217 );
218
219 let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
220 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
221 let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
222
223 let interval = Duration::from_secs(60);
224 // a separate background task to clean up
225 std::thread::spawn(move || {
226 loop {
227 std::thread::sleep(interval);
228 create_session_governor_limiter.retain_recent();
229 sign_in_governor_limiter.retain_recent();
230 create_account_governor_limiter.retain_recent();
231 }
232 });
233
234 let cors = CorsLayer::new()
235 .allow_origin(Any)
236 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
237 .allow_headers(Any);
238
239 let app = Router::new()
240 .route("/", get(root_handler))
241 .route("/xrpc/com.atproto.server.getSession", get(get_session))
242 .route(
243 "/xrpc/com.atproto.server.updateEmail",
244 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
245 )
246 .route(
247 "/@atproto/oauth-provider/~api/sign-in",
248 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
249 )
250 .route(
251 "/xrpc/com.atproto.server.createSession",
252 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
253 )
254 .route(
255 "/xrpc/com.atproto.server.createAccount",
256 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
257 )
258 .layer(CompressionLayer::new())
259 .layer(cors)
260 .with_state(state);
261
262 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
263 let port: u16 = env::var("GATEKEEPER_PORT")
264 .ok()
265 .and_then(|s| s.parse().ok())
266 .unwrap_or(8080);
267 let addr: SocketAddr = format!("{host}:{port}")
268 .parse()
269 .expect("valid socket address");
270
271 let listener = tokio::net::TcpListener::bind(addr).await?;
272
273 let server = axum::serve(
274 listener,
275 app.into_make_service_with_connect_info::<SocketAddr>(),
276 )
277 .with_graceful_shutdown(shutdown_signal());
278
279 if let Err(err) = server.await {
280 log::error!("server error:{err}");
281 }
282
283 Ok(())
284}
285
286fn setup_tracing() {
287 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
288 tracing_subscriber::registry()
289 .with(env_filter)
290 .with(fmt::layer())
291 .init();
292}
293
294async fn shutdown_signal() {
295 // Wait for Ctrl+C
296 let ctrl_c = async {
297 tokio::signal::ctrl_c()
298 .await
299 .expect("failed to install Ctrl+C handler");
300 };
301
302 #[cfg(unix)]
303 let terminate = async {
304 use tokio::signal::unix::{SignalKind, signal};
305
306 let mut sigterm =
307 signal(SignalKind::terminate()).expect("failed to install signal handler");
308 sigterm.recv().await;
309 };
310
311 #[cfg(not(unix))]
312 let terminate = std::future::pending::<()>();
313
314 tokio::select! {
315 _ = ctrl_c => {},
316 _ = terminate => {},
317 }
318}