Alternative ATProto PDS implementation
at oauth 22 kB view raw
1//! Based on https://github.com/blacksky-algorithms/rsky/blob/main/rsky-pds/src/pipethrough.rs 2//! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3//! 4//! Modified for Axum instead of Rocket 5 6use anyhow::{Result, bail}; 7use axum::extract::{FromRequestParts, State}; 8use rsky_identity::IdResolver; 9use rsky_pds::apis::ApiError; 10use rsky_pds::auth_verifier::{AccessOutput, AccessStandard}; 11use rsky_pds::config::{ServerConfig, ServiceConfig, env_to_cfg}; 12use rsky_pds::pipethrough::{OverrideOpts, ProxyHeader, UrlAndAud}; 13use rsky_pds::xrpc_server::types::{HandlerPipeThrough, InvalidRequestError, XRPCError}; 14use rsky_pds::{APP_USER_AGENT, SharedIdResolver, context}; 15// use lazy_static::lazy_static; 16use reqwest::header::{CONTENT_TYPE, HeaderValue}; 17use reqwest::{Client, Method, RequestBuilder, Response}; 18// use rocket::data::ToByteUnit; 19// use rocket::http::{Method, Status}; 20// use rocket::request::{FromRequest, Outcome, Request}; 21// use rocket::{Data, State}; 22use axum::{ 23 body::Bytes, 24 http::{self, HeaderMap}, 25}; 26use rsky_common::{GetServiceEndpointOpts, get_service_endpoint}; 27use rsky_repo::types::Ids; 28use serde::de::DeserializeOwned; 29use serde_json::Value as JsonValue; 30use std::collections::{BTreeMap, HashSet}; 31use std::str::FromStr; 32use std::sync::Arc; 33use std::time::Duration; 34use ubyte::ToByteUnit as _; 35use url::Url; 36 37use crate::serve::AppState; 38 39// pub struct OverrideOpts { 40// pub aud: Option<String>, 41// pub lxm: Option<String>, 42// } 43 44// pub struct UrlAndAud { 45// pub url: Url, 46// pub aud: String, 47// pub lxm: String, 48// } 49 50// pub struct ProxyHeader { 51// pub did: String, 52// pub service_url: String, 53// } 54 55pub struct ProxyRequest { 56 pub headers: BTreeMap<String, String>, 57 pub query: Option<String>, 58 pub path: String, 59 pub method: Method, 60 pub id_resolver: Arc<tokio::sync::RwLock<rsky_identity::IdResolver>>, 61 pub cfg: ServerConfig, 62} 63impl FromRequestParts<AppState> for ProxyRequest { 64 // type Rejection = ApiError; 65 type Rejection = axum::response::Response; 66 67 async fn from_request_parts( 68 parts: &mut axum::http::request::Parts, 69 state: &AppState, 70 ) -> Result<Self, Self::Rejection> { 71 let headers = parts 72 .headers 73 .iter() 74 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) 75 .collect::<BTreeMap<String, String>>(); 76 let query = parts.uri.query().map(|s| s.to_string()); 77 let path = parts.uri.path().to_string(); 78 let method = parts.method.clone(); 79 let id_resolver = state.id_resolver.clone(); 80 // let cfg = state.cfg.clone(); 81 let cfg = env_to_cfg(); // TODO: use state.cfg.clone(); 82 83 Ok(Self { 84 headers, 85 query, 86 path, 87 method, 88 id_resolver, 89 cfg, 90 }) 91 } 92} 93 94// #[rocket::async_trait] 95// impl<'r> FromRequest<'r> for HandlerPipeThrough { 96// type Error = anyhow::Error; 97 98// #[tracing::instrument(skip_all)] 99// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { 100// match AccessStandard::from_request(req).await { 101// Outcome::Success(output) => { 102// let AccessOutput { credentials, .. } = output.access; 103// let requester: Option<String> = match credentials { 104// None => None, 105// Some(credentials) => credentials.did, 106// }; 107// let headers = req.headers().clone().into_iter().fold( 108// BTreeMap::new(), 109// |mut acc: BTreeMap<String, String>, cur| { 110// let _ = acc.insert(cur.name().to_string(), cur.value().to_string()); 111// acc 112// }, 113// ); 114// let proxy_req = ProxyRequest { 115// headers, 116// query: match req.uri().query() { 117// None => None, 118// Some(query) => Some(query.to_string()), 119// }, 120// path: req.uri().path().to_string(), 121// method: req.method(), 122// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(), 123// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(), 124// }; 125// match pipethrough( 126// &proxy_req, 127// requester, 128// OverrideOpts { 129// aud: None, 130// lxm: None, 131// }, 132// ) 133// .await 134// { 135// Ok(res) => Outcome::Success(res), 136// Err(error) => match error.downcast_ref() { 137// Some(InvalidRequestError::XRPCError(xrpc)) => { 138// if let XRPCError::FailedResponse { 139// status, 140// error, 141// message, 142// headers, 143// } = xrpc 144// { 145// tracing::error!( 146// "@LOG: XRPC ERROR Status:{status}; Message: {message:?}; Error: {error:?}; Headers: {headers:?}" 147// ); 148// } 149// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string()))); 150// Outcome::Error((Status::BadRequest, error)) 151// } 152// _ => { 153// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string()))); 154// Outcome::Error((Status::BadRequest, error)) 155// } 156// }, 157// } 158// } 159// Outcome::Error(err) => { 160// req.local_cache(|| Some(ApiError::RuntimeError)); 161// Outcome::Error(( 162// Status::BadRequest, 163// anyhow::Error::new(InvalidRequestError::AuthError(err.1)), 164// )) 165// } 166// _ => panic!("Unexpected outcome during Pipethrough"), 167// } 168// } 169// } 170 171// #[rocket::async_trait] 172// impl<'r> FromRequest<'r> for ProxyRequest<'r> { 173// type Error = anyhow::Error; 174 175// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { 176// let headers = req.headers().clone().into_iter().fold( 177// BTreeMap::new(), 178// |mut acc: BTreeMap<String, String>, cur| { 179// let _ = acc.insert(cur.name().to_string(), cur.value().to_string()); 180// acc 181// }, 182// ); 183// Outcome::Success(Self { 184// headers, 185// query: match req.uri().query() { 186// None => None, 187// Some(query) => Some(query.to_string()), 188// }, 189// path: req.uri().path().to_string(), 190// method: req.method(), 191// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(), 192// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(), 193// }) 194// } 195// } 196 197pub async fn pipethrough( 198 req: &ProxyRequest, 199 requester: Option<String>, 200 override_opts: OverrideOpts, 201) -> Result<HandlerPipeThrough> { 202 let UrlAndAud { 203 url, 204 aud, 205 lxm: nsid, 206 } = format_url_and_aud(req, override_opts.aud).await?; 207 let lxm = override_opts.lxm.unwrap_or(nsid); 208 let headers = format_headers(req, aud, lxm, requester).await?; 209 let req_init = format_req_init(req, url, headers, None)?; 210 let res = make_request(req_init).await?; 211 parse_proxy_res(res).await 212} 213 214pub async fn pipethrough_procedure<T: serde::Serialize>( 215 req: &ProxyRequest, 216 requester: Option<String>, 217 body: Option<T>, 218) -> Result<HandlerPipeThrough> { 219 let UrlAndAud { 220 url, 221 aud, 222 lxm: nsid, 223 } = format_url_and_aud(req, None).await?; 224 let headers = format_headers(req, aud, nsid, requester).await?; 225 let encoded_body: Option<Vec<u8>> = match body { 226 None => None, 227 Some(body) => Some(serde_json::to_string(&body)?.into_bytes()), 228 }; 229 let req_init = format_req_init(req, url, headers, encoded_body)?; 230 let res = make_request(req_init).await?; 231 parse_proxy_res(res).await 232} 233 234#[tracing::instrument(skip_all)] 235pub async fn pipethrough_procedure_post( 236 req: &ProxyRequest, 237 requester: Option<String>, 238 body: Option<Bytes>, 239) -> Result<HandlerPipeThrough, ApiError> { 240 let UrlAndAud { 241 url, 242 aud, 243 lxm: nsid, 244 } = format_url_and_aud(req, None).await?; 245 let headers = format_headers(req, aud, nsid, requester).await?; 246 let encoded_body: Option<JsonValue>; 247 match body { 248 None => encoded_body = None, 249 Some(body) => { 250 // let res = match body.open(50.megabytes()).into_string().await { 251 // Ok(res1) => { 252 // tracing::info!(res1.value); 253 // res1.value 254 // } 255 // Err(error) => { 256 // tracing::error!("{error}"); 257 // return Err(ApiError::RuntimeError); 258 // } 259 // }; 260 let res = String::from_utf8(body.to_vec()).expect("Invalid UTF-8"); 261 262 match serde_json::from_str(res.as_str()) { 263 Ok(res) => { 264 encoded_body = Some(res); 265 } 266 Err(error) => { 267 tracing::error!("{error}"); 268 return Err(ApiError::RuntimeError); 269 } 270 } 271 } 272 }; 273 let req_init = format_req_init_with_value(req, url, headers, encoded_body)?; 274 let res = make_request(req_init).await?; 275 Ok(parse_proxy_res(res).await?) 276} 277 278// Request setup/formatting 279// ------------------- 280 281const REQ_HEADERS_TO_FORWARD: [&str; 4] = [ 282 "accept-language", 283 "content-type", 284 "atproto-accept-labelers", 285 "x-bsky-topics", 286]; 287 288#[tracing::instrument(skip_all)] 289pub async fn format_url_and_aud( 290 req: &ProxyRequest, 291 aud_override: Option<String>, 292) -> Result<UrlAndAud> { 293 let proxy_to = parse_proxy_header(req).await?; 294 let nsid = parse_req_nsid(req); 295 let default_proxy = default_service(req, &nsid).await; 296 let service_url = match proxy_to { 297 Some(ref proxy_to) => { 298 tracing::info!( 299 "@LOG: format_url_and_aud() proxy_to: {:?}", 300 proxy_to.service_url 301 ); 302 Some(proxy_to.service_url.clone()) 303 } 304 None => match default_proxy { 305 Some(ref default_proxy) => Some(default_proxy.url.clone()), 306 None => None, 307 }, 308 }; 309 let aud = match aud_override { 310 Some(_) => aud_override, 311 None => match proxy_to { 312 Some(proxy_to) => Some(proxy_to.did), 313 None => match default_proxy { 314 Some(default_proxy) => Some(default_proxy.did), 315 None => None, 316 }, 317 }, 318 }; 319 match (service_url, aud) { 320 (Some(service_url), Some(aud)) => { 321 let mut url = Url::parse(format!("{0}{1}", service_url, req.path).as_str())?; 322 if let Some(ref params) = req.query { 323 url.set_query(Some(params.as_str())); 324 } 325 if !req.cfg.service.dev_mode && !is_safe_url(url.clone()) { 326 bail!(InvalidRequestError::InvalidServiceUrl(url.to_string())); 327 } 328 Ok(UrlAndAud { 329 url, 330 aud, 331 lxm: nsid, 332 }) 333 } 334 _ => bail!(InvalidRequestError::NoServiceConfigured(req.path.clone())), 335 } 336} 337 338pub async fn format_headers( 339 req: &ProxyRequest, 340 aud: String, 341 lxm: String, 342 requester: Option<String>, 343) -> Result<HeaderMap> { 344 let mut headers: HeaderMap = match requester { 345 Some(requester) => context::service_auth_headers(&requester, &aud, &lxm).await?, 346 None => HeaderMap::new(), 347 }; 348 // forward select headers to upstream services 349 for header in REQ_HEADERS_TO_FORWARD { 350 let val = req.headers.get(header); 351 if let Some(val) = val { 352 headers.insert(header, HeaderValue::from_str(val)?); 353 } 354 } 355 Ok(headers) 356} 357 358pub fn format_req_init( 359 req: &ProxyRequest, 360 url: Url, 361 headers: HeaderMap, 362 body: Option<Vec<u8>>, 363) -> Result<RequestBuilder> { 364 match req.method { 365 Method::GET => { 366 let client = Client::builder() 367 .user_agent(APP_USER_AGENT) 368 .http2_keep_alive_while_idle(true) 369 .http2_keep_alive_timeout(Duration::from_secs(5)) 370 .default_headers(headers) 371 .build()?; 372 Ok(client.get(url)) 373 } 374 Method::HEAD => { 375 let client = Client::builder() 376 .user_agent(APP_USER_AGENT) 377 .http2_keep_alive_while_idle(true) 378 .http2_keep_alive_timeout(Duration::from_secs(5)) 379 .default_headers(headers) 380 .build()?; 381 Ok(client.head(url)) 382 } 383 Method::POST => { 384 let client = Client::builder() 385 .user_agent(APP_USER_AGENT) 386 .http2_keep_alive_while_idle(true) 387 .http2_keep_alive_timeout(Duration::from_secs(5)) 388 .default_headers(headers) 389 .build()?; 390 Ok(client.post(url).body(body.unwrap())) 391 } 392 _ => bail!(InvalidRequestError::MethodNotFound), 393 } 394} 395 396pub fn format_req_init_with_value( 397 req: &ProxyRequest, 398 url: Url, 399 headers: HeaderMap, 400 body: Option<JsonValue>, 401) -> Result<RequestBuilder> { 402 match req.method { 403 Method::GET => { 404 let client = Client::builder() 405 .user_agent(APP_USER_AGENT) 406 .http2_keep_alive_while_idle(true) 407 .http2_keep_alive_timeout(Duration::from_secs(5)) 408 .default_headers(headers) 409 .build()?; 410 Ok(client.get(url)) 411 } 412 Method::HEAD => { 413 let client = Client::builder() 414 .user_agent(APP_USER_AGENT) 415 .http2_keep_alive_while_idle(true) 416 .http2_keep_alive_timeout(Duration::from_secs(5)) 417 .default_headers(headers) 418 .build()?; 419 Ok(client.head(url)) 420 } 421 Method::POST => { 422 let client = Client::builder() 423 .user_agent(APP_USER_AGENT) 424 .http2_keep_alive_while_idle(true) 425 .http2_keep_alive_timeout(Duration::from_secs(5)) 426 .default_headers(headers) 427 .build()?; 428 Ok(client.post(url).json(&body.unwrap())) 429 } 430 _ => bail!(InvalidRequestError::MethodNotFound), 431 } 432} 433 434pub async fn parse_proxy_header(req: &ProxyRequest) -> Result<Option<ProxyHeader>> { 435 let headers = &req.headers; 436 let proxy_to: Option<&String> = headers.get("atproto-proxy"); 437 match proxy_to { 438 None => Ok(None), 439 Some(proxy_to) => { 440 let parts: Vec<&str> = proxy_to.split("#").collect::<Vec<&str>>(); 441 match (parts.get(0), parts.get(1), parts.get(2)) { 442 (Some(did), Some(service_id), None) => { 443 let did = did.to_string(); 444 let mut lock = req.id_resolver.write().await; 445 match lock.did.resolve(did.clone(), None).await? { 446 None => bail!(InvalidRequestError::CannotResolveProxyDid), 447 Some(did_doc) => { 448 match get_service_endpoint( 449 did_doc, 450 GetServiceEndpointOpts { 451 id: format!("#{service_id}"), 452 r#type: None, 453 }, 454 ) { 455 None => bail!(InvalidRequestError::CannotResolveServiceUrl), 456 Some(service_url) => Ok(Some(ProxyHeader { did, service_url })), 457 } 458 } 459 } 460 } 461 (_, None, _) => bail!(InvalidRequestError::NoServiceId), 462 _ => bail!("error parsing atproto-proxy header"), 463 } 464 } 465 } 466} 467 468pub fn parse_req_nsid(req: &ProxyRequest) -> String { 469 let nsid = req.path.as_str().replace("/xrpc/", ""); 470 match nsid.ends_with("/") { 471 false => nsid, 472 true => nsid 473 .trim_end_matches(|c| c == nsid.chars().last().unwrap()) 474 .to_string(), 475 } 476} 477 478// Sending request 479// ------------------- 480#[tracing::instrument(skip_all)] 481pub async fn make_request(req_init: RequestBuilder) -> Result<Response> { 482 let res = req_init.send().await; 483 match res { 484 Err(e) => { 485 tracing::error!("@LOG WARN: pipethrough network error {}", e.to_string()); 486 bail!(InvalidRequestError::XRPCError(XRPCError::UpstreamFailure)) 487 } 488 Ok(res) => match res.error_for_status_ref() { 489 Ok(_) => Ok(res), 490 Err(_) => { 491 let status = res.status().to_string(); 492 let headers = res.headers().clone(); 493 let error_body = res.json::<JsonValue>().await?; 494 bail!(InvalidRequestError::XRPCError(XRPCError::FailedResponse { 495 status, 496 headers, 497 error: match error_body["error"].as_str() { 498 None => None, 499 Some(error_body_error) => Some(error_body_error.to_string()), 500 }, 501 message: match error_body["message"].as_str() { 502 None => None, 503 Some(error_body_message) => Some(error_body_message.to_string()), 504 } 505 })) 506 } 507 }, 508 } 509} 510 511// Response parsing/forwarding 512// ------------------- 513 514const RES_HEADERS_TO_FORWARD: [&str; 4] = [ 515 "content-type", 516 "content-language", 517 "atproto-repo-rev", 518 "atproto-content-labelers", 519]; 520 521pub async fn parse_proxy_res(res: Response) -> Result<HandlerPipeThrough> { 522 let encoding = match res.headers().get(CONTENT_TYPE) { 523 Some(content_type) => content_type.to_str()?, 524 None => "application/json", 525 }; 526 // Release borrow 527 let encoding = encoding.to_string(); 528 let res_headers = RES_HEADERS_TO_FORWARD.into_iter().fold( 529 BTreeMap::new(), 530 |mut acc: BTreeMap<String, String>, cur| { 531 let _ = match res.headers().get(cur) { 532 Some(res_header_val) => acc.insert( 533 cur.to_string(), 534 res_header_val.clone().to_str().unwrap().to_string(), 535 ), 536 None => None, 537 }; 538 acc 539 }, 540 ); 541 let buffer = read_array_buffer_res(res).await?; 542 Ok(HandlerPipeThrough { 543 encoding, 544 buffer, 545 headers: Some(res_headers), 546 }) 547} 548 549// Utils 550// ------------------- 551 552pub async fn default_service(req: &ProxyRequest, nsid: &str) -> Option<ServiceConfig> { 553 let cfg = req.cfg.clone(); 554 match Ids::from_str(nsid) { 555 Ok(Ids::ToolsOzoneTeamAddMember) => cfg.mod_service, 556 Ok(Ids::ToolsOzoneTeamDeleteMember) => cfg.mod_service, 557 Ok(Ids::ToolsOzoneTeamUpdateMember) => cfg.mod_service, 558 Ok(Ids::ToolsOzoneTeamListMembers) => cfg.mod_service, 559 Ok(Ids::ToolsOzoneCommunicationCreateTemplate) => cfg.mod_service, 560 Ok(Ids::ToolsOzoneCommunicationDeleteTemplate) => cfg.mod_service, 561 Ok(Ids::ToolsOzoneCommunicationUpdateTemplate) => cfg.mod_service, 562 Ok(Ids::ToolsOzoneCommunicationListTemplates) => cfg.mod_service, 563 Ok(Ids::ToolsOzoneModerationEmitEvent) => cfg.mod_service, 564 Ok(Ids::ToolsOzoneModerationGetEvent) => cfg.mod_service, 565 Ok(Ids::ToolsOzoneModerationGetRecord) => cfg.mod_service, 566 Ok(Ids::ToolsOzoneModerationGetRepo) => cfg.mod_service, 567 Ok(Ids::ToolsOzoneModerationQueryEvents) => cfg.mod_service, 568 Ok(Ids::ToolsOzoneModerationQueryStatuses) => cfg.mod_service, 569 Ok(Ids::ToolsOzoneModerationSearchRepos) => cfg.mod_service, 570 Ok(Ids::ComAtprotoModerationCreateReport) => cfg.report_service, 571 _ => cfg.bsky_app_view, 572 } 573} 574 575pub fn parse_res<T: DeserializeOwned>(_nsid: String, res: HandlerPipeThrough) -> Result<T> { 576 let buffer = res.buffer; 577 let record = serde_json::from_slice::<T>(buffer.as_slice())?; 578 Ok(record) 579} 580 581#[tracing::instrument(skip_all)] 582pub async fn read_array_buffer_res(res: Response) -> Result<Vec<u8>> { 583 match res.bytes().await { 584 Ok(bytes) => Ok(bytes.to_vec()), 585 Err(err) => { 586 tracing::error!("@LOG WARN: pipethrough network error {}", err.to_string()); 587 bail!("UpstreamFailure") 588 } 589 } 590} 591 592pub fn is_safe_url(url: Url) -> bool { 593 if url.scheme() != "https" { 594 return false; 595 } 596 match url.host_str() { 597 None => false, 598 Some(hostname) if hostname == "localhost" => false, 599 Some(hostname) => { 600 if std::net::IpAddr::from_str(hostname).is_ok() { 601 return false; 602 } 603 true 604 } 605 } 606}