Alternative ATProto PDS implementation
1//! Based on https://github.com/blacksky-algorithms/rsky/blob/main/rsky-pds/src/pipethrough.rs
2//! blacksky-algorithms/rsky is licensed under the Apache License 2.0
3//!
4//! Modified for Axum instead of Rocket
5
6use anyhow::{Result, bail};
7use axum::extract::{FromRequestParts, State};
8use rsky_identity::IdResolver;
9use rsky_pds::apis::ApiError;
10use rsky_pds::auth_verifier::{AccessOutput, AccessStandard};
11use rsky_pds::config::{ServerConfig, ServiceConfig, env_to_cfg};
12use rsky_pds::pipethrough::{OverrideOpts, ProxyHeader, UrlAndAud};
13use rsky_pds::xrpc_server::types::{HandlerPipeThrough, InvalidRequestError, XRPCError};
14use rsky_pds::{APP_USER_AGENT, SharedIdResolver, context};
15// use lazy_static::lazy_static;
16use reqwest::header::{CONTENT_TYPE, HeaderValue};
17use reqwest::{Client, Method, RequestBuilder, Response};
18// use rocket::data::ToByteUnit;
19// use rocket::http::{Method, Status};
20// use rocket::request::{FromRequest, Outcome, Request};
21// use rocket::{Data, State};
22use axum::{
23 body::Bytes,
24 http::{self, HeaderMap},
25};
26use rsky_common::{GetServiceEndpointOpts, get_service_endpoint};
27use rsky_repo::types::Ids;
28use serde::de::DeserializeOwned;
29use serde_json::Value as JsonValue;
30use std::collections::{BTreeMap, HashSet};
31use std::str::FromStr;
32use std::sync::Arc;
33use std::time::Duration;
34use ubyte::ToByteUnit as _;
35use url::Url;
36
37use crate::serve::AppState;
38
39// pub struct OverrideOpts {
40// pub aud: Option<String>,
41// pub lxm: Option<String>,
42// }
43
44// pub struct UrlAndAud {
45// pub url: Url,
46// pub aud: String,
47// pub lxm: String,
48// }
49
50// pub struct ProxyHeader {
51// pub did: String,
52// pub service_url: String,
53// }
54
55pub struct ProxyRequest {
56 pub headers: BTreeMap<String, String>,
57 pub query: Option<String>,
58 pub path: String,
59 pub method: Method,
60 pub id_resolver: Arc<tokio::sync::RwLock<rsky_identity::IdResolver>>,
61 pub cfg: ServerConfig,
62}
63impl FromRequestParts<AppState> for ProxyRequest {
64 // type Rejection = ApiError;
65 type Rejection = axum::response::Response;
66
67 async fn from_request_parts(
68 parts: &mut axum::http::request::Parts,
69 state: &AppState,
70 ) -> Result<Self, Self::Rejection> {
71 let headers = parts
72 .headers
73 .iter()
74 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
75 .collect::<BTreeMap<String, String>>();
76 let query = parts.uri.query().map(|s| s.to_string());
77 let path = parts.uri.path().to_string();
78 let method = parts.method.clone();
79 let id_resolver = state.id_resolver.clone();
80 // let cfg = state.cfg.clone();
81 let cfg = env_to_cfg(); // TODO: use state.cfg.clone();
82
83 Ok(Self {
84 headers,
85 query,
86 path,
87 method,
88 id_resolver,
89 cfg,
90 })
91 }
92}
93
94// #[rocket::async_trait]
95// impl<'r> FromRequest<'r> for HandlerPipeThrough {
96// type Error = anyhow::Error;
97
98// #[tracing::instrument(skip_all)]
99// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
100// match AccessStandard::from_request(req).await {
101// Outcome::Success(output) => {
102// let AccessOutput { credentials, .. } = output.access;
103// let requester: Option<String> = match credentials {
104// None => None,
105// Some(credentials) => credentials.did,
106// };
107// let headers = req.headers().clone().into_iter().fold(
108// BTreeMap::new(),
109// |mut acc: BTreeMap<String, String>, cur| {
110// let _ = acc.insert(cur.name().to_string(), cur.value().to_string());
111// acc
112// },
113// );
114// let proxy_req = ProxyRequest {
115// headers,
116// query: match req.uri().query() {
117// None => None,
118// Some(query) => Some(query.to_string()),
119// },
120// path: req.uri().path().to_string(),
121// method: req.method(),
122// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(),
123// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
124// };
125// match pipethrough(
126// &proxy_req,
127// requester,
128// OverrideOpts {
129// aud: None,
130// lxm: None,
131// },
132// )
133// .await
134// {
135// Ok(res) => Outcome::Success(res),
136// Err(error) => match error.downcast_ref() {
137// Some(InvalidRequestError::XRPCError(xrpc)) => {
138// if let XRPCError::FailedResponse {
139// status,
140// error,
141// message,
142// headers,
143// } = xrpc
144// {
145// tracing::error!(
146// "@LOG: XRPC ERROR Status:{status}; Message: {message:?}; Error: {error:?}; Headers: {headers:?}"
147// );
148// }
149// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
150// Outcome::Error((Status::BadRequest, error))
151// }
152// _ => {
153// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
154// Outcome::Error((Status::BadRequest, error))
155// }
156// },
157// }
158// }
159// Outcome::Error(err) => {
160// req.local_cache(|| Some(ApiError::RuntimeError));
161// Outcome::Error((
162// Status::BadRequest,
163// anyhow::Error::new(InvalidRequestError::AuthError(err.1)),
164// ))
165// }
166// _ => panic!("Unexpected outcome during Pipethrough"),
167// }
168// }
169// }
170
171// #[rocket::async_trait]
172// impl<'r> FromRequest<'r> for ProxyRequest<'r> {
173// type Error = anyhow::Error;
174
175// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
176// let headers = req.headers().clone().into_iter().fold(
177// BTreeMap::new(),
178// |mut acc: BTreeMap<String, String>, cur| {
179// let _ = acc.insert(cur.name().to_string(), cur.value().to_string());
180// acc
181// },
182// );
183// Outcome::Success(Self {
184// headers,
185// query: match req.uri().query() {
186// None => None,
187// Some(query) => Some(query.to_string()),
188// },
189// path: req.uri().path().to_string(),
190// method: req.method(),
191// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(),
192// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
193// })
194// }
195// }
196
197pub async fn pipethrough(
198 req: &ProxyRequest,
199 requester: Option<String>,
200 override_opts: OverrideOpts,
201) -> Result<HandlerPipeThrough> {
202 let UrlAndAud {
203 url,
204 aud,
205 lxm: nsid,
206 } = format_url_and_aud(req, override_opts.aud).await?;
207 let lxm = override_opts.lxm.unwrap_or(nsid);
208 let headers = format_headers(req, aud, lxm, requester).await?;
209 let req_init = format_req_init(req, url, headers, None)?;
210 let res = make_request(req_init).await?;
211 parse_proxy_res(res).await
212}
213
214pub async fn pipethrough_procedure<T: serde::Serialize>(
215 req: &ProxyRequest,
216 requester: Option<String>,
217 body: Option<T>,
218) -> Result<HandlerPipeThrough> {
219 let UrlAndAud {
220 url,
221 aud,
222 lxm: nsid,
223 } = format_url_and_aud(req, None).await?;
224 let headers = format_headers(req, aud, nsid, requester).await?;
225 let encoded_body: Option<Vec<u8>> = match body {
226 None => None,
227 Some(body) => Some(serde_json::to_string(&body)?.into_bytes()),
228 };
229 let req_init = format_req_init(req, url, headers, encoded_body)?;
230 let res = make_request(req_init).await?;
231 parse_proxy_res(res).await
232}
233
234#[tracing::instrument(skip_all)]
235pub async fn pipethrough_procedure_post(
236 req: &ProxyRequest,
237 requester: Option<String>,
238 body: Option<Bytes>,
239) -> Result<HandlerPipeThrough, ApiError> {
240 let UrlAndAud {
241 url,
242 aud,
243 lxm: nsid,
244 } = format_url_and_aud(req, None).await?;
245 let headers = format_headers(req, aud, nsid, requester).await?;
246 let encoded_body: Option<JsonValue>;
247 match body {
248 None => encoded_body = None,
249 Some(body) => {
250 // let res = match body.open(50.megabytes()).into_string().await {
251 // Ok(res1) => {
252 // tracing::info!(res1.value);
253 // res1.value
254 // }
255 // Err(error) => {
256 // tracing::error!("{error}");
257 // return Err(ApiError::RuntimeError);
258 // }
259 // };
260 let res = String::from_utf8(body.to_vec()).expect("Invalid UTF-8");
261
262 match serde_json::from_str(res.as_str()) {
263 Ok(res) => {
264 encoded_body = Some(res);
265 }
266 Err(error) => {
267 tracing::error!("{error}");
268 return Err(ApiError::RuntimeError);
269 }
270 }
271 }
272 };
273 let req_init = format_req_init_with_value(req, url, headers, encoded_body)?;
274 let res = make_request(req_init).await?;
275 Ok(parse_proxy_res(res).await?)
276}
277
278// Request setup/formatting
279// -------------------
280
281const REQ_HEADERS_TO_FORWARD: [&str; 4] = [
282 "accept-language",
283 "content-type",
284 "atproto-accept-labelers",
285 "x-bsky-topics",
286];
287
288#[tracing::instrument(skip_all)]
289pub async fn format_url_and_aud(
290 req: &ProxyRequest,
291 aud_override: Option<String>,
292) -> Result<UrlAndAud> {
293 let proxy_to = parse_proxy_header(req).await?;
294 let nsid = parse_req_nsid(req);
295 let default_proxy = default_service(req, &nsid).await;
296 let service_url = match proxy_to {
297 Some(ref proxy_to) => {
298 tracing::info!(
299 "@LOG: format_url_and_aud() proxy_to: {:?}",
300 proxy_to.service_url
301 );
302 Some(proxy_to.service_url.clone())
303 }
304 None => match default_proxy {
305 Some(ref default_proxy) => Some(default_proxy.url.clone()),
306 None => None,
307 },
308 };
309 let aud = match aud_override {
310 Some(_) => aud_override,
311 None => match proxy_to {
312 Some(proxy_to) => Some(proxy_to.did),
313 None => match default_proxy {
314 Some(default_proxy) => Some(default_proxy.did),
315 None => None,
316 },
317 },
318 };
319 match (service_url, aud) {
320 (Some(service_url), Some(aud)) => {
321 let mut url = Url::parse(format!("{0}{1}", service_url, req.path).as_str())?;
322 if let Some(ref params) = req.query {
323 url.set_query(Some(params.as_str()));
324 }
325 if !req.cfg.service.dev_mode && !is_safe_url(url.clone()) {
326 bail!(InvalidRequestError::InvalidServiceUrl(url.to_string()));
327 }
328 Ok(UrlAndAud {
329 url,
330 aud,
331 lxm: nsid,
332 })
333 }
334 _ => bail!(InvalidRequestError::NoServiceConfigured(req.path.clone())),
335 }
336}
337
338pub async fn format_headers(
339 req: &ProxyRequest,
340 aud: String,
341 lxm: String,
342 requester: Option<String>,
343) -> Result<HeaderMap> {
344 let mut headers: HeaderMap = match requester {
345 Some(requester) => context::service_auth_headers(&requester, &aud, &lxm).await?,
346 None => HeaderMap::new(),
347 };
348 // forward select headers to upstream services
349 for header in REQ_HEADERS_TO_FORWARD {
350 let val = req.headers.get(header);
351 if let Some(val) = val {
352 headers.insert(header, HeaderValue::from_str(val)?);
353 }
354 }
355 Ok(headers)
356}
357
358pub fn format_req_init(
359 req: &ProxyRequest,
360 url: Url,
361 headers: HeaderMap,
362 body: Option<Vec<u8>>,
363) -> Result<RequestBuilder> {
364 match req.method {
365 Method::GET => {
366 let client = Client::builder()
367 .user_agent(APP_USER_AGENT)
368 .http2_keep_alive_while_idle(true)
369 .http2_keep_alive_timeout(Duration::from_secs(5))
370 .default_headers(headers)
371 .build()?;
372 Ok(client.get(url))
373 }
374 Method::HEAD => {
375 let client = Client::builder()
376 .user_agent(APP_USER_AGENT)
377 .http2_keep_alive_while_idle(true)
378 .http2_keep_alive_timeout(Duration::from_secs(5))
379 .default_headers(headers)
380 .build()?;
381 Ok(client.head(url))
382 }
383 Method::POST => {
384 let client = Client::builder()
385 .user_agent(APP_USER_AGENT)
386 .http2_keep_alive_while_idle(true)
387 .http2_keep_alive_timeout(Duration::from_secs(5))
388 .default_headers(headers)
389 .build()?;
390 Ok(client.post(url).body(body.unwrap()))
391 }
392 _ => bail!(InvalidRequestError::MethodNotFound),
393 }
394}
395
396pub fn format_req_init_with_value(
397 req: &ProxyRequest,
398 url: Url,
399 headers: HeaderMap,
400 body: Option<JsonValue>,
401) -> Result<RequestBuilder> {
402 match req.method {
403 Method::GET => {
404 let client = Client::builder()
405 .user_agent(APP_USER_AGENT)
406 .http2_keep_alive_while_idle(true)
407 .http2_keep_alive_timeout(Duration::from_secs(5))
408 .default_headers(headers)
409 .build()?;
410 Ok(client.get(url))
411 }
412 Method::HEAD => {
413 let client = Client::builder()
414 .user_agent(APP_USER_AGENT)
415 .http2_keep_alive_while_idle(true)
416 .http2_keep_alive_timeout(Duration::from_secs(5))
417 .default_headers(headers)
418 .build()?;
419 Ok(client.head(url))
420 }
421 Method::POST => {
422 let client = Client::builder()
423 .user_agent(APP_USER_AGENT)
424 .http2_keep_alive_while_idle(true)
425 .http2_keep_alive_timeout(Duration::from_secs(5))
426 .default_headers(headers)
427 .build()?;
428 Ok(client.post(url).json(&body.unwrap()))
429 }
430 _ => bail!(InvalidRequestError::MethodNotFound),
431 }
432}
433
434pub async fn parse_proxy_header(req: &ProxyRequest) -> Result<Option<ProxyHeader>> {
435 let headers = &req.headers;
436 let proxy_to: Option<&String> = headers.get("atproto-proxy");
437 match proxy_to {
438 None => Ok(None),
439 Some(proxy_to) => {
440 let parts: Vec<&str> = proxy_to.split("#").collect::<Vec<&str>>();
441 match (parts.get(0), parts.get(1), parts.get(2)) {
442 (Some(did), Some(service_id), None) => {
443 let did = did.to_string();
444 let mut lock = req.id_resolver.write().await;
445 match lock.did.resolve(did.clone(), None).await? {
446 None => bail!(InvalidRequestError::CannotResolveProxyDid),
447 Some(did_doc) => {
448 match get_service_endpoint(
449 did_doc,
450 GetServiceEndpointOpts {
451 id: format!("#{service_id}"),
452 r#type: None,
453 },
454 ) {
455 None => bail!(InvalidRequestError::CannotResolveServiceUrl),
456 Some(service_url) => Ok(Some(ProxyHeader { did, service_url })),
457 }
458 }
459 }
460 }
461 (_, None, _) => bail!(InvalidRequestError::NoServiceId),
462 _ => bail!("error parsing atproto-proxy header"),
463 }
464 }
465 }
466}
467
468pub fn parse_req_nsid(req: &ProxyRequest) -> String {
469 let nsid = req.path.as_str().replace("/xrpc/", "");
470 match nsid.ends_with("/") {
471 false => nsid,
472 true => nsid
473 .trim_end_matches(|c| c == nsid.chars().last().unwrap())
474 .to_string(),
475 }
476}
477
478// Sending request
479// -------------------
480#[tracing::instrument(skip_all)]
481pub async fn make_request(req_init: RequestBuilder) -> Result<Response> {
482 let res = req_init.send().await;
483 match res {
484 Err(e) => {
485 tracing::error!("@LOG WARN: pipethrough network error {}", e.to_string());
486 bail!(InvalidRequestError::XRPCError(XRPCError::UpstreamFailure))
487 }
488 Ok(res) => match res.error_for_status_ref() {
489 Ok(_) => Ok(res),
490 Err(_) => {
491 let status = res.status().to_string();
492 let headers = res.headers().clone();
493 let error_body = res.json::<JsonValue>().await?;
494 bail!(InvalidRequestError::XRPCError(XRPCError::FailedResponse {
495 status,
496 headers,
497 error: match error_body["error"].as_str() {
498 None => None,
499 Some(error_body_error) => Some(error_body_error.to_string()),
500 },
501 message: match error_body["message"].as_str() {
502 None => None,
503 Some(error_body_message) => Some(error_body_message.to_string()),
504 }
505 }))
506 }
507 },
508 }
509}
510
511// Response parsing/forwarding
512// -------------------
513
514const RES_HEADERS_TO_FORWARD: [&str; 4] = [
515 "content-type",
516 "content-language",
517 "atproto-repo-rev",
518 "atproto-content-labelers",
519];
520
521pub async fn parse_proxy_res(res: Response) -> Result<HandlerPipeThrough> {
522 let encoding = match res.headers().get(CONTENT_TYPE) {
523 Some(content_type) => content_type.to_str()?,
524 None => "application/json",
525 };
526 // Release borrow
527 let encoding = encoding.to_string();
528 let res_headers = RES_HEADERS_TO_FORWARD.into_iter().fold(
529 BTreeMap::new(),
530 |mut acc: BTreeMap<String, String>, cur| {
531 let _ = match res.headers().get(cur) {
532 Some(res_header_val) => acc.insert(
533 cur.to_string(),
534 res_header_val.clone().to_str().unwrap().to_string(),
535 ),
536 None => None,
537 };
538 acc
539 },
540 );
541 let buffer = read_array_buffer_res(res).await?;
542 Ok(HandlerPipeThrough {
543 encoding,
544 buffer,
545 headers: Some(res_headers),
546 })
547}
548
549// Utils
550// -------------------
551
552pub async fn default_service(req: &ProxyRequest, nsid: &str) -> Option<ServiceConfig> {
553 let cfg = req.cfg.clone();
554 match Ids::from_str(nsid) {
555 Ok(Ids::ToolsOzoneTeamAddMember) => cfg.mod_service,
556 Ok(Ids::ToolsOzoneTeamDeleteMember) => cfg.mod_service,
557 Ok(Ids::ToolsOzoneTeamUpdateMember) => cfg.mod_service,
558 Ok(Ids::ToolsOzoneTeamListMembers) => cfg.mod_service,
559 Ok(Ids::ToolsOzoneCommunicationCreateTemplate) => cfg.mod_service,
560 Ok(Ids::ToolsOzoneCommunicationDeleteTemplate) => cfg.mod_service,
561 Ok(Ids::ToolsOzoneCommunicationUpdateTemplate) => cfg.mod_service,
562 Ok(Ids::ToolsOzoneCommunicationListTemplates) => cfg.mod_service,
563 Ok(Ids::ToolsOzoneModerationEmitEvent) => cfg.mod_service,
564 Ok(Ids::ToolsOzoneModerationGetEvent) => cfg.mod_service,
565 Ok(Ids::ToolsOzoneModerationGetRecord) => cfg.mod_service,
566 Ok(Ids::ToolsOzoneModerationGetRepo) => cfg.mod_service,
567 Ok(Ids::ToolsOzoneModerationQueryEvents) => cfg.mod_service,
568 Ok(Ids::ToolsOzoneModerationQueryStatuses) => cfg.mod_service,
569 Ok(Ids::ToolsOzoneModerationSearchRepos) => cfg.mod_service,
570 Ok(Ids::ComAtprotoModerationCreateReport) => cfg.report_service,
571 _ => cfg.bsky_app_view,
572 }
573}
574
575pub fn parse_res<T: DeserializeOwned>(_nsid: String, res: HandlerPipeThrough) -> Result<T> {
576 let buffer = res.buffer;
577 let record = serde_json::from_slice::<T>(buffer.as_slice())?;
578 Ok(record)
579}
580
581#[tracing::instrument(skip_all)]
582pub async fn read_array_buffer_res(res: Response) -> Result<Vec<u8>> {
583 match res.bytes().await {
584 Ok(bytes) => Ok(bytes.to_vec()),
585 Err(err) => {
586 tracing::error!("@LOG WARN: pipethrough network error {}", err.to_string());
587 bail!("UpstreamFailure")
588 }
589 }
590}
591
592pub fn is_safe_url(url: Url) -> bool {
593 if url.scheme() != "https" {
594 return false;
595 }
596 match url.host_str() {
597 None => false,
598 Some(hostname) if hostname == "localhost" => false,
599 Some(hostname) => {
600 if std::net::IpAddr::from_str(hostname).is_ok() {
601 return false;
602 }
603 true
604 }
605 }
606}