forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::{
3 AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json,
4};
5use crate::middleware::Did;
6use axum::body::Body;
7use axum::extract::State;
8use axum::http::{HeaderMap, StatusCode};
9use axum::response::{IntoResponse, Response};
10use axum::{Extension, Json, debug_handler, extract, extract::Request};
11use serde::{Deserialize, Serialize};
12use serde_json;
13use tracing::log;
14
15#[derive(Serialize, Deserialize, Debug, Clone)]
16#[serde(rename_all = "camelCase")]
17enum AccountStatus {
18 Takendown,
19 Suspended,
20 Deactivated,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone)]
24#[serde(rename_all = "camelCase")]
25struct GetSessionResponse {
26 handle: String,
27 did: String,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 email: Option<String>,
30 #[serde(skip_serializing_if = "Option::is_none")]
31 email_confirmed: Option<bool>,
32 #[serde(skip_serializing_if = "Option::is_none")]
33 email_auth_factor: Option<bool>,
34 #[serde(skip_serializing_if = "Option::is_none")]
35 did_doc: Option<String>,
36 #[serde(skip_serializing_if = "Option::is_none")]
37 active: Option<bool>,
38 #[serde(skip_serializing_if = "Option::is_none")]
39 status: Option<AccountStatus>,
40}
41
42#[derive(Serialize, Deserialize, Debug, Clone)]
43#[serde(rename_all = "camelCase")]
44pub struct UpdateEmailResponse {
45 email: String,
46 #[serde(skip_serializing_if = "Option::is_none")]
47 email_auth_factor: Option<bool>,
48 #[serde(skip_serializing_if = "Option::is_none")]
49 token: Option<String>,
50}
51
52#[allow(dead_code)]
53#[derive(Deserialize, Serialize)]
54#[serde(rename_all = "camelCase")]
55pub struct CreateSessionRequest {
56 identifier: String,
57 password: String,
58 #[serde(skip_serializing_if = "Option::is_none")]
59 auth_factor_token: Option<String>,
60 #[serde(skip_serializing_if = "Option::is_none")]
61 allow_takendown: Option<bool>,
62}
63
64pub async fn create_session(
65 State(state): State<AppState>,
66 headers: HeaderMap,
67 Json(payload): extract::Json<CreateSessionRequest>,
68) -> Result<Response<Body>, StatusCode> {
69 let identifier = payload.identifier.clone();
70 let password = payload.password.clone();
71 let auth_factor_token = payload.auth_factor_token.clone();
72
73 // Run the shared pre-auth logic to validate and check 2FA requirement
74 match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
75 Ok(result) => match result {
76 AuthResult::WrongIdentityOrPassword => json_error_response(
77 StatusCode::UNAUTHORIZED,
78 "AuthenticationRequired",
79 "Invalid identifier or password",
80 ),
81 AuthResult::TwoFactorRequired(_) => {
82 // Email sending step can be handled here if needed in the future.
83 json_error_response(
84 StatusCode::UNAUTHORIZED,
85 "AuthFactorTokenRequired",
86 "A sign in code has been sent to your email address",
87 )
88 }
89 AuthResult::ProxyThrough => {
90 log::info!("Proxying through");
91 //No 2FA or already passed
92 let uri = format!(
93 "{}{}",
94 state.pds_base_url, "/xrpc/com.atproto.server.createSession"
95 );
96
97 let mut req = axum::http::Request::post(uri);
98 if let Some(req_headers) = req.headers_mut() {
99 req_headers.extend(headers.clone());
100 }
101
102 let payload_bytes =
103 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
104 let req = req
105 .body(Body::from(payload_bytes))
106 .map_err(|_| StatusCode::BAD_REQUEST)?;
107
108 let proxied = state
109 .reverse_proxy_client
110 .request(req)
111 .await
112 .map_err(|_| StatusCode::BAD_REQUEST)?
113 .into_response();
114
115 Ok(proxied)
116 }
117 AuthResult::TokenCheckFailed(err) => match err {
118 TokenCheckError::InvalidToken => {
119 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
120 }
121 TokenCheckError::ExpiredToken => {
122 json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
123 }
124 },
125 },
126 Err(err) => {
127 log::error!(
128 "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
129 );
130 json_error_response(
131 StatusCode::INTERNAL_SERVER_ERROR,
132 "InternalServerError",
133 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
134 )
135 }
136 }
137}
138
139#[debug_handler]
140pub async fn update_email(
141 State(state): State<AppState>,
142 Extension(did): Extension<Did>,
143 headers: HeaderMap,
144 Json(payload): extract::Json<UpdateEmailResponse>,
145) -> Result<Response<Body>, StatusCode> {
146 //If email auth is not set at all it is a update email address
147 let email_auth_not_set = payload.email_auth_factor.is_none();
148 //If email auth is set it is to either turn on or off 2fa
149 let email_auth_update = payload.email_auth_factor.unwrap_or(false);
150
151 // Email update asked for
152 if email_auth_update {
153 let email = payload.email.clone();
154 let email_confirmed = sqlx::query_as::<_, (String,)>(
155 "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
156 )
157 .bind(&email)
158 .fetch_optional(&state.account_pool)
159 .await
160 .map_err(|_| StatusCode::BAD_REQUEST)?;
161
162 //Since the email is already confirmed we can enable 2fa
163 return match email_confirmed {
164 None => Err(StatusCode::BAD_REQUEST),
165 Some(did_row) => {
166 let _ = sqlx::query(
167 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
168 )
169 .bind(&did_row.0)
170 .execute(&state.pds_gatekeeper_pool)
171 .await
172 .map_err(|_| StatusCode::BAD_REQUEST)?;
173
174 Ok(StatusCode::OK.into_response())
175 }
176 };
177 }
178
179 // User wants auth turned off
180 if !email_auth_update && !email_auth_not_set {
181 //User wants auth turned off and has a token
182 if let Some(token) = &payload.token {
183 let token_found = sqlx::query_as::<_, (String,)>(
184 "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
185 )
186 .bind(token)
187 .bind(&did.0)
188 .fetch_optional(&state.account_pool)
189 .await
190 .map_err(|_| StatusCode::BAD_REQUEST)?;
191
192 if token_found.is_some() {
193 let _ = sqlx::query(
194 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
195 )
196 .bind(&did.0)
197 .execute(&state.pds_gatekeeper_pool)
198 .await
199 .map_err(|_| StatusCode::BAD_REQUEST)?;
200
201 return Ok(StatusCode::OK.into_response());
202 } else {
203 return Err(StatusCode::BAD_REQUEST);
204 }
205 }
206 }
207
208 // Updating the actual email address by sending it on to the PDS
209 let uri = format!(
210 "{}{}",
211 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
212 );
213 let mut req = axum::http::Request::post(uri);
214 if let Some(req_headers) = req.headers_mut() {
215 req_headers.extend(headers.clone());
216 }
217
218 let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
219 let req = req
220 .body(Body::from(payload_bytes))
221 .map_err(|_| StatusCode::BAD_REQUEST)?;
222
223 let proxied = state
224 .reverse_proxy_client
225 .request(req)
226 .await
227 .map_err(|_| StatusCode::BAD_REQUEST)?
228 .into_response();
229
230 Ok(proxied)
231}
232
233pub async fn get_session(
234 State(state): State<AppState>,
235 req: Request,
236) -> Result<Response<Body>, StatusCode> {
237 match proxy_get_json::<GetSessionResponse>(&state, req, "/xrpc/com.atproto.server.getSession")
238 .await?
239 {
240 ProxiedResult::Parsed {
241 value: mut session, ..
242 } => {
243 let did = session.did.clone();
244 let required_opt = sqlx::query_as::<_, (u8,)>(
245 "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
246 )
247 .bind(&did)
248 .fetch_optional(&state.pds_gatekeeper_pool)
249 .await
250 .map_err(|_| StatusCode::BAD_REQUEST)?;
251
252 let email_auth_factor = match required_opt {
253 Some(row) => row.0 != 0,
254 None => false,
255 };
256
257 session.email_auth_factor = Some(email_auth_factor);
258 Ok(Json(session).into_response())
259 }
260 ProxiedResult::Passthrough(resp) => Ok(resp),
261 }
262}