forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1#![warn(clippy::unwrap_used)]
2use crate::oauth_provider::sign_in;
3use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email};
4use axum::body::Body;
5use axum::handler::Handler;
6use axum::http::{Method, header};
7use axum::middleware as ax_middleware;
8use axum::routing::post;
9use axum::{Router, routing::get};
10use axum_template::engine::Engine;
11use handlebars::Handlebars;
12use hyper_util::client::legacy::connect::HttpConnector;
13use hyper_util::rt::TokioExecutor;
14use lettre::{AsyncSmtpTransport, Tokio1Executor};
15use rust_embed::RustEmbed;
16use sqlx::sqlite::{SqliteConnectOptions, SqliteJournalMode};
17use sqlx::{SqlitePool, sqlite::SqlitePoolOptions};
18use std::path::Path;
19use std::time::Duration;
20use std::{env, net::SocketAddr};
21use tower_governor::GovernorLayer;
22use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder};
23use tower_http::compression::CompressionLayer;
24use tower_http::cors::{Any, CorsLayer};
25use tracing::log;
26use tracing_subscriber::{EnvFilter, fmt, prelude::*};
27
28pub mod helpers;
29mod middleware;
30mod oauth_provider;
31mod xrpc;
32
33type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
34
35#[derive(RustEmbed)]
36#[folder = "email_templates"]
37#[include = "*.hbs"]
38struct EmailTemplates;
39
40#[derive(Clone)]
41pub struct AppState {
42 account_pool: SqlitePool,
43 pds_gatekeeper_pool: SqlitePool,
44 reverse_proxy_client: HyperUtilClient,
45 pds_base_url: String,
46 mailer: AsyncSmtpTransport<Tokio1Executor>,
47 mailer_from: String,
48 template_engine: Engine<Handlebars<'static>>,
49}
50
51async fn root_handler() -> impl axum::response::IntoResponse {
52 let body = r"
53
54 ...oO _.--X~~OO~~X--._ ...oOO
55 _.-~ / \ II / \ ~-._
56 [].-~ \ / \||/ \ / ~-.[] ...o
57 ...o _ ||/ \ / || \ / \|| _
58 (_) |X X || X X| (_)
59 _-~-_ ||\ / \ || / \ /|| _-~-_
60 ||||| || \ / \ /||\ / \ / || |||||
61 | |_|| \ / \ / || \ / \ / ||_| |
62 | |~|| X X || X X ||~| |
63==============| | || / \ / \ || / \ / \ || | |==============
64______________| | || / \ / \||/ \ / \ || | |______________
65 . . | | ||/ \ / || \ / \|| | | . .
66 / | | |X X || X X| | | / /
67 / . | | ||\ / \ || / \ /|| | | . / .
68. / | | || \ / \ /||\ / \ / || | | . .
69 . . | | || \ / \ / || \ / \ / || | | .
70 / | | || X X || X X || | | . / . /
71 / . | | || / \ / \ || / \ / \ || | | /
72 / | | || / \ / \||/ \ / \ || | | . /
73. . . | | ||/ \ / /||\ \ / \|| | | /. .
74 | |_|X X / II \ X X|_| | . . /
75==============| |~II~~~~~~~~~~~~~~OO~~~~~~~~~~~~~~II~| |==============
76 ";
77
78 let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
79
80 let banner = format!(" {body}\n{intro}");
81
82 (
83 [(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
84 banner,
85 )
86}
87
88#[tokio::main]
89async fn main() -> Result<(), Box<dyn std::error::Error>> {
90 setup_tracing();
91 let pds_env_location =
92 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
93
94 let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
95 if let Err(e) = result_of_finding_pds_env {
96 log::error!(
97 "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
98 );
99 }
100 let pds_root = env::var("PDS_DATA_DIRECTORY")?;
101 let account_db_url = format!("{pds_root}/account.sqlite");
102
103 let account_options = SqliteConnectOptions::new()
104 .filename(account_db_url)
105 .busy_timeout(Duration::from_secs(5));
106
107 let account_pool = SqlitePoolOptions::new()
108 .max_connections(5)
109 .connect_with(account_options)
110 .await?;
111
112 let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
113 let options = SqliteConnectOptions::new()
114 .journal_mode(SqliteJournalMode::Wal)
115 .filename(bells_db_url)
116 .create_if_missing(true)
117 .busy_timeout(Duration::from_secs(5));
118 let pds_gatekeeper_pool = SqlitePoolOptions::new()
119 .max_connections(5)
120 .connect_with(options)
121 .await?;
122
123 // Run migrations for the extra database
124 // Note: the migrations are embedded at compile time from the given directory
125 // sqlx
126 sqlx::migrate!("./migrations")
127 .run(&pds_gatekeeper_pool)
128 .await?;
129
130 let client: HyperUtilClient =
131 hyper_util::client::legacy::Client::<(), ()>::builder(TokioExecutor::new())
132 .build(HttpConnector::new());
133
134 //Emailer set up
135 let smtp_url =
136 env::var("PDS_EMAIL_SMTP_URL").expect("PDS_EMAIL_SMTP_URL is not set in your pds.env file");
137 let sent_from = env::var("PDS_EMAIL_FROM_ADDRESS")
138 .expect("PDS_EMAIL_FROM_ADDRESS is not set in your pds.env file");
139
140 let mailer: AsyncSmtpTransport<Tokio1Executor> =
141 AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
142 //Email templates setup
143 let mut hbs = Handlebars::new();
144
145 let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
146 if let Ok(users_email_directory) = users_email_directory {
147 hbs.register_template_file(
148 "two_factor_code.hbs",
149 format!("{users_email_directory}/two_factor_code.hbs"),
150 )?;
151 } else {
152 let _ = hbs.register_embed_templates::<EmailTemplates>();
153 }
154
155 let pds_base_url =
156 env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
157
158 let state = AppState {
159 account_pool,
160 pds_gatekeeper_pool,
161 reverse_proxy_client: client,
162 pds_base_url,
163 mailer,
164 mailer_from: sent_from,
165 template_engine: Engine::from(hbs),
166 };
167
168 // Rate limiting
169 //Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
170 let create_session_governor_conf = GovernorConfigBuilder::default()
171 .per_second(60)
172 .burst_size(5)
173 .finish()
174 .expect("failed to create governor config for create session. this should not happen and is a bug");
175
176 // Create a second config with the same settings for the other endpoint
177 let sign_in_governor_conf = GovernorConfigBuilder::default()
178 .per_second(60)
179 .burst_size(5)
180 .finish()
181 .expect(
182 "failed to create governor config for sign in. this should not happen and is a bug",
183 );
184
185 let create_account_limiter_time: Option<String> =
186 env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
187 let create_account_limiter_burst: Option<String> =
188 env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
189
190 //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
191 let mut create_account_governor_conf = GovernorConfigBuilder::default();
192 if create_account_limiter_time.is_some() {
193 let time = create_account_limiter_time
194 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
195 .parse::<u64>()
196 .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
197 create_account_governor_conf.per_second(time);
198 }
199
200 if create_account_limiter_burst.is_some() {
201 let burst = create_account_limiter_burst
202 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
203 .parse::<u32>()
204 .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
205 create_account_governor_conf.burst_size(burst);
206 }
207
208 let create_account_governor_conf = create_account_governor_conf.finish().expect(
209 "failed to create governor config for create account. this should not happen and is a bug",
210 );
211
212 let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
213 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
214 let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
215
216 let interval = Duration::from_secs(60);
217 // a separate background task to clean up
218 std::thread::spawn(move || {
219 loop {
220 std::thread::sleep(interval);
221 create_session_governor_limiter.retain_recent();
222 sign_in_governor_limiter.retain_recent();
223 create_account_governor_limiter.retain_recent();
224 }
225 });
226
227 let cors = CorsLayer::new()
228 .allow_origin(Any)
229 .allow_methods([Method::GET, Method::OPTIONS, Method::POST])
230 .allow_headers(Any);
231
232 let app = Router::new()
233 .route("/", get(root_handler))
234 .route("/xrpc/com.atproto.server.getSession", get(get_session))
235 .route(
236 "/xrpc/com.atproto.server.updateEmail",
237 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
238 )
239 .route(
240 "/@atproto/oauth-provider/~api/sign-in",
241 post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
242 )
243 .route(
244 "/xrpc/com.atproto.server.createSession",
245 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
246 )
247 .route(
248 "/xrpc/com.atproto.server.createAccount",
249 post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
250 )
251 .layer(CompressionLayer::new())
252 .layer(cors)
253 .with_state(state);
254
255 let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
256 let port: u16 = env::var("GATEKEEPER_PORT")
257 .ok()
258 .and_then(|s| s.parse().ok())
259 .unwrap_or(8080);
260 let addr: SocketAddr = format!("{host}:{port}")
261 .parse()
262 .expect("valid socket address");
263
264 let listener = tokio::net::TcpListener::bind(addr).await?;
265
266 let server = axum::serve(
267 listener,
268 app.into_make_service_with_connect_info::<SocketAddr>(),
269 )
270 .with_graceful_shutdown(shutdown_signal());
271
272 if let Err(err) = server.await {
273 log::error!("server error:{err}");
274 }
275
276 Ok(())
277}
278
279fn setup_tracing() {
280 let env_filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
281 tracing_subscriber::registry()
282 .with(env_filter)
283 .with(fmt::layer())
284 .init();
285}
286
287async fn shutdown_signal() {
288 // Wait for Ctrl+C
289 let ctrl_c = async {
290 tokio::signal::ctrl_c()
291 .await
292 .expect("failed to install Ctrl+C handler");
293 };
294
295 #[cfg(unix)]
296 let terminate = async {
297 use tokio::signal::unix::{SignalKind, signal};
298
299 let mut sigterm =
300 signal(SignalKind::terminate()).expect("failed to install signal handler");
301 sigterm.recv().await;
302 };
303
304 #[cfg(not(unix))]
305 let terminate = std::future::pending::<()>();
306
307 tokio::select! {
308 _ = ctrl_c => {},
309 _ = terminate => {},
310 }
311}