forked from
baileytownsend.dev/pds-gatekeeper
Microservice to bring 2FA to self hosted PDSes
1use crate::AppState;
2use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check};
3use axum::body::Body;
4use axum::extract::State;
5use axum::http::header::CONTENT_TYPE;
6use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
7use axum::response::{IntoResponse, Response};
8use axum::{Json, extract};
9use serde::{Deserialize, Serialize};
10use tracing::log;
11
12#[derive(Serialize, Deserialize, Clone)]
13pub struct SignInRequest {
14 pub username: String,
15 pub password: String,
16 pub remember: bool,
17 pub locale: String,
18 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
19 pub email_otp: Option<String>,
20}
21
22pub async fn sign_in(
23 State(state): State<AppState>,
24 headers: HeaderMap,
25 Json(mut payload): extract::Json<SignInRequest>,
26) -> Result<Response<Body>, StatusCode> {
27 let identifier = payload.username.clone();
28 let password = payload.password.clone();
29 let auth_factor_token = payload.email_otp.clone();
30
31 match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
32 Ok(result) => match result {
33 AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
34 StatusCode::BAD_REQUEST,
35 "invalid_request",
36 "Invalid identifier or password",
37 ),
38 AuthResult::TwoFactorRequired(masked_email) => {
39 let body_str = match serde_json::to_string(&serde_json::json!({
40 "error": "second_authentication_factor_required",
41 "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email),
42 "type": "emailOtp",
43 "hint": masked_email,
44 })) {
45 Ok(s) => s,
46 Err(_) => return Err(StatusCode::BAD_REQUEST),
47 };
48
49 Response::builder()
50 .status(StatusCode::BAD_REQUEST)
51 .header(CONTENT_TYPE, "application/json")
52 .body(Body::from(body_str))
53 .map_err(|_| StatusCode::BAD_REQUEST)
54 }
55 AuthResult::ProxyThrough => {
56 //No 2FA or already passed
57 let uri = format!(
58 "{}{}",
59 state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in"
60 );
61
62 let mut req = axum::http::Request::post(uri);
63 if let Some(req_headers) = req.headers_mut() {
64 // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
65 copy_filtered_headers(&headers, req_headers);
66 //Setting the content type to application/json manually
67 req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
68 }
69
70 //Clears the email_otp because the pds will reject a request with it.
71 payload.email_otp = None;
72 let payload_bytes =
73 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
74
75 let req = req
76 .body(Body::from(payload_bytes))
77 .map_err(|_| StatusCode::BAD_REQUEST)?;
78
79 let proxied = state
80 .reverse_proxy_client
81 .request(req)
82 .await
83 .map_err(|_| StatusCode::BAD_REQUEST)?
84 .into_response();
85
86 Ok(proxied)
87 }
88 //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same.
89 AuthResult::TokenCheckFailed(_) => oauth_json_error_response(
90 StatusCode::BAD_REQUEST,
91 "invalid_request",
92 "Unable to sign-in due to an unexpected server error",
93 ),
94 },
95 Err(err) => {
96 log::error!(
97 "Error during pre-auth check. This happens on the oauth signin endpoint when trying to decide if the user has access:\n {err}"
98 );
99 oauth_json_error_response(
100 StatusCode::BAD_REQUEST,
101 "pds_gatekeeper_error",
102 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
103 )
104 }
105 }
106}
107
108fn is_disallowed_header(name: &HeaderName) -> bool {
109 // possible problematic headers with proxying
110 matches!(
111 name.as_str(),
112 "connection"
113 | "keep-alive"
114 | "proxy-authenticate"
115 | "proxy-authorization"
116 | "te"
117 | "trailer"
118 | "transfer-encoding"
119 | "upgrade"
120 | "host"
121 | "content-length"
122 | "content-encoding"
123 | "expect"
124 | "accept-encoding"
125 )
126}
127
128fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
129 for (name, value) in src.iter() {
130 if is_disallowed_header(name) {
131 continue;
132 }
133 // Only copy valid headers
134 if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
135 dst.insert(name.clone(), hv);
136 }
137 }
138}