Microservice to bring 2FA to self hosted PDSes
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

at feature/2faCodeGeneration 262 lines 9.4 kB view raw
1use crate::AppState; 2use crate::helpers::{ 3 AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json, 4}; 5use crate::middleware::Did; 6use axum::body::Body; 7use axum::extract::State; 8use axum::http::{HeaderMap, StatusCode}; 9use axum::response::{IntoResponse, Response}; 10use axum::{Extension, Json, debug_handler, extract, extract::Request}; 11use serde::{Deserialize, Serialize}; 12use serde_json; 13use tracing::log; 14 15#[derive(Serialize, Deserialize, Debug, Clone)] 16#[serde(rename_all = "camelCase")] 17enum AccountStatus { 18 Takendown, 19 Suspended, 20 Deactivated, 21} 22 23#[derive(Serialize, Deserialize, Debug, Clone)] 24#[serde(rename_all = "camelCase")] 25struct GetSessionResponse { 26 handle: String, 27 did: String, 28 #[serde(skip_serializing_if = "Option::is_none")] 29 email: Option<String>, 30 #[serde(skip_serializing_if = "Option::is_none")] 31 email_confirmed: Option<bool>, 32 #[serde(skip_serializing_if = "Option::is_none")] 33 email_auth_factor: Option<bool>, 34 #[serde(skip_serializing_if = "Option::is_none")] 35 did_doc: Option<String>, 36 #[serde(skip_serializing_if = "Option::is_none")] 37 active: Option<bool>, 38 #[serde(skip_serializing_if = "Option::is_none")] 39 status: Option<AccountStatus>, 40} 41 42#[derive(Serialize, Deserialize, Debug, Clone)] 43#[serde(rename_all = "camelCase")] 44pub struct UpdateEmailResponse { 45 email: String, 46 #[serde(skip_serializing_if = "Option::is_none")] 47 email_auth_factor: Option<bool>, 48 #[serde(skip_serializing_if = "Option::is_none")] 49 token: Option<String>, 50} 51 52#[allow(dead_code)] 53#[derive(Deserialize, Serialize)] 54#[serde(rename_all = "camelCase")] 55pub struct CreateSessionRequest { 56 identifier: String, 57 password: String, 58 #[serde(skip_serializing_if = "Option::is_none")] 59 auth_factor_token: Option<String>, 60 #[serde(skip_serializing_if = "Option::is_none")] 61 allow_takendown: Option<bool>, 62} 63 64pub async fn create_session( 65 State(state): State<AppState>, 66 headers: HeaderMap, 67 Json(payload): extract::Json<CreateSessionRequest>, 68) -> Result<Response<Body>, StatusCode> { 69 let identifier = payload.identifier.clone(); 70 let password = payload.password.clone(); 71 let auth_factor_token = payload.auth_factor_token.clone(); 72 73 // Run the shared pre-auth logic to validate and check 2FA requirement 74 match preauth_check(&state, &identifier, &password, auth_factor_token, false).await { 75 Ok(result) => match result { 76 AuthResult::WrongIdentityOrPassword => json_error_response( 77 StatusCode::UNAUTHORIZED, 78 "AuthenticationRequired", 79 "Invalid identifier or password", 80 ), 81 AuthResult::TwoFactorRequired(_) => { 82 // Email sending step can be handled here if needed in the future. 83 json_error_response( 84 StatusCode::UNAUTHORIZED, 85 "AuthFactorTokenRequired", 86 "A sign in code has been sent to your email address", 87 ) 88 } 89 AuthResult::ProxyThrough => { 90 log::info!("Proxying through"); 91 //No 2FA or already passed 92 let uri = format!( 93 "{}{}", 94 state.pds_base_url, "/xrpc/com.atproto.server.createSession" 95 ); 96 97 let mut req = axum::http::Request::post(uri); 98 if let Some(req_headers) = req.headers_mut() { 99 req_headers.extend(headers.clone()); 100 } 101 102 let payload_bytes = 103 serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 104 let req = req 105 .body(Body::from(payload_bytes)) 106 .map_err(|_| StatusCode::BAD_REQUEST)?; 107 108 let proxied = state 109 .reverse_proxy_client 110 .request(req) 111 .await 112 .map_err(|_| StatusCode::BAD_REQUEST)? 113 .into_response(); 114 115 Ok(proxied) 116 } 117 AuthResult::TokenCheckFailed(err) => match err { 118 TokenCheckError::InvalidToken => { 119 json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid") 120 } 121 TokenCheckError::ExpiredToken => { 122 json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired") 123 } 124 }, 125 }, 126 Err(err) => { 127 log::error!( 128 "Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}" 129 ); 130 json_error_response( 131 StatusCode::INTERNAL_SERVER_ERROR, 132 "InternalServerError", 133 "This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.", 134 ) 135 } 136 } 137} 138 139#[debug_handler] 140pub async fn update_email( 141 State(state): State<AppState>, 142 Extension(did): Extension<Did>, 143 headers: HeaderMap, 144 Json(payload): extract::Json<UpdateEmailResponse>, 145) -> Result<Response<Body>, StatusCode> { 146 //If email auth is not set at all it is a update email address 147 let email_auth_not_set = payload.email_auth_factor.is_none(); 148 //If email auth is set it is to either turn on or off 2fa 149 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 150 151 // Email update asked for 152 if email_auth_update { 153 let email = payload.email.clone(); 154 let email_confirmed = sqlx::query_as::<_, (String,)>( 155 "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", 156 ) 157 .bind(&email) 158 .fetch_optional(&state.account_pool) 159 .await 160 .map_err(|_| StatusCode::BAD_REQUEST)?; 161 162 //Since the email is already confirmed we can enable 2fa 163 return match email_confirmed { 164 None => Err(StatusCode::BAD_REQUEST), 165 Some(did_row) => { 166 let _ = sqlx::query( 167 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", 168 ) 169 .bind(&did_row.0) 170 .execute(&state.pds_gatekeeper_pool) 171 .await 172 .map_err(|_| StatusCode::BAD_REQUEST)?; 173 174 Ok(StatusCode::OK.into_response()) 175 } 176 }; 177 } 178 179 // User wants auth turned off 180 if !email_auth_update && !email_auth_not_set { 181 //User wants auth turned off and has a token 182 if let Some(token) = &payload.token { 183 let token_found = sqlx::query_as::<_, (String,)>( 184 "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", 185 ) 186 .bind(token) 187 .bind(&did.0) 188 .fetch_optional(&state.account_pool) 189 .await 190 .map_err(|_| StatusCode::BAD_REQUEST)?; 191 192 if token_found.is_some() { 193 let _ = sqlx::query( 194 "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", 195 ) 196 .bind(&did.0) 197 .execute(&state.pds_gatekeeper_pool) 198 .await 199 .map_err(|_| StatusCode::BAD_REQUEST)?; 200 201 return Ok(StatusCode::OK.into_response()); 202 } else { 203 return Err(StatusCode::BAD_REQUEST); 204 } 205 } 206 } 207 208 // Updating the actual email address by sending it on to the PDS 209 let uri = format!( 210 "{}{}", 211 state.pds_base_url, "/xrpc/com.atproto.server.updateEmail" 212 ); 213 let mut req = axum::http::Request::post(uri); 214 if let Some(req_headers) = req.headers_mut() { 215 req_headers.extend(headers.clone()); 216 } 217 218 let payload_bytes = serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?; 219 let req = req 220 .body(Body::from(payload_bytes)) 221 .map_err(|_| StatusCode::BAD_REQUEST)?; 222 223 let proxied = state 224 .reverse_proxy_client 225 .request(req) 226 .await 227 .map_err(|_| StatusCode::BAD_REQUEST)? 228 .into_response(); 229 230 Ok(proxied) 231} 232 233pub async fn get_session( 234 State(state): State<AppState>, 235 req: Request, 236) -> Result<Response<Body>, StatusCode> { 237 match proxy_get_json::<GetSessionResponse>(&state, req, "/xrpc/com.atproto.server.getSession") 238 .await? 239 { 240 ProxiedResult::Parsed { 241 value: mut session, .. 242 } => { 243 let did = session.did.clone(); 244 let required_opt = sqlx::query_as::<_, (u8,)>( 245 "SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1", 246 ) 247 .bind(&did) 248 .fetch_optional(&state.pds_gatekeeper_pool) 249 .await 250 .map_err(|_| StatusCode::BAD_REQUEST)?; 251 252 let email_auth_factor = match required_opt { 253 Some(row) => row.0 != 0, 254 None => false, 255 }; 256 257 session.email_auth_factor = Some(email_auth_factor); 258 Ok(Json(session).into_response()) 259 } 260 ProxiedResult::Passthrough(resp) => Ok(resp), 261 } 262}