Microservice to bring 2FA to self hosted PDSes
1use crate::AppState; 2use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check}; 3use axum::body::Body; 4use axum::extract::State; 5use axum::http::header::CONTENT_TYPE; 6use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode}; 7use axum::response::{IntoResponse, Response}; 8use axum::{Json, extract}; 9use serde::{Deserialize, Serialize}; 10use tracing::log; 11 12#[derive(Serialize, Deserialize, Clone)] 13pub struct SignInRequest { 14 pub username: String, 15 pub password: String, 16 pub remember: bool, 17 pub locale: String, 18 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 19 pub email_otp: Option<String>, 20} 21 22pub async fn sign_in( 23 State(state): State<AppState>, 24 headers: HeaderMap, 25 Json(mut payload): extract::Json<SignInRequest>, 26) -> Result<Response<Body>, StatusCode> { 27 let identifier = payload.username.clone(); 28 let password = payload.password.clone(); 29 let auth_factor_token = payload.email_otp.clone(); 30 31 match preauth_check(&state, &identifier, &password, auth_factor_token, true).await { 32 Ok(result) => match result { 33 AuthResult::WrongIdentityOrPassword => oauth_json_error_response( 34 StatusCode::BAD_REQUEST, 35 "invalid_request", 36 "Invalid identifier or password", 37 ), 38 AuthResult::TwoFactorRequired(masked_email) => { 39 let body_str = match serde_json::to_string(&serde_json::json!({ 40 "error": "second_authentication_factor_required", 41 "error_description": format!("emailOtp authentication factor required (hint: {})", masked_email), 42 "type": "emailOtp", 43 "hint": masked_email, 44 })) { 45 Ok(s) => s, 46 Err(_) => return Err(StatusCode::BAD_REQUEST), 47 }; 48 49 Response::builder() 50 .status(StatusCode::BAD_REQUEST) 51 .header(CONTENT_TYPE, "application/json") 52 .body(Body::from(body_str)) 53 .map_err(|_| StatusCode::BAD_REQUEST) 54 } 55 AuthResult::ProxyThrough => { 56 //No 2FA or already passed 57 let uri = format!( 58 "{}{}", 59 state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in" 60 ); 61 62 let mut req = axum::http::Request::post(uri); 63 if let Some(req_headers) = req.headers_mut() { 64 // Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers 65 copy_filtered_headers(&headers, req_headers); 66 //Setting the content type to application/json manually 67 req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 68 } 69 70 //Clears the email_otp because the pds will reject a request with it. 71 payload.email_otp = None; 72 let payload_bytes = 73 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 74 75 let req = req 76 .body(Body::from(payload_bytes)) 77 .map_err(|_| StatusCode::BAD_REQUEST)?; 78 79 let proxied = state 80 .reverse_proxy_client 81 .request(req) 82 .await 83 .map_err(|_| StatusCode::BAD_REQUEST)? 84 .into_response(); 85 86 Ok(proxied) 87 } 88 //Ignoring the type of token check failure. Looks like oauth on the entry treads them the same. 89 AuthResult::TokenCheckFailed(_) => oauth_json_error_response( 90 StatusCode::BAD_REQUEST, 91 "invalid_request", 92 "Unable to sign-in due to an unexpected server error", 93 ), 94 }, 95 Err(err) => { 96 log::error!( 97 "Error during pre-auth check. This happens on the oauth signin endpoint when trying to decide if the user has access:\n {err}" 98 ); 99 oauth_json_error_response( 100 StatusCode::BAD_REQUEST, 101 "pds_gatekeeper_error", 102 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 103 ) 104 } 105 } 106} 107 108fn is_disallowed_header(name: &HeaderName) -> bool { 109 // possible problematic headers with proxying 110 matches!( 111 name.as_str(), 112 "connection" 113 | "keep-alive" 114 | "proxy-authenticate" 115 | "proxy-authorization" 116 | "te" 117 | "trailer" 118 | "transfer-encoding" 119 | "upgrade" 120 | "host" 121 | "content-length" 122 | "content-encoding" 123 | "expect" 124 | "accept-encoding" 125 ) 126} 127 128fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) { 129 for (name, value) in src.iter() { 130 if is_disallowed_header(name) { 131 continue; 132 } 133 // Only copy valid headers 134 if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) { 135 dst.insert(name.clone(), hv); 136 } 137 } 138}