Microservice to bring 2FA to self hosted PDSes
1use crate::AppState; 2use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check}; 3use axum::body::Body; 4use axum::extract::State; 5use axum::http::header::CONTENT_TYPE; 6use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; 7use axum::response::{IntoResponse, Response}; 8use axum::{Json, extract}; 9use serde::{Deserialize, Serialize}; 10use tracing::log; 11 12#[derive(Serialize, Deserialize, Clone)] 13pub struct SignInRequest { 14 pub username: String, 15 pub password: String, 16 #[serde(skip_serializing_if = "Option::is_none")] 17 pub remember: Option<bool>, 18 pub locale: String, 19 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 20 pub email_otp: Option<String>, 21} 22 23pub async fn sign_in( 24 State(state): State<AppState>, 25 headers: HeaderMap, 26 Json(mut payload): extract::Json<SignInRequest>, 27) -> Result<Response<Body>, StatusCode> { 28 let identifier = payload.username.clone(); 29 let password = payload.password.clone(); 30 let auth_factor_token = payload.email_otp.clone(); 31 32 match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { 33 Ok(result) => match result { 34 AuthResult::WrongIdentityOrPassword => oauth_json_error_response( 35 StatusCode::BAD_REQUEST, 36 "invalid_request", 37 "Invalid identifier or password", 38 ), 39 AuthResult::TwoFactorRequired(masked_email) => { 40 let body_str = match serde_json::to_string(&serde_json::json!({ 41 "error": "second_authentication_factor_required", 42 "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email), 43 "type": "emailOtp", 44 "hint": masked_email, 45 })) { 46 Ok(s) => s, 47 Err(_) => return Err(StatusCode::BAD_REQUEST), 48 }; 49 50 Response::builder() 51 .status(StatusCode::BAD_REQUEST) 52 .header(CONTENT_TYPE, "application/json") 53 .body(Body::from(body_str)) 54 .map_err(|_| StatusCode::BAD_REQUEST) 55 } 56 AuthResult::ProxyThrough => { 57 //No 2FA or already passed 58 let uri = format!( 59 "{}{}", 60 state.app_config.pds_base_url, "/@atproto/oauth-provider/~api/sign-in" 61 ); 62 63 let mut req = axum::http::Request::post(uri); 64 if let Some(req_headers) = req.headers_mut() { 65 // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers 66 copy_filtered_headers(&headers, req_headers); 67 //Setting the content type to application/json manually 68 req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 69 } 70 71 //Clears the email_otp because the pds will reject a request with it. 72 payload.email_otp = None; 73 let payload_bytes = 74 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 75 76 let req = req 77 .body(Body::from(payload_bytes)) 78 .map_err(|_| StatusCode::BAD_REQUEST)?; 79 80 let proxied = state 81 .reverse_proxy_client 82 .request(req) 83 .await 84 .map_err(|_| StatusCode::BAD_REQUEST)? 85 .into_response(); 86 87 Ok(proxied) 88 } 89 //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same. 90 AuthResult::TokenCheckFailed(_) => oauth_json_error_response( 91 StatusCode::BAD_REQUEST, 92 "invalid_request", 93 "Unable to sign-in due to an unexpected server error", 94 ), 95 }, 96 Err(err) => { 97 log::error!( 98 "Error during pre-auth check. This happens on the oauth signin endpoint when trying to decide if the user has access:\n {err}" 99 ); 100 oauth_json_error_response( 101 StatusCode::BAD_REQUEST, 102 "pds_gatekeeper_error", 103 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 104 ) 105 } 106 } 107} 108 109fn is_disallowed_header(name: &HeaderName) -> bool { 110 // possible problematic headers with proxying 111 matches!( 112 name.as_str(), 113 "connection" 114 | "keep-alive" 115 | "proxy-authenticate" 116 | "proxy-authorization" 117 | "te" 118 | "trailer" 119 | "transfer-encoding" 120 | "upgrade" 121 | "host" 122 | "content-length" 123 | "content-encoding" 124 | "expect" 125 | "accept-encoding" 126 ) 127} 128 129fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { 130 for (name, value) in src.iter() { 131 if is_disallowed_header(name) { 132 continue; 133 } 134 // Only copy valid headers 135 if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { 136 dst.insert(name.clone(), hv); 137 } 138 } 139}