forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::{
3 AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json,
4};
5use crate::middleware::Did;
6use axum::body::Body;
7use axum::extract::State;
8use axum::http::{HeaderMap, StatusCode};
9use axum::response::{IntoResponse, Response};
10use axum::{Extension, Json, debug_handler, extract, extract::Request};
11use serde::{Deserialize, Serialize};
12use serde_json;
13use tracing::log;
14
15#[derive(Serialize, Deserialize, Debug, Clone)]
16#[serde(rename_all = "camelCase")]
17enum AccountStatus {
18 Takendown,
19 Suspended,
20 Deactivated,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone)]
24#[serde(rename_all = "camelCase")]
25struct GetSessionResponse {
26 handle: String,
27 did: String,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 email: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 email_confirmed: Option<bool>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 email_auth_factor: Option<bool>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 did_doc: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 active: Option<bool>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 status: Option<AccountStatus>,
40}
41
42#[derive(Serialize, Deserialize, Debug, Clone)]
43#[serde(rename_all = "camelCase")]
44pub struct UpdateEmailResponse {
45 email: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 email_auth_factor: Option<bool>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 token: Option<String>,
50}
51
52#[allow(dead_code)]
53#[derive(Deserialize, Serialize)]
54#[serde(rename_all = "camelCase")]
55pub struct CreateSessionRequest {
56 identifier: String,
57 password: String,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 auth_factor_token: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 allow_takendown: Option<bool>,
62}
63
64pub async fn create_session(
65 State(state): State<AppState>,
66 headers: HeaderMap,
67 Json(payload): extract::Json<CreateSessionRequest>,
68) -> Result<Response<Body>, StatusCode> {
69 let identifier = payload.identifier.clone();
70 let password = payload.password.clone();
71 let auth_factor_token = payload.auth_factor_token.clone();
72
73 // Run the shared pre-auth logic to validate and check 2FA requirement
74 match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
75 Ok(result) => match result {
76 AuthResult::WrongIdentityOrPassword => json_error_response(
77 StatusCode::UNAUTHORIZED,
78 "AuthenticationRequired",
79 "Invalid identifier or password",
80 ),
81 AuthResult::TwoFactorRequired(_) => {
82 // Email sending step can be handled here if needed in the future.
83 json_error_response(
84 StatusCode::UNAUTHORIZED,
85 "AuthFactorTokenRequired",
86 "A sign in code has been sent to your email address",
87 )
88 }
89 AuthResult::ProxyThrough => {
90 //No 2FA or already passed
91 let uri = format!(
92 "{}{}",
93 state.pds_base_url, "/xrpc/com.atproto.server.createSession"
94 );
95
96 let mut req = axum::http::Request::post(uri);
97 if let Some(req_headers) = req.headers_mut() {
98 req_headers.extend(headers.clone());
99 }
100
101 let payload_bytes =
102 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
103 let req = req
104 .body(Body::from(payload_bytes))
105 .map_err(|_| StatusCode::BAD_REQUEST)?;
106
107 let proxied = state
108 .reverse_proxy_client
109 .request(req)
110 .await
111 .map_err(|_| StatusCode::BAD_REQUEST)?
112 .into_response();
113
114 Ok(proxied)
115 }
116 AuthResult::TokenCheckFailed(err) => match err {
117 TokenCheckError::InvalidToken => {
118 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
119 }
120 TokenCheckError::ExpiredToken => {
121 json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
122 }
123 },
124 },
125 Err(err) => {
126 log::error!(
127 "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
128 );
129 json_error_response(
130 StatusCode::INTERNAL_SERVER_ERROR,
131 "InternalServerError",
132 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
133 )
134 }
135 }
136}
137
138#[debug_handler]
139pub async fn update_email(
140 State(state): State<AppState>,
141 Extension(did): Extension<Did>,
142 headers: HeaderMap,
143 Json(payload): extract::Json<UpdateEmailResponse>,
144) -> Result<Response<Body>, StatusCode> {
145 //If email auth is not set at all it is a update email address
146 let email_auth_not_set = payload.email_auth_factor.is_none();
147 //If email auth is set it is to either turn on or off 2fa
148 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149
150 //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS
151 //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented
152 let did_is_not_empty = did.0.is_some();
153
154 if did_is_not_empty {
155 // Email update asked for
156 if email_auth_update {
157 let email = payload.email.clone();
158 let email_confirmed = match sqlx::query_as::<_, (String,)>(
159 "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
160 )
161 .bind(&email)
162 .fetch_optional(&state.account_pool)
163 .await
164 {
165 Ok(row) => row,
166 Err(err) => {
167 log::error!("Error checking if email is confirmed: {err}");
168 return Err(StatusCode::BAD_REQUEST);
169 }
170 };
171
172 //Since the email is already confirmed we can enable 2fa
173 return match email_confirmed {
174 None => Err(StatusCode::BAD_REQUEST),
175 Some(did_row) => {
176 let _ = sqlx::query(
177 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
178 )
179 .bind(&did_row.0)
180 .execute(&state.pds_gatekeeper_pool)
181 .await
182 .map_err(|_| StatusCode::BAD_REQUEST)?;
183
184 Ok(StatusCode::OK.into_response())
185 }
186 };
187 }
188
189 // User wants auth turned off
190 if !email_auth_update && !email_auth_not_set {
191 //User wants auth turned off and has a token
192 if let Some(token) = &payload.token {
193 let token_found = match sqlx::query_as::<_, (String,)>(
194 "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
195 )
196 .bind(token)
197 .bind(&did.0)
198 .fetch_optional(&state.account_pool)
199 .await{
200 Ok(token) => token,
201 Err(err) => {
202 log::error!("Error checking if token is valid: {err}");
203 return Err(StatusCode::BAD_REQUEST);
204 }
205 };
206
207 return if token_found.is_some() {
208 //TODO I think there may be a bug here and need to do some retry logic
209 // First try was erroring, seconds was allowing
210 match sqlx::query(
211 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
212 )
213 .bind(&did.0)
214 .execute(&state.pds_gatekeeper_pool)
215 .await {
216 Ok(_) => {}
217 Err(err) => {
218 log::error!("Error updating email auth: {err}");
219 return Err(StatusCode::BAD_REQUEST);
220 }
221 }
222
223 Ok(StatusCode::OK.into_response())
224 } else {
225 Err(StatusCode::BAD_REQUEST)
226 };
227 }
228 }
229 }
230 // Updating the actual email address by sending it on to the PDS
231 let uri = format!(
232 "{}{}",
233 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
234 );
235 let mut req = axum::http::Request::post(uri);
236 if let Some(req_headers) = req.headers_mut() {
237 req_headers.extend(headers.clone());
238 }
239
240 let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
241 let req = req
242 .body(Body::from(payload_bytes))
243 .map_err(|_| StatusCode::BAD_REQUEST)?;
244
245 let proxied = state
246 .reverse_proxy_client
247 .request(req)
248 .await
249 .map_err(|_| StatusCode::BAD_REQUEST)?
250 .into_response();
251
252 Ok(proxied)
253}
254
255pub async fn get_session(
256 State(state): State<AppState>,
257 req: Request,
258) -> Result<Response<Body>, StatusCode> {
259 match proxy_get_json::<GetSessionResponse>(&state, req, "/xrpc/com.atproto.server.getSession")
260 .await?
261 {
262 ProxiedResult::Parsed {
263 value: mut session, ..
264 } => {
265 let did = session.did.clone();
266 let required_opt = sqlx::query_as::<_, (u8,)>(
267 "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
268 )
269 .bind(&did)
270 .fetch_optional(&state.pds_gatekeeper_pool)
271 .await
272 .map_err(|_| StatusCode::BAD_REQUEST)?;
273
274 let email_auth_factor = match required_opt {
275 Some(row) => row.0 != 0,
276 None => false,
277 };
278
279 session.email_auth_factor = Some(email_auth_factor);
280 Ok(Json(session).into_response())
281 }
282 ProxiedResult::Passthrough(resp) => Ok(resp),
283 }
284}
285
286pub async fn create_account(
287 State(state): State<AppState>,
288 mut req: Request,
289) -> Result<Response<Body>, StatusCode> {
290 //TODO if I add the block of only accounts authenticated just take the body as json here and grab the lxm token. No middle ware is needed
291
292 let uri = format!(
293 "{}{}",
294 state.pds_base_url, "/xrpc/com.atproto.server.createAccount"
295 );
296
297 // Rewrite the URI to point at the upstream PDS; keep headers, method, and body intact
298 *req.uri_mut() = uri.parse().map_err(|_| StatusCode::BAD_REQUEST)?;
299
300 let proxied = state
301 .reverse_proxy_client
302 .request(req)
303 .await
304 .map_err(|_| StatusCode::BAD_REQUEST)?
305 .into_response();
306
307 Ok(proxied)
308}