forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check};
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::header::CONTENT_TYPE;
6use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
7use axum::response::{IntoResponse, Response};
8use axum::{Json, extract};
9use serde::{Deserialize, Serialize};
10use tracing::log;
11
12#[derive(Serialize, Deserialize, Clone)]
13pub struct SignInRequest {
14 pub username: String,
15 pub password: String,
16 #[serde(skip_serializing_if = "Option::is_none")]
17 pub remember: Option<bool>,
18 pub locale: String,
19 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
20 pub email_otp: Option<String>,
21}
22
23pub async fn sign_in(
24 State(state): State<AppState>,
25 headers: HeaderMap,
26 Json(mut payload): extract::Json<SignInRequest>,
27) -> Result<Response<Body>, StatusCode> {
28 let identifier = payload.username.clone();
29 let password = payload.password.clone();
30 let auth_factor_token = payload.email_otp.clone();
31
32 match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
33 Ok(result) => match result {
34 AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
35 StatusCode::BAD_REQUEST,
36 "invalid_request",
37 "Invalid identifier or password",
38 ),
39 AuthResult::TwoFactorRequired(masked_email) => {
40 let body_str = match serde_json::to_string(&serde_json::json!({
41 "error": "second_authentication_factor_required",
42 "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email),
43 "type": "emailOtp",
44 "hint": masked_email,
45 })) {
46 Ok(s) => s,
47 Err(_) => return Err(StatusCode::BAD_REQUEST),
48 };
49
50 Response::builder()
51 .status(StatusCode::BAD_REQUEST)
52 .header(CONTENT_TYPE, "application/json")
53 .body(Body::from(body_str))
54 .map_err(|_| StatusCode::BAD_REQUEST)
55 }
56 AuthResult::ProxyThrough => {
57 //No 2FA or already passed
58 let uri = format!(
59 "{}{}",
60 state.app_config.pds_base_url, "/@atproto/oauth-provider/~api/sign-in"
61 );
62
63 let mut req = axum::http::Request::post(uri);
64 if let Some(req_headers) = req.headers_mut() {
65 // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
66 copy_filtered_headers(&headers, req_headers);
67 //Setting the content type to application/json manually
68 req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
69 }
70
71 //Clears the email_otp because the pds will reject a request with it.
72 payload.email_otp = None;
73 let payload_bytes =
74 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
75
76 let req = req
77 .body(Body::from(payload_bytes))
78 .map_err(|_| StatusCode::BAD_REQUEST)?;
79
80 let proxied = state
81 .reverse_proxy_client
82 .request(req)
83 .await
84 .map_err(|_| StatusCode::BAD_REQUEST)?
85 .into_response();
86
87 Ok(proxied)
88 }
89 //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same.
90 AuthResult::TokenCheckFailed(_) => oauth_json_error_response(
91 StatusCode::BAD_REQUEST,
92 "invalid_request",
93 "Unable to sign-in due to an unexpected server error",
94 ),
95 },
96 Err(err) => {
97 log::error!(
98 "Error during pre-auth check. This happens on the oauth signin endpoint when trying to decide if the user has access:\n {err}"
99 );
100 oauth_json_error_response(
101 StatusCode::BAD_REQUEST,
102 "pds_gatekeeper_error",
103 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
104 )
105 }
106 }
107}
108
109fn is_disallowed_header(name: &HeaderName) -> bool {
110 // possible problematic headers with proxying
111 matches!(
112 name.as_str(),
113 "connection"
114 | "keep-alive"
115 | "proxy-authenticate"
116 | "proxy-authorization"
117 | "te"
118 | "trailer"
119 | "transfer-encoding"
120 | "upgrade"
121 | "host"
122 | "content-length"
123 | "content-encoding"
124 | "expect"
125 | "accept-encoding"
126 )
127}
128
129fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
130 for (name, value) in src.iter() {
131 if is_disallowed_header(name) {
132 continue;
133 }
134 // Only copy valid headers
135 if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
136 dst.insert(name.clone(), hv);
137 }
138 }
139}