PDS software with bells & whistles you didn’t even know you needed. will move this to its own account when ready.
at main 12 kB view raw
1use std::convert::Infallible; 2 3use crate::api::error::ApiError; 4use crate::api::proxy_client::proxy_client; 5use crate::state::AppState; 6use axum::{ 7 body::Bytes, 8 extract::{RawQuery, Request, State}, 9 handler::Handler, 10 http::{HeaderMap, Method, StatusCode}, 11 response::{IntoResponse, Response}, 12}; 13use futures_util::future::Either; 14use tower::{Service, util::BoxCloneSyncService}; 15use tracing::{error, info, warn}; 16 17const PROTECTED_METHODS: &[&str] = &[ 18 "app.bsky.actor.getPreferences", 19 "app.bsky.actor.putPreferences", 20 "com.atproto.admin.deleteAccount", 21 "com.atproto.admin.disableAccountInvites", 22 "com.atproto.admin.disableInviteCodes", 23 "com.atproto.admin.enableAccountInvites", 24 "com.atproto.admin.getAccountInfo", 25 "com.atproto.admin.getAccountInfos", 26 "com.atproto.admin.getInviteCodes", 27 "com.atproto.admin.getSubjectStatus", 28 "com.atproto.admin.searchAccounts", 29 "com.atproto.admin.sendEmail", 30 "com.atproto.admin.updateAccountEmail", 31 "com.atproto.admin.updateAccountHandle", 32 "com.atproto.admin.updateAccountPassword", 33 "com.atproto.admin.updateSubjectStatus", 34 "com.atproto.identity.getRecommendedDidCredentials", 35 "com.atproto.identity.requestPlcOperationSignature", 36 "com.atproto.identity.signPlcOperation", 37 "com.atproto.identity.submitPlcOperation", 38 "com.atproto.identity.updateHandle", 39 "com.atproto.repo.applyWrites", 40 "com.atproto.repo.createRecord", 41 "com.atproto.repo.deleteRecord", 42 "com.atproto.repo.importRepo", 43 "com.atproto.repo.putRecord", 44 "com.atproto.repo.uploadBlob", 45 "com.atproto.server.activateAccount", 46 "com.atproto.server.checkAccountStatus", 47 "com.atproto.server.confirmEmail", 48 "com.atproto.server.confirmSignup", 49 "com.atproto.server.createAccount", 50 "com.atproto.server.createAppPassword", 51 "com.atproto.server.createInviteCode", 52 "com.atproto.server.createInviteCodes", 53 "com.atproto.server.createSession", 54 "com.atproto.server.createTotpSecret", 55 "com.atproto.server.deactivateAccount", 56 "com.atproto.server.deleteAccount", 57 "com.atproto.server.deletePasskey", 58 "com.atproto.server.deleteSession", 59 "com.atproto.server.describeServer", 60 "com.atproto.server.disableTotp", 61 "com.atproto.server.enableTotp", 62 "com.atproto.server.finishPasskeyRegistration", 63 "com.atproto.server.getAccountInviteCodes", 64 "com.atproto.server.getServiceAuth", 65 "com.atproto.server.getSession", 66 "com.atproto.server.getTotpStatus", 67 "com.atproto.server.listAppPasswords", 68 "com.atproto.server.listPasskeys", 69 "com.atproto.server.refreshSession", 70 "com.atproto.server.regenerateBackupCodes", 71 "com.atproto.server.requestAccountDelete", 72 "com.atproto.server.requestEmailConfirmation", 73 "com.atproto.server.requestEmailUpdate", 74 "com.atproto.server.requestPasswordReset", 75 "com.atproto.server.resendMigrationVerification", 76 "com.atproto.server.resendVerification", 77 "com.atproto.server.reserveSigningKey", 78 "com.atproto.server.resetPassword", 79 "com.atproto.server.revokeAppPassword", 80 "com.atproto.server.startPasskeyRegistration", 81 "com.atproto.server.updateEmail", 82 "com.atproto.server.updatePasskey", 83 "com.atproto.server.verifyMigrationEmail", 84 "com.atproto.sync.getBlob", 85 "com.atproto.sync.getBlocks", 86 "com.atproto.sync.getCheckout", 87 "com.atproto.sync.getHead", 88 "com.atproto.sync.getLatestCommit", 89 "com.atproto.sync.getRecord", 90 "com.atproto.sync.getRepo", 91 "com.atproto.sync.getRepoStatus", 92 "com.atproto.sync.listBlobs", 93 "com.atproto.sync.listRepos", 94 "com.atproto.sync.notifyOfUpdate", 95 "com.atproto.sync.requestCrawl", 96 "com.atproto.sync.subscribeRepos", 97 "com.atproto.temp.checkSignupQueue", 98 "com.atproto.temp.dereferenceScope", 99]; 100 101fn is_protected_method(method: &str) -> bool { 102 PROTECTED_METHODS.contains(&method) 103} 104 105pub struct XrpcProxyLayer { 106 state: AppState, 107} 108 109impl XrpcProxyLayer { 110 pub fn new(state: AppState) -> Self { 111 XrpcProxyLayer { state } 112 } 113} 114 115impl<S> tower_layer::Layer<S> for XrpcProxyLayer { 116 type Service = XrpcProxyingService<S>; 117 118 fn layer(&self, inner: S) -> Self::Service { 119 XrpcProxyingService { 120 inner, 121 // TODO(nel): make our own service here instead of boxing a HandlerService 122 handler: BoxCloneSyncService::new(proxy_handler.with_state(self.state.clone())), 123 } 124 } 125} 126 127#[derive(Clone)] 128pub struct XrpcProxyingService<S> { 129 inner: S, 130 handler: BoxCloneSyncService<Request, Response, Infallible>, 131} 132 133impl<S: Service<Request, Response = Response, Error = Infallible>> Service<Request> 134 for XrpcProxyingService<S> 135{ 136 type Response = Response; 137 138 type Error = Infallible; 139 140 type Future = Either< 141 <BoxCloneSyncService<Request, Response, Infallible> as Service<Request>>::Future, 142 S::Future, 143 >; 144 145 fn poll_ready( 146 &mut self, 147 cx: &mut std::task::Context<'_>, 148 ) -> std::task::Poll<Result<(), Self::Error>> { 149 self.inner.poll_ready(cx) 150 } 151 152 fn call(&mut self, req: Request) -> Self::Future { 153 if req 154 .headers() 155 .contains_key(http::HeaderName::from(jacquard::xrpc::Header::AtprotoProxy)) 156 { 157 let path = req.uri().path(); 158 let method = path.trim_start_matches("/"); 159 160 if is_protected_method(method) { 161 return Either::Right(self.inner.call(req)); 162 } 163 164 // If the age assurance override is set and this is an age assurance call then we dont want to proxy even if the client requests it 165 if std::env::var("PDS_AGE_ASSURANCE_OVERRIDE").is_ok() 166 && (path.ends_with("app.bsky.ageassurance.getState") 167 || path.ends_with("app.bsky.unspecced.getAgeAssuranceState")) 168 { 169 return Either::Right(self.inner.call(req)); 170 } 171 172 Either::Left(self.handler.call(req)) 173 } else { 174 Either::Right(self.inner.call(req)) 175 } 176 } 177} 178 179async fn proxy_handler( 180 State(state): State<AppState>, 181 uri: http::Uri, 182 method_verb: Method, 183 headers: HeaderMap, 184 RawQuery(query): RawQuery, 185 body: Bytes, 186) -> Response { 187 // This layer is nested under /xrpc in an axum router so the extracted uri will look like /<method> and thus we can just strip the / 188 let method = uri.path().trim_start_matches("/"); 189 if is_protected_method(method) { 190 warn!(method = %method, "Attempted to proxy protected method"); 191 return ApiError::InvalidRequest(format!("Cannot proxy protected method: {}", method)) 192 .into_response(); 193 } 194 195 let Some(proxy_header) = headers 196 .get("atproto-proxy") 197 .and_then(|h| h.to_str().ok()) 198 .map(String::from) 199 else { 200 return ApiError::InvalidRequest("Missing required atproto-proxy header".into()) 201 .into_response(); 202 }; 203 204 let did = proxy_header.split('#').next().unwrap_or(&proxy_header); 205 let Some(resolved) = state.did_resolver.resolve_did(did).await else { 206 error!(did = %did, "Could not resolve service DID"); 207 return ApiError::UpstreamFailure.into_response(); 208 }; 209 210 let target_url = match &query { 211 Some(q) => format!("{}/xrpc/{}?{}", resolved.url, method, q), 212 None => format!("{}/xrpc/{}", resolved.url, method), 213 }; 214 info!("Proxying {} request to {}", method_verb, target_url); 215 216 let client = proxy_client(); 217 let mut request_builder = client.request(method_verb.clone(), &target_url); 218 219 let mut auth_header_val = headers.get("Authorization").cloned(); 220 if let Some(extracted) = crate::auth::extract_auth_token_from_header( 221 headers.get("Authorization").and_then(|h| h.to_str().ok()), 222 ) { 223 let token = extracted.token; 224 let dpop_proof = headers.get("DPoP").and_then(|h| h.to_str().ok()); 225 let http_uri = uri.to_string(); 226 227 match crate::auth::validate_token_with_dpop( 228 &state.db, 229 &token, 230 extracted.is_dpop, 231 dpop_proof, 232 method_verb.as_str(), 233 &http_uri, 234 false, 235 false, 236 ) 237 .await 238 { 239 Ok(auth_user) => { 240 if let Err(e) = crate::auth::scope_check::check_rpc_scope( 241 auth_user.is_oauth, 242 auth_user.scope.as_deref(), 243 &resolved.did, 244 method, 245 ) { 246 return e; 247 } 248 249 if let Some(key_bytes) = auth_user.key_bytes { 250 match crate::auth::create_service_token( 251 &auth_user.did, 252 &resolved.did, 253 method, 254 &key_bytes, 255 ) { 256 Ok(new_token) => { 257 if let Ok(val) = 258 axum::http::HeaderValue::from_str(&format!("Bearer {}", new_token)) 259 { 260 auth_header_val = Some(val); 261 } 262 } 263 Err(e) => { 264 warn!("Failed to create service token: {:?}", e); 265 } 266 } 267 } 268 } 269 Err(e) => { 270 warn!("Token validation failed: {:?}", e); 271 if matches!(e, crate::auth::TokenValidationError::TokenExpired) { 272 let is_dpop = extracted.is_dpop; 273 let scheme = if is_dpop { "DPoP" } else { "Bearer" }; 274 let www_auth = format!( 275 "{} error=\"invalid_token\", error_description=\"Token has expired\"", 276 scheme 277 ); 278 let mut response = 279 ApiError::ExpiredToken(Some("Token has expired".into())).into_response(); 280 response 281 .headers_mut() 282 .insert("WWW-Authenticate", www_auth.parse().unwrap()); 283 if is_dpop { 284 let nonce = crate::oauth::verify::generate_dpop_nonce(); 285 response 286 .headers_mut() 287 .insert("DPoP-Nonce", nonce.parse().unwrap()); 288 } 289 return response; 290 } 291 } 292 } 293 } 294 295 if let Some(val) = auth_header_val { 296 request_builder = request_builder.header("Authorization", val); 297 } 298 for header_name in crate::api::proxy_client::HEADERS_TO_FORWARD { 299 if let Some(val) = headers.get(*header_name) { 300 request_builder = request_builder.header(*header_name, val); 301 } 302 } 303 if !body.is_empty() { 304 request_builder = request_builder.body(body); 305 } 306 307 match request_builder.send().await { 308 Ok(resp) => { 309 let status = resp.status(); 310 let headers = resp.headers().clone(); 311 let body = match resp.bytes().await { 312 Ok(b) => b, 313 Err(e) => { 314 error!("Error reading proxy response body: {:?}", e); 315 return (StatusCode::BAD_GATEWAY, "Error reading upstream response") 316 .into_response(); 317 } 318 }; 319 let mut response_builder = Response::builder().status(status); 320 for header_name in crate::api::proxy_client::RESPONSE_HEADERS_TO_FORWARD { 321 if let Some(val) = headers.get(*header_name) { 322 response_builder = response_builder.header(*header_name, val); 323 } 324 } 325 match response_builder.body(axum::body::Body::from(body)) { 326 Ok(r) => r, 327 Err(e) => { 328 error!("Error building proxy response: {:?}", e); 329 (StatusCode::INTERNAL_SERVER_ERROR, "Internal Server Error").into_response() 330 } 331 } 332 } 333 Err(e) => { 334 error!("Error sending proxy request: {:?}", e); 335 if e.is_timeout() { 336 (StatusCode::GATEWAY_TIMEOUT, "Upstream Timeout").into_response() 337 } else { 338 (StatusCode::BAD_GATEWAY, "Upstream Error").into_response() 339 } 340 } 341 } 342}