Constellation, Spacedust, Slingshot, UFOs: atproto crates and services for microcosm

Compare changes

Choose any two refs to compare.

+1076 -390
+1
Cargo.lock
··· 5101 name = "slingshot" 5102 version = "0.1.0" 5103 dependencies = [ 5104 "atrium-api 0.25.4 (git+https://github.com/uniphil/atrium.git?branch=fix%2Fresolve-handle-https-accept-whitespace)", 5105 "atrium-common 0.1.2 (git+https://github.com/uniphil/atrium.git?branch=fix%2Fresolve-handle-https-accept-whitespace)", 5106 "atrium-identity 0.1.5 (git+https://github.com/uniphil/atrium.git?branch=fix%2Fresolve-handle-https-accept-whitespace)",
··· 5101 name = "slingshot" 5102 version = "0.1.0" 5103 dependencies = [ 5104 + "async-stream", 5105 "atrium-api 0.25.4 (git+https://github.com/uniphil/atrium.git?branch=fix%2Fresolve-handle-https-accept-whitespace)", 5106 "atrium-common 0.1.2 (git+https://github.com/uniphil/atrium.git?branch=fix%2Fresolve-handle-https-accept-whitespace)", 5107 "atrium-identity 0.1.5 (git+https://github.com/uniphil/atrium.git?branch=fix%2Fresolve-handle-https-accept-whitespace)",
+1
slingshot/Cargo.toml
··· 4 edition = "2024" 5 6 [dependencies] 7 atrium-api = { git = "https://github.com/uniphil/atrium.git", branch = "fix/resolve-handle-https-accept-whitespace", default-features = false } 8 atrium-common = { git = "https://github.com/uniphil/atrium.git", branch = "fix/resolve-handle-https-accept-whitespace" } 9 atrium-identity = { git = "https://github.com/uniphil/atrium.git", branch = "fix/resolve-handle-https-accept-whitespace" }
··· 4 edition = "2024" 5 6 [dependencies] 7 + async-stream = "0.3.6" 8 atrium-api = { git = "https://github.com/uniphil/atrium.git", branch = "fix/resolve-handle-https-accept-whitespace", default-features = false } 9 atrium-common = { git = "https://github.com/uniphil/atrium.git", branch = "fix/resolve-handle-https-accept-whitespace" } 10 atrium-identity = { git = "https://github.com/uniphil/atrium.git", branch = "fix/resolve-handle-https-accept-whitespace" }
+8
slingshot/src/error.rs
··· 100 UrlParseError(#[from] url::ParseError), 101 #[error(transparent)] 102 ReqwestError(#[from] reqwest::Error), 103 }
··· 100 UrlParseError(#[from] url::ParseError), 101 #[error(transparent)] 102 ReqwestError(#[from] reqwest::Error), 103 + #[error(transparent)] 104 + InvalidHeader(#[from] reqwest::header::InvalidHeaderValue), 105 + #[error(transparent)] 106 + IdentityError(#[from] IdentityError), 107 + #[error("upstream service could not be resolved")] 108 + ServiceNotFound, 109 + #[error("upstream service was found but no services matched")] 110 + ServiceNotMatched, 111 }
+208 -16
slingshot/src/identity.rs
··· 17 18 use crate::error::IdentityError; 19 use atrium_api::{ 20 - did_doc::DidDocument, 21 types::string::{Did, Handle}, 22 }; 23 use atrium_common::resolver::Resolver; ··· 41 pub enum IdentityKey { 42 Handle(Handle), 43 Did(Did), 44 } 45 46 impl IdentityKey { ··· 48 let s = match self { 49 IdentityKey::Handle(h) => h.as_str(), 50 IdentityKey::Did(d) => d.as_str(), 51 }; 52 std::mem::size_of::<Self>() + std::mem::size_of_val(s) 53 } ··· 59 #[derive(Debug, Serialize, Deserialize)] 60 enum IdentityData { 61 NotFound, 62 - Did(Did), 63 - Doc(PartialMiniDoc), 64 } 65 66 impl IdentityVal { ··· 71 IdentityData::Did(d) => std::mem::size_of_val(d.as_str()), 72 IdentityData::Doc(d) => { 73 std::mem::size_of_val(d.unverified_handle.as_str()) 74 - + std::mem::size_of_val(d.pds.as_str()) 75 - + std::mem::size_of_val(d.signing_key.as_str()) 76 } 77 }; 78 wrapping + inner ··· 168 } 169 } 170 171 /// multi-producer *single-consumer* queue structures (wrap in arc-mutex plz) 172 /// 173 /// the hashset allows testing for presense of items in the queue. ··· 296 let now = UtcDateTime::now(); 297 let IdentityVal(last_fetch, data) = entry.value(); 298 match data { 299 - IdentityData::Doc(_) => { 300 - log::error!("identity value mixup: got a doc from a handle key (should be a did)"); 301 - Err(IdentityError::IdentityValTypeMixup(handle.to_string())) 302 - } 303 IdentityData::NotFound => { 304 if (now - *last_fetch) >= MIN_NOT_FOUND_TTL { 305 metrics::counter!("identity_handle_refresh_queued", "reason" => "ttl", "found" => "false").increment(1); ··· 313 self.queue_refresh(key).await; 314 } 315 Ok(Some(did.clone())) 316 } 317 } 318 } ··· 362 let now = UtcDateTime::now(); 363 let IdentityVal(last_fetch, data) = entry.value(); 364 match data { 365 - IdentityData::Did(_) => { 366 - log::error!("identity value mixup: got a did from a did key (should be a doc)"); 367 - Err(IdentityError::IdentityValTypeMixup(did.to_string())) 368 - } 369 IdentityData::NotFound => { 370 if (now - *last_fetch) >= MIN_NOT_FOUND_TTL { 371 metrics::counter!("identity_did_refresh_queued", "reason" => "ttl", "found" => "false").increment(1); ··· 373 } 374 Ok(None) 375 } 376 - IdentityData::Doc(mini_did) => { 377 if (now - *last_fetch) >= MIN_TTL { 378 metrics::counter!("identity_did_refresh_queued", "reason" => "ttl", "found" => "true").increment(1); 379 self.queue_refresh(key).await; 380 } 381 - Ok(Some(mini_did.clone())) 382 } 383 } 384 } ··· 519 log::warn!( 520 "refreshed did doc failed: wrong did doc id. dropping refresh." 521 ); 522 continue; 523 } 524 let mini_doc = match did_doc.try_into() { ··· 526 Err(e) => { 527 metrics::counter!("identity_did_refresh", "success" => "false", "reason" => "bad doc").increment(1); 528 log::warn!( 529 - "converting mini doc failed: {e:?}. dropping refresh." 530 ); 531 continue; 532 } 533 }; ··· 554 } 555 556 self.complete_refresh(&task_key).await?; // failures are bugs, so break loop 557 } 558 } 559 }
··· 17 18 use crate::error::IdentityError; 19 use atrium_api::{ 20 + did_doc::{DidDocument, Service as DidDocServic}, 21 types::string::{Did, Handle}, 22 }; 23 use atrium_common::resolver::Resolver; ··· 41 pub enum IdentityKey { 42 Handle(Handle), 43 Did(Did), 44 + ServiceDid(Did), 45 } 46 47 impl IdentityKey { ··· 49 let s = match self { 50 IdentityKey::Handle(h) => h.as_str(), 51 IdentityKey::Did(d) => d.as_str(), 52 + IdentityKey::ServiceDid(d) => d.as_str(), 53 }; 54 std::mem::size_of::<Self>() + std::mem::size_of_val(s) 55 } ··· 61 #[derive(Debug, Serialize, Deserialize)] 62 enum IdentityData { 63 NotFound, 64 + Did(Did), // from handle 65 + Doc(PartialMiniDoc), // from did 66 + ServiceDoc(MiniServiceDoc), // from service did 67 } 68 69 impl IdentityVal { ··· 74 IdentityData::Did(d) => std::mem::size_of_val(d.as_str()), 75 IdentityData::Doc(d) => { 76 std::mem::size_of_val(d.unverified_handle.as_str()) 77 + + std::mem::size_of_val(&d.pds) 78 + + std::mem::size_of_val(&d.signing_key) 79 + } 80 + IdentityData::ServiceDoc(d) => { 81 + let mut s = std::mem::size_of::<MiniServiceDoc>(); 82 + s += std::mem::size_of_val(&d.services); 83 + for sv in &d.services { 84 + s += std::mem::size_of_val(&sv.full_id); 85 + s += std::mem::size_of_val(&sv.r#type); 86 + s += std::mem::size_of_val(&sv.endpoint); 87 + } 88 + s 89 } 90 }; 91 wrapping + inner ··· 181 } 182 } 183 184 + /// Simplified info from service DID docs 185 + #[derive(Debug, Clone, Serialize, Deserialize)] 186 + pub struct MiniServiceDoc { 187 + services: Vec<MiniService>, 188 + } 189 + 190 + impl MiniServiceDoc { 191 + pub fn get(&self, id_fragment: &str, service_type: Option<&str>) -> Option<&MiniService> { 192 + self.services.iter().find(|ms| { 193 + ms.full_id.ends_with(id_fragment) 194 + && service_type.map(|t| t == ms.r#type).unwrap_or(true) 195 + }) 196 + } 197 + } 198 + 199 + /// The corresponding service info 200 + #[derive(Debug, Clone, Serialize, Deserialize)] 201 + pub struct MiniService { 202 + /// The full id 203 + /// 204 + /// for informational purposes only -- services are deduplicated by id fragment 205 + full_id: String, 206 + r#type: String, 207 + /// HTTP endpoint for the actual service 208 + pub endpoint: String, 209 + } 210 + 211 + impl TryFrom<DidDocument> for MiniServiceDoc { 212 + type Error = String; 213 + fn try_from(did_doc: DidDocument) -> Result<Self, Self::Error> { 214 + let mut services = Vec::new(); 215 + let mut seen = HashSet::new(); 216 + 217 + for DidDocServic { 218 + id, 219 + r#type, 220 + service_endpoint, 221 + } in did_doc.service.unwrap_or(vec![]) 222 + { 223 + let Some((_, id_fragment)) = id.rsplit_once('#') else { 224 + continue; 225 + }; 226 + if !seen.insert((id_fragment.to_string(), r#type.clone())) { 227 + continue; 228 + } 229 + services.push(MiniService { 230 + full_id: id, 231 + r#type, 232 + endpoint: service_endpoint, 233 + }); 234 + } 235 + 236 + Ok(Self { services }) 237 + } 238 + } 239 + 240 /// multi-producer *single-consumer* queue structures (wrap in arc-mutex plz) 241 /// 242 /// the hashset allows testing for presense of items in the queue. ··· 365 let now = UtcDateTime::now(); 366 let IdentityVal(last_fetch, data) = entry.value(); 367 match data { 368 IdentityData::NotFound => { 369 if (now - *last_fetch) >= MIN_NOT_FOUND_TTL { 370 metrics::counter!("identity_handle_refresh_queued", "reason" => "ttl", "found" => "false").increment(1); ··· 378 self.queue_refresh(key).await; 379 } 380 Ok(Some(did.clone())) 381 + } 382 + _ => { 383 + log::error!("identity value mixup: got a doc from a handle key (should be a did)"); 384 + Err(IdentityError::IdentityValTypeMixup(handle.to_string())) 385 } 386 } 387 } ··· 431 let now = UtcDateTime::now(); 432 let IdentityVal(last_fetch, data) = entry.value(); 433 match data { 434 IdentityData::NotFound => { 435 if (now - *last_fetch) >= MIN_NOT_FOUND_TTL { 436 metrics::counter!("identity_did_refresh_queued", "reason" => "ttl", "found" => "false").increment(1); ··· 438 } 439 Ok(None) 440 } 441 + IdentityData::Doc(mini_doc) => { 442 if (now - *last_fetch) >= MIN_TTL { 443 metrics::counter!("identity_did_refresh_queued", "reason" => "ttl", "found" => "true").increment(1); 444 self.queue_refresh(key).await; 445 } 446 + Ok(Some(mini_doc.clone())) 447 + } 448 + _ => { 449 + log::error!("identity value mixup: got a doc from a handle key (should be a did)"); 450 + Err(IdentityError::IdentityValTypeMixup(did.to_string())) 451 + } 452 + } 453 + } 454 + 455 + /// Fetch (and cache) a service mini doc from a did 456 + pub async fn did_to_mini_service_doc( 457 + &self, 458 + did: &Did, 459 + ) -> Result<Option<MiniServiceDoc>, IdentityError> { 460 + let key = IdentityKey::ServiceDid(did.clone()); 461 + metrics::counter!("slingshot_get_service_did_doc").increment(1); 462 + let entry = self 463 + .cache 464 + .get_or_fetch(&key, { 465 + let did = did.clone(); 466 + let resolver = self.did_resolver.clone(); 467 + || async move { 468 + let t0 = Instant::now(); 469 + let (res, success) = match resolver.resolve(&did).await { 470 + Ok(did_doc) if did_doc.id != did.to_string() => ( 471 + // TODO: fix in atrium: should verify id is did 472 + Err(IdentityError::BadDidDoc( 473 + "did doc's id did not match did".to_string(), 474 + )), 475 + "false", 476 + ), 477 + Ok(did_doc) => match did_doc.try_into() { 478 + Ok(mini_service_doc) => ( 479 + Ok(IdentityVal( 480 + UtcDateTime::now(), 481 + IdentityData::ServiceDoc(mini_service_doc), 482 + )), 483 + "true", 484 + ), 485 + Err(e) => (Err(IdentityError::BadDidDoc(e)), "false"), 486 + }, 487 + Err(atrium_identity::Error::NotFound) => ( 488 + Ok(IdentityVal(UtcDateTime::now(), IdentityData::NotFound)), 489 + "false", 490 + ), 491 + Err(other) => (Err(IdentityError::ResolutionFailed(other)), "false"), 492 + }; 493 + metrics::histogram!("slingshot_fetch_service_did_doc", "success" => success) 494 + .record(t0.elapsed()); 495 + res 496 + } 497 + }) 498 + .await?; 499 + 500 + let now = UtcDateTime::now(); 501 + let IdentityVal(last_fetch, data) = entry.value(); 502 + match data { 503 + IdentityData::NotFound => { 504 + if (now - *last_fetch) >= MIN_NOT_FOUND_TTL { 505 + metrics::counter!("identity_service_did_refresh_queued", "reason" => "ttl", "found" => "false").increment(1); 506 + self.queue_refresh(key).await; 507 + } 508 + Ok(None) 509 + } 510 + IdentityData::ServiceDoc(mini_service_doc) => { 511 + if (now - *last_fetch) >= MIN_TTL { 512 + metrics::counter!("identity_service_did_refresh_queued", "reason" => "ttl", "found" => "true").increment(1); 513 + self.queue_refresh(key).await; 514 + } 515 + Ok(Some(mini_service_doc.clone())) 516 + } 517 + _ => { 518 + log::error!( 519 + "identity value mixup: got a doc from a different key type (should be a service did)" 520 + ); 521 + Err(IdentityError::IdentityValTypeMixup(did.to_string())) 522 } 523 } 524 } ··· 659 log::warn!( 660 "refreshed did doc failed: wrong did doc id. dropping refresh." 661 ); 662 + self.complete_refresh(&task_key).await?; 663 continue; 664 } 665 let mini_doc = match did_doc.try_into() { ··· 667 Err(e) => { 668 metrics::counter!("identity_did_refresh", "success" => "false", "reason" => "bad doc").increment(1); 669 log::warn!( 670 + "converting mini doc for {did:?} failed: {e:?}. dropping refresh." 671 ); 672 + self.complete_refresh(&task_key).await?; 673 continue; 674 } 675 }; ··· 696 } 697 698 self.complete_refresh(&task_key).await?; // failures are bugs, so break loop 699 + } 700 + IdentityKey::ServiceDid(ref did) => { 701 + log::trace!("refreshing service did doc: {did:?}"); 702 + 703 + match self.did_resolver.resolve(did).await { 704 + Ok(did_doc) => { 705 + // TODO: fix in atrium: should verify id is did 706 + if did_doc.id != did.to_string() { 707 + metrics::counter!("identity_service_did_refresh", "success" => "false", "reason" => "wrong did").increment(1); 708 + log::warn!( 709 + "refreshed did doc failed: wrong did doc id. dropping refresh." 710 + ); 711 + self.complete_refresh(&task_key).await?; 712 + continue; 713 + } 714 + let mini_service_doc = match did_doc.try_into() { 715 + Ok(md) => md, 716 + Err(e) => { 717 + metrics::counter!("identity_service_did_refresh", "success" => "false", "reason" => "bad doc").increment(1); 718 + log::warn!( 719 + "converting mini service doc failed: {e:?}. dropping refresh." 720 + ); 721 + self.complete_refresh(&task_key).await?; 722 + continue; 723 + } 724 + }; 725 + metrics::counter!("identity_service_did_refresh", "success" => "true") 726 + .increment(1); 727 + self.cache.insert( 728 + task_key.clone(), 729 + IdentityVal( 730 + UtcDateTime::now(), 731 + IdentityData::ServiceDoc(mini_service_doc), 732 + ), 733 + ); 734 + } 735 + Err(atrium_identity::Error::NotFound) => { 736 + metrics::counter!("identity_service_did_refresh", "success" => "false", "reason" => "not found").increment(1); 737 + self.cache.insert( 738 + task_key.clone(), 739 + IdentityVal(UtcDateTime::now(), IdentityData::NotFound), 740 + ); 741 + } 742 + Err(err) => { 743 + metrics::counter!("identity_service_did_refresh", "success" => "false", "reason" => "other").increment(1); 744 + log::warn!( 745 + "failed to refresh did doc: {err:?}. leaving stale (should we eventually do something?)" 746 + ); 747 + } 748 + } 749 } 750 } 751 }
+61 -25
slingshot/src/main.rs
··· 1 - // use foyer::HybridCache; 2 - // use foyer::{Engine, DirectFsDeviceOptions, HybridCacheBuilder}; 3 use metrics_exporter_prometheus::PrometheusBuilder; 4 use slingshot::{ 5 Identity, Proxy, Repo, consume, error::MainTaskError, firehose_cache, healthcheck, serve, ··· 9 10 use clap::Parser; 11 use tokio_util::sync::CancellationToken; 12 13 /// Slingshot record edge cache 14 #[derive(Parser, Debug, Clone)] ··· 48 #[arg(long, env = "SLINGSHOT_IDENTITY_CACHE_DISK_DB")] 49 #[clap(default_value_t = 1)] 50 identity_cache_disk_gb: usize, 51 /// the domain pointing to this server 52 /// 53 /// if present: 54 /// - a did:web document will be served at /.well-known/did.json 55 - /// - an HTTPS certs will be automatically configured with Acme/letsencrypt 56 /// - TODO: a rate-limiter will be installed 57 #[arg( 58 long, 59 conflicts_with("bind"), 60 - requires("acme_cache_path"), 61 - env = "SLINGSHOT_ACME_DOMAIN" 62 )] 63 - acme_domain: Option<String>, 64 /// email address for letsencrypt contact 65 /// 66 /// recommended in production, i guess? 67 - #[arg(long, requires("acme_domain"), env = "SLINGSHOT_ACME_CONTACT")] 68 acme_contact: Option<String>, 69 - /// a location to cache acme https certs 70 /// 71 - /// required when (and only used when) --acme-domain is specified. 72 - /// 73 - /// recommended in production, but mind the file permissions. 74 - #[arg(long, requires("acme_domain"), env = "SLINGSHOT_ACME_CACHE_PATH")] 75 - acme_cache_path: Option<PathBuf>, 76 - /// listen for ipv6 when using acme 77 - /// 78 - /// you must also configure the relevant DNS records for this to work 79 - #[arg(long, action, requires("acme_domain"), env = "SLINGSHOT_ACME_IPV6")] 80 - acme_ipv6: bool, 81 /// an web address to send healtcheck pings to every ~51s or so 82 #[arg(long, env = "SLINGSHOT_HEALTHCHECK")] 83 healthcheck: Option<String>, ··· 101 102 let args = Args::parse(); 103 104 if args.collect_metrics { 105 log::trace!("installing metrics server..."); 106 if let Err(e) = install_metrics_server(args.bind_metrics) { ··· 152 log::info!("identity service ready."); 153 154 let repo = Repo::new(identity.clone()); 155 - let proxy = Proxy::new(repo.clone()); 156 157 let identity_for_server = identity.clone(); 158 let server_shutdown = shutdown.clone(); ··· 164 identity_for_server, 165 repo, 166 proxy, 167 - args.acme_domain, 168 args.acme_contact, 169 - args.acme_cache_path, 170 - args.acme_ipv6, 171 server_shutdown, 172 bind, 173 ) ··· 236 ) -> Result<(), metrics_exporter_prometheus::BuildError> { 237 log::info!("installing metrics server..."); 238 PrometheusBuilder::new() 239 - .set_quantiles(&[0.5, 0.9, 0.99, 1.0])? 240 - .set_bucket_duration(std::time::Duration::from_secs(300))? 241 - .set_bucket_count(std::num::NonZero::new(12).unwrap()) // count * duration = 60 mins. stuff doesn't happen that fast here. 242 .set_enable_unit_suffix(false) // this seemed buggy for constellation (sometimes wouldn't engage) 243 .with_http_listener(bind_metrics) 244 .install()?;
··· 1 use metrics_exporter_prometheus::PrometheusBuilder; 2 use slingshot::{ 3 Identity, Proxy, Repo, consume, error::MainTaskError, firehose_cache, healthcheck, serve, ··· 7 8 use clap::Parser; 9 use tokio_util::sync::CancellationToken; 10 + use url::Url; 11 12 /// Slingshot record edge cache 13 #[derive(Parser, Debug, Clone)] ··· 47 #[arg(long, env = "SLINGSHOT_IDENTITY_CACHE_DISK_DB")] 48 #[clap(default_value_t = 1)] 49 identity_cache_disk_gb: usize, 50 + /// the address of this server 51 + /// 52 + /// used if --acme-domain is not set, defaulting to `--bind` 53 + #[arg(long, conflicts_with("tls_domain"), env = "SLINGSHOT_PUBLIC_HOST")] 54 + base_url: Option<Url>, 55 /// the domain pointing to this server 56 /// 57 /// if present: 58 /// - a did:web document will be served at /.well-known/did.json 59 + /// - the server will bind on port 443 60 + /// - if `--acme-contact` is present, the server will bind port 80 for http 61 + /// challenges and attempt to auto-provision certs for `--tls-domain` 62 + /// - if `--acme-contact is absent, the server will load certs from the 63 + /// `--tls-certs` folder, and try to reload them twice daily, guarded by 64 + /// a lock file called `.cert-lock` in the `--tls-certs` folder. 65 /// - TODO: a rate-limiter will be installed 66 #[arg( 67 long, 68 conflicts_with("bind"), 69 + requires("tls_certs"), 70 + env = "SLINGSHOT_TLS_DOMAIN" 71 + )] 72 + tls_domain: Option<String>, 73 + /// a location to find/cache acme or other tls certs 74 + /// 75 + /// recommended in production, mind the file permissions. 76 + #[arg(long, env = "SLINGSHOT_TLS_CERTS_PATH")] 77 + tls_certs: Option<PathBuf>, 78 + /// listen for ipv6 when using acme or other tls 79 + /// 80 + /// you must also configure the relevant DNS records for this to work 81 + #[arg(long, action, requires("tls_domain"), env = "SLINGSHOT_TLS_IPV6")] 82 + tls_ipv6: bool, 83 + /// redirect acme http-01 challenges to this url 84 + /// 85 + /// useful if you're setting up a second instance that synchronizes its 86 + /// certs from a main instance doing acme. 87 + #[arg( 88 + long, 89 + conflicts_with("acme_contact"), 90 + requires("tls_domain"), 91 + env = "SLINGSHOT_ACME_CHALLENGE_REDIRECT" 92 )] 93 + acme_challenge_redirect: Option<String>, 94 /// email address for letsencrypt contact 95 /// 96 /// recommended in production, i guess? 97 + #[arg(long, requires("tls_domain"), env = "SLINGSHOT_ACME_CONTACT")] 98 acme_contact: Option<String>, 99 + /// use the staging environment for letsencrypt 100 /// 101 + /// recommended to initially test out new deployments with this to avoid 102 + /// letsencrypt rate limit problems. 103 + #[arg(long, action, requires("acme_contact"), env = "SLINGSHOT_ACME_STAGING")] 104 + acme_staging: bool, 105 /// an web address to send healtcheck pings to every ~51s or so 106 #[arg(long, env = "SLINGSHOT_HEALTHCHECK")] 107 healthcheck: Option<String>, ··· 125 126 let args = Args::parse(); 127 128 + let base_url: Url = args 129 + .base_url 130 + .or_else(|| { 131 + args.tls_domain 132 + .as_ref() 133 + .map(|d| Url::parse(&format!("https://{d}")).unwrap()) 134 + }) 135 + .unwrap_or_else(|| Url::parse(&format!("http://{}", args.bind)).unwrap()); 136 + 137 if args.collect_metrics { 138 log::trace!("installing metrics server..."); 139 if let Err(e) = install_metrics_server(args.bind_metrics) { ··· 185 log::info!("identity service ready."); 186 187 let repo = Repo::new(identity.clone()); 188 + let proxy = Proxy::new(identity.clone()); 189 190 let identity_for_server = identity.clone(); 191 let server_shutdown = shutdown.clone(); ··· 197 identity_for_server, 198 repo, 199 proxy, 200 + base_url, 201 + args.tls_domain, 202 + args.tls_certs, 203 + args.tls_ipv6, 204 + args.acme_challenge_redirect, 205 args.acme_contact, 206 + args.acme_staging, 207 server_shutdown, 208 bind, 209 ) ··· 272 ) -> Result<(), metrics_exporter_prometheus::BuildError> { 273 log::info!("installing metrics server..."); 274 PrometheusBuilder::new() 275 + .set_buckets(&[0.001, 0.006, 0.036, 0.216, 1.296, 7.776, 45.656])? 276 + .set_bucket_duration(std::time::Duration::from_secs(15))? 277 + .set_bucket_count(std::num::NonZero::new(4).unwrap()) // count * duration = bucket lifetime 278 .set_enable_unit_suffix(false) // this seemed buggy for constellation (sometimes wouldn't engage) 279 .with_http_listener(bind_metrics) 280 .install()?;
+251 -155
slingshot/src/proxy.rs
··· 1 - use serde::Deserialize; 2 - use url::Url; 3 - use std::{collections::HashMap, time::Duration}; 4 - use crate::{Repo, server::HydrationSource, error::ProxyError}; 5 use reqwest::Client; 6 use serde_json::{Map, Value}; 7 8 pub enum ParamValue { 9 String(Vec<String>), ··· 13 pub struct Params(HashMap<String, ParamValue>); 14 15 impl TryFrom<Map<String, Value>> for Params { 16 - type Error = (); // TODO 17 fn try_from(val: Map<String, Value>) -> Result<Self, Self::Error> { 18 let mut out = HashMap::new(); 19 for (k, v) in val { ··· 70 71 #[derive(Clone)] 72 pub struct Proxy { 73 - repo: Repo, 74 client: Client, 75 } 76 77 impl Proxy { 78 - pub fn new(repo: Repo) -> Self { 79 let client = Client::builder() 80 .user_agent(format!( 81 "microcosm slingshot v{} (contact: @bad-example.com)", ··· 85 .timeout(Duration::from_secs(6)) 86 .build() 87 .unwrap(); 88 - Self { repo, client } 89 } 90 91 pub async fn proxy( 92 &self, 93 - xrpc: String, 94 - service: String, 95 params: Option<Map<String, Value>>, 96 ) -> Result<Value, ProxyError> { 97 - 98 - // hackin it to start 99 - 100 - // 1. assume did-web (TODO) and get the did doc 101 - #[derive(Debug, Deserialize)] 102 - struct ServiceDoc { 103 - id: String, 104 - service: Vec<ServiceItem>, 105 - } 106 - #[derive(Debug, Deserialize)] 107 - struct ServiceItem { 108 - id: String, 109 - #[expect(unused)] 110 - r#type: String, 111 - #[serde(rename = "serviceEndpoint")] 112 - service_endpoint: Url, 113 - } 114 - let dw = service.strip_prefix("did:web:").expect("a did web"); 115 - let (dw, service_id) = dw.split_once("#").expect("whatever"); 116 - let mut dw_url = Url::parse(&format!("https://{dw}"))?; 117 - dw_url.set_path("/.well-known/did.json"); 118 - let doc: ServiceDoc = self.client 119 - .get(dw_url) 120 - .send() 121 .await? 122 - .error_for_status()? 123 - .json() 124 - .await?; 125 126 - assert_eq!(doc.id, format!("did:web:{}", dw)); 127 - 128 - let mut upstream = None; 129 - for ServiceItem { id, service_endpoint, .. } in doc.service { 130 - let Some((_, id)) = id.split_once("#") else { continue; }; 131 - if id != service_id { continue; }; 132 - upstream = Some(service_endpoint); 133 - break; 134 - } 135 - 136 - // 2. proxy the request forward 137 - let mut upstream = upstream.expect("to find it"); 138 - upstream.set_path(&format!("/xrpc/{xrpc}")); // TODO: validate nsid 139 140 if let Some(params) = params { 141 let mut query = upstream.query_pairs_mut(); ··· 161 } 162 } 163 164 - // TODO: other headers to proxy 165 - Ok(self.client 166 .get(upstream) 167 .send() 168 - .await? 169 - .error_for_status()? 170 - .json() 171 - .await?) 172 } 173 } 174 ··· 188 while let Some((i, c)) = chars.next() { 189 match c { 190 '[' if in_bracket => return Err(format!("nested opening bracket not allowed, at {i}")), 191 - '[' if key_acc.is_empty() => return Err(format!("missing key before opening bracket, at {i}")), 192 '[' => in_bracket = true, 193 ']' if in_bracket => { 194 in_bracket = false; 195 let key = std::mem::take(&mut key_acc); 196 let r#type = std::mem::take(&mut type_acc); 197 - let t = if r#type.is_empty() { None } else { Some(r#type) }; 198 out.push(PathPart::Vector(key, t)); 199 // peek ahead because we need a dot after array if there's more and i don't want to add more loop state 200 let Some((i, c)) = chars.next() else { 201 break; 202 }; 203 if c != '.' { 204 - return Err(format!("expected dot after close bracket, found {c:?} at {i}")); 205 } 206 } 207 ']' => return Err(format!("unexpected close bracket at {i}")), 208 '.' if in_bracket => type_acc.push(c), 209 - '.' if key_acc.is_empty() => return Err(format!("missing key before next segment, at {i}")), 210 '.' => { 211 let key = std::mem::take(&mut key_acc); 212 assert!(type_acc.is_empty()); ··· 225 Ok(out) 226 } 227 228 - #[derive(Debug, Clone, PartialEq)] 229 pub enum RefShape { 230 StrongRef, 231 AtUri, ··· 233 Did, 234 Handle, 235 AtIdentifier, 236 - Blob, 237 - // TODO: blob with type? 238 } 239 240 impl TryFrom<&str> for RefShape { ··· 247 "did" => Ok(Self::Did), 248 "handle" => Ok(Self::Handle), 249 "at-identifier" => Ok(Self::AtIdentifier), 250 - "blob" => Ok(Self::Blob), 251 _ => Err(format!("unknown shape: {s}")), 252 } 253 } 254 } 255 256 #[derive(Debug, PartialEq)] 257 pub enum MatchedRef { 258 - AtUri { 259 - uri: String, 260 - cid: Option<String>, 261 - }, 262 - Identifier(String), 263 - Blob { 264 - link: String, 265 - mime: String, 266 - size: u64, 267 - } 268 } 269 270 - pub fn match_shape(shape: &RefShape, val: &Value) -> Option<MatchedRef> { 271 // TODO: actually validate at-uri format 272 // TODO: actually validate everything else also 273 // TODO: should this function normalize identifiers to DIDs probably? ··· 276 RefShape::StrongRef => { 277 let o = val.as_object()?; 278 let uri = o.get("uri")?.as_str()?.to_string(); 279 - let cid = o.get("cid")?.as_str()?.to_string(); 280 - Some(MatchedRef::AtUri { uri, cid: Some(cid) }) 281 } 282 RefShape::AtUri => { 283 let uri = val.as_str()?.to_string(); 284 - Some(MatchedRef::AtUri { uri, cid: None }) 285 } 286 RefShape::AtUriParts => { 287 let o = val.as_object()?; 288 - let identifier = o.get("repo").or(o.get("did"))?.as_str()?.to_string(); 289 - let collection = o.get("collection")?.as_str()?.to_string(); 290 - let rkey = o.get("rkey")?.as_str()?.to_string(); 291 - let uri = format!("at://{identifier}/{collection}/{rkey}"); 292 - let cid = o.get("cid").and_then(|v| v.as_str()).map(str::to_string); 293 - Some(MatchedRef::AtUri { uri, cid }) 294 } 295 RefShape::Did => { 296 - let id = val.as_str()?; 297 - if !id.starts_with("did:") { 298 - return None; 299 - } 300 - Some(MatchedRef::Identifier(id.to_string())) 301 } 302 RefShape::Handle => { 303 - let id = val.as_str()?; 304 - if id.contains(':') { 305 - return None; 306 - } 307 - Some(MatchedRef::Identifier(id.to_string())) 308 } 309 RefShape::AtIdentifier => { 310 - Some(MatchedRef::Identifier(val.as_str()?.to_string())) 311 - } 312 - RefShape::Blob => { 313 - let o = val.as_object()?; 314 - if o.get("$type")? != "blob" { 315 - return None; 316 - } 317 - let link = o.get("ref")?.as_object()?.get("$link")?.as_str()?.to_string(); 318 - let mime = o.get("mimeType")?.as_str()?.to_string(); 319 - let size = o.get("size")?.as_u64()?; 320 - Some(MatchedRef::Blob { link, mime, size }) 321 } 322 } 323 } ··· 343 let mut out = Vec::new(); 344 for (path_parts, shape) in sources { 345 for val in PathWalker::new(&path_parts, skeleton) { 346 - if let Some(matched) = match_shape(&shape, val) { 347 out.push(matched); 348 } 349 } ··· 357 } 358 impl<'a> PathWalker<'a> { 359 fn new(path_parts: &'a [PathPart], skeleton: &'a Value) -> Self { 360 - Self { todo: vec![(path_parts, skeleton)] } 361 } 362 } 363 impl<'a> Iterator for PathWalker<'a> { ··· 382 let Some(a) = o.get(k).and_then(|v| v.as_array()) else { 383 continue; 384 }; 385 - for v in a 386 - .iter() 387 - .rev() 388 - .filter(|c| { 389 - let Some(t) = t else { return true }; 390 - c 391 - .as_object() 392 - .and_then(|o| o.get("$type")) 393 - .and_then(|v| v.as_str()) 394 - .map(|s| s == t) 395 - .unwrap_or(false) 396 - }) 397 - { 398 self.todo.push((rest, v)) 399 } 400 } ··· 403 } 404 } 405 406 - 407 #[cfg(test)] 408 mod tests { 409 use super::*; 410 use serde_json::json; 411 412 #[test] 413 fn test_parse_record_path() -> Result<(), Box<dyn std::error::Error>> { 414 let cases = [ 415 ("", vec![]), 416 ("subject", vec![PathPart::Scalar("subject".into())]), 417 ("authorDid", vec![PathPart::Scalar("authorDid".into())]), 418 - ("subject.uri", vec![PathPart::Scalar("subject".into()), PathPart::Scalar("uri".into())]), 419 ("members[]", vec![PathPart::Vector("members".into(), None)]), 420 - ("add[].key", vec![ 421 - PathPart::Vector("add".into(), None), 422 - PathPart::Scalar("key".into()), 423 - ]), 424 ("a[b]", vec![PathPart::Vector("a".into(), Some("b".into()))]), 425 - ("a[b.c]", vec![PathPart::Vector("a".into(), Some("b.c".into()))]), 426 - ("facets[app.bsky.richtext.facet].features[app.bsky.richtext.facet#mention].did", vec![ 427 - PathPart::Vector("facets".into(), Some("app.bsky.richtext.facet".into())), 428 - PathPart::Vector("features".into(), Some("app.bsky.richtext.facet#mention".into())), 429 - PathPart::Scalar("did".into()), 430 - ]), 431 ]; 432 433 for (path, expected) in cases { ··· 444 ("strong-ref", json!(""), None), 445 ("strong-ref", json!({}), None), 446 ("strong-ref", json!({ "uri": "abc" }), None), 447 - ("strong-ref", json!({ "cid": "def" }), None), 448 ( 449 "strong-ref", 450 - json!({ "uri": "abc", "cid": "def" }), 451 - Some(MatchedRef::AtUri { uri: "abc".to_string(), cid: Some("def".to_string()) }), 452 ), 453 ("at-uri", json!({ "uri": "abc" }), None), 454 - ("at-uri", json!({ "uri": "abc", "cid": "def" }), None), 455 ( 456 "at-uri", 457 - json!("abc"), 458 - Some(MatchedRef::AtUri { uri: "abc".to_string(), cid: None }), 459 ), 460 ("at-uri-parts", json!("abc"), None), 461 ("at-uri-parts", json!({}), None), 462 ( 463 "at-uri-parts", 464 - json!({"repo": "a", "collection": "b", "rkey": "c"}), 465 - Some(MatchedRef::AtUri { uri: "at://a/b/c".to_string(), cid: None }), 466 ), 467 ( 468 "at-uri-parts", 469 - json!({"did": "a", "collection": "b", "rkey": "c"}), 470 - Some(MatchedRef::AtUri { uri: "at://a/b/c".to_string(), cid: None }), 471 ), 472 ( 473 "at-uri-parts", 474 // 'repo' takes precedence over 'did' 475 - json!({"did": "a", "repo": "z", "collection": "b", "rkey": "c"}), 476 - Some(MatchedRef::AtUri { uri: "at://z/b/c".to_string(), cid: None }), 477 ), 478 ( 479 "at-uri-parts", 480 - json!({"repo": "a", "collection": "b", "rkey": "c", "cid": "def"}), 481 - Some(MatchedRef::AtUri { uri: "at://a/b/c".to_string(), cid: Some("def".to_string()) }), 482 ), 483 ( 484 "at-uri-parts", 485 - json!({"repo": "a", "collection": "b", "rkey": "c", "cid": {}}), 486 - Some(MatchedRef::AtUri { uri: "at://a/b/c".to_string(), cid: None }), 487 ), 488 ("did", json!({}), None), 489 ("did", json!(""), None), 490 ("did", json!("bad-example.com"), None), 491 - ("did", json!("did:plc:xyz"), Some(MatchedRef::Identifier("did:plc:xyz".to_string()))), 492 ("handle", json!({}), None), 493 - ("handle", json!("bad-example.com"), Some(MatchedRef::Identifier("bad-example.com".to_string()))), 494 ("handle", json!("did:plc:xyz"), None), 495 ("at-identifier", json!({}), None), 496 - ("at-identifier", json!("bad-example.com"), Some(MatchedRef::Identifier("bad-example.com".to_string()))), 497 - ("at-identifier", json!("did:plc:xyz"), Some(MatchedRef::Identifier("did:plc:xyz".to_string()))), 498 ]; 499 - for (shape, val, expected) in cases { 500 let s = shape.try_into().unwrap(); 501 - let matched = match_shape(&s, &val); 502 - assert_eq!(matched, expected, "shape: {shape:?}, val: {val:?}"); 503 } 504 } 505 }
··· 1 + use crate::{Identity, error::ProxyError, server::HydrationSource}; 2 + use atrium_api::types::string::{AtIdentifier, Cid, Did, Nsid, RecordKey}; 3 use reqwest::Client; 4 use serde_json::{Map, Value}; 5 + use std::{collections::HashMap, time::Duration}; 6 + use url::Url; 7 8 pub enum ParamValue { 9 String(Vec<String>), ··· 13 pub struct Params(HashMap<String, ParamValue>); 14 15 impl TryFrom<Map<String, Value>> for Params { 16 + type Error = (); // TODO 17 fn try_from(val: Map<String, Value>) -> Result<Self, Self::Error> { 18 let mut out = HashMap::new(); 19 for (k, v) in val { ··· 70 71 #[derive(Clone)] 72 pub struct Proxy { 73 + identity: Identity, 74 client: Client, 75 } 76 77 impl Proxy { 78 + pub fn new(identity: Identity) -> Self { 79 let client = Client::builder() 80 .user_agent(format!( 81 "microcosm slingshot v{} (contact: @bad-example.com)", ··· 85 .timeout(Duration::from_secs(6)) 86 .build() 87 .unwrap(); 88 + Self { client, identity } 89 } 90 91 pub async fn proxy( 92 &self, 93 + service_did: &Did, 94 + service_id: &str, 95 + xrpc: &Nsid, 96 + authorization: Option<&str>, 97 + atproto_accept_labelers: Option<&str>, 98 params: Option<Map<String, Value>>, 99 ) -> Result<Value, ProxyError> { 100 + let mut upstream: Url = self 101 + .identity 102 + .did_to_mini_service_doc(service_did) 103 .await? 104 + .ok_or(ProxyError::ServiceNotFound)? 105 + .get(service_id, None) 106 + .ok_or(ProxyError::ServiceNotMatched)? 107 + .endpoint 108 + .parse()?; 109 110 + upstream.set_path(&format!("/xrpc/{}", xrpc.as_str())); 111 112 if let Some(params) = params { 113 let mut query = upstream.query_pairs_mut(); ··· 133 } 134 } 135 136 + // TODO i mean maybe we should look for headers also in our headers but not obviously 137 + let mut headers = reqwest::header::HeaderMap::new(); 138 + // TODO: check the jwt aud against the upstream!!! 139 + if let Some(auth) = authorization { 140 + headers.insert("Authorization", auth.try_into()?); 141 + } 142 + if let Some(aal) = atproto_accept_labelers { 143 + headers.insert("atproto-accept-labelers", aal.try_into()?); 144 + } 145 + 146 + let t0 = std::time::Instant::now(); 147 + let res = self 148 + .client 149 .get(upstream) 150 + .headers(headers) 151 .send() 152 + .await 153 + .and_then(|r| r.error_for_status()); 154 + 155 + if res.is_ok() { 156 + metrics::histogram!("slingshot_proxy_upstream_request", "success" => "true") 157 + .record(t0.elapsed()); 158 + } else { 159 + metrics::histogram!("slingshot_proxy_upstream_request", "success" => "false") 160 + .record(t0.elapsed()); 161 + } 162 + 163 + Ok(res?.json().await?) 164 } 165 } 166 ··· 180 while let Some((i, c)) = chars.next() { 181 match c { 182 '[' if in_bracket => return Err(format!("nested opening bracket not allowed, at {i}")), 183 + '[' if key_acc.is_empty() => { 184 + return Err(format!("missing key before opening bracket, at {i}")); 185 + } 186 '[' => in_bracket = true, 187 ']' if in_bracket => { 188 in_bracket = false; 189 let key = std::mem::take(&mut key_acc); 190 let r#type = std::mem::take(&mut type_acc); 191 + let t = if r#type.is_empty() { 192 + None 193 + } else { 194 + Some(r#type) 195 + }; 196 out.push(PathPart::Vector(key, t)); 197 // peek ahead because we need a dot after array if there's more and i don't want to add more loop state 198 let Some((i, c)) = chars.next() else { 199 break; 200 }; 201 if c != '.' { 202 + return Err(format!( 203 + "expected dot after close bracket, found {c:?} at {i}" 204 + )); 205 } 206 } 207 ']' => return Err(format!("unexpected close bracket at {i}")), 208 '.' if in_bracket => type_acc.push(c), 209 + '.' if key_acc.is_empty() => { 210 + return Err(format!("missing key before next segment, at {i}")); 211 + } 212 '.' => { 213 let key = std::mem::take(&mut key_acc); 214 assert!(type_acc.is_empty()); ··· 227 Ok(out) 228 } 229 230 + #[derive(Debug, Clone, Copy, PartialEq)] 231 pub enum RefShape { 232 StrongRef, 233 AtUri, ··· 235 Did, 236 Handle, 237 AtIdentifier, 238 } 239 240 impl TryFrom<&str> for RefShape { ··· 247 "did" => Ok(Self::Did), 248 "handle" => Ok(Self::Handle), 249 "at-identifier" => Ok(Self::AtIdentifier), 250 _ => Err(format!("unknown shape: {s}")), 251 } 252 } 253 } 254 255 #[derive(Debug, PartialEq)] 256 + pub struct FullAtUriParts { 257 + pub repo: AtIdentifier, 258 + pub collection: Nsid, 259 + pub rkey: RecordKey, 260 + pub cid: Option<Cid>, 261 + } 262 + 263 + impl FullAtUriParts { 264 + pub fn to_uri(&self) -> String { 265 + let repo: String = self.repo.clone().into(); // no as_str for AtIdentifier atrium??? 266 + let collection = self.collection.as_str(); 267 + let rkey = self.rkey.as_str(); 268 + format!("at://{repo}/{collection}/{rkey}") 269 + } 270 + } 271 + 272 + // TODO: move this to links 273 + pub fn split_uri(uri: &str) -> Option<(AtIdentifier, Nsid, RecordKey)> { 274 + let rest = uri.strip_prefix("at://")?; 275 + let (repo, rest) = rest.split_once("/")?; 276 + let repo = repo.parse().ok()?; 277 + let (collection, rkey) = rest.split_once("/")?; 278 + let collection = collection.parse().ok()?; 279 + let rkey = rkey.split_once('#').map(|(k, _)| k).unwrap_or(rkey); 280 + let rkey = rkey.split_once('?').map(|(k, _)| k).unwrap_or(rkey); 281 + let rkey = rkey.parse().ok()?; 282 + Some((repo, collection, rkey)) 283 + } 284 + 285 + #[derive(Debug, PartialEq)] 286 pub enum MatchedRef { 287 + AtUri(FullAtUriParts), 288 + Identifier(AtIdentifier), 289 } 290 291 + pub fn match_shape(shape: RefShape, val: &Value) -> Option<MatchedRef> { 292 // TODO: actually validate at-uri format 293 // TODO: actually validate everything else also 294 // TODO: should this function normalize identifiers to DIDs probably? ··· 297 RefShape::StrongRef => { 298 let o = val.as_object()?; 299 let uri = o.get("uri")?.as_str()?.to_string(); 300 + let cid = o.get("cid")?.as_str()?.parse().ok()?; 301 + let (repo, collection, rkey) = split_uri(&uri)?; 302 + Some(MatchedRef::AtUri(FullAtUriParts { 303 + repo, 304 + collection, 305 + rkey, 306 + cid: Some(cid), 307 + })) 308 } 309 RefShape::AtUri => { 310 let uri = val.as_str()?.to_string(); 311 + let (repo, collection, rkey) = split_uri(&uri)?; 312 + Some(MatchedRef::AtUri(FullAtUriParts { 313 + repo, 314 + collection, 315 + rkey, 316 + cid: None, 317 + })) 318 } 319 RefShape::AtUriParts => { 320 let o = val.as_object()?; 321 + let repo = o.get("repo").or(o.get("did"))?.as_str()?.parse().ok()?; 322 + let collection = o.get("collection")?.as_str()?.parse().ok()?; 323 + let rkey = o.get("rkey")?.as_str()?.parse().ok()?; 324 + let cid = o 325 + .get("cid") 326 + .and_then(|v| v.as_str()) 327 + .and_then(|s| s.parse().ok()); 328 + Some(MatchedRef::AtUri(FullAtUriParts { 329 + repo, 330 + collection, 331 + rkey, 332 + cid, 333 + })) 334 } 335 RefShape::Did => { 336 + let did = val.as_str()?.parse().ok()?; 337 + Some(MatchedRef::Identifier(AtIdentifier::Did(did))) 338 } 339 RefShape::Handle => { 340 + let handle = val.as_str()?.parse().ok()?; 341 + Some(MatchedRef::Identifier(AtIdentifier::Handle(handle))) 342 } 343 RefShape::AtIdentifier => { 344 + let identifier = val.as_str()?.parse().ok()?; 345 + Some(MatchedRef::Identifier(identifier)) 346 } 347 } 348 } ··· 368 let mut out = Vec::new(); 369 for (path_parts, shape) in sources { 370 for val in PathWalker::new(&path_parts, skeleton) { 371 + if let Some(matched) = match_shape(shape, val) { 372 out.push(matched); 373 } 374 } ··· 382 } 383 impl<'a> PathWalker<'a> { 384 fn new(path_parts: &'a [PathPart], skeleton: &'a Value) -> Self { 385 + Self { 386 + todo: vec![(path_parts, skeleton)], 387 + } 388 } 389 } 390 impl<'a> Iterator for PathWalker<'a> { ··· 409 let Some(a) = o.get(k).and_then(|v| v.as_array()) else { 410 continue; 411 }; 412 + for v in a.iter().rev().filter(|c| { 413 + let Some(t) = t else { return true }; 414 + c.as_object() 415 + .and_then(|o| o.get("$type")) 416 + .and_then(|v| v.as_str()) 417 + .map(|s| s == t) 418 + .unwrap_or(false) 419 + }) { 420 self.todo.push((rest, v)) 421 } 422 } ··· 425 } 426 } 427 428 #[cfg(test)] 429 mod tests { 430 use super::*; 431 use serde_json::json; 432 433 + static TEST_CID: &str = "bafyreidffwk5wvh5l76yy7zefiqrovv6yaaegb4wg4zaq35w7nt3quix5a"; 434 + 435 #[test] 436 fn test_parse_record_path() -> Result<(), Box<dyn std::error::Error>> { 437 let cases = [ 438 ("", vec![]), 439 ("subject", vec![PathPart::Scalar("subject".into())]), 440 ("authorDid", vec![PathPart::Scalar("authorDid".into())]), 441 + ( 442 + "subject.uri", 443 + vec![ 444 + PathPart::Scalar("subject".into()), 445 + PathPart::Scalar("uri".into()), 446 + ], 447 + ), 448 ("members[]", vec![PathPart::Vector("members".into(), None)]), 449 + ( 450 + "add[].key", 451 + vec![ 452 + PathPart::Vector("add".into(), None), 453 + PathPart::Scalar("key".into()), 454 + ], 455 + ), 456 ("a[b]", vec![PathPart::Vector("a".into(), Some("b".into()))]), 457 + ( 458 + "a[b.c]", 459 + vec![PathPart::Vector("a".into(), Some("b.c".into()))], 460 + ), 461 + ( 462 + "facets[app.bsky.richtext.facet].features[app.bsky.richtext.facet#mention].did", 463 + vec![ 464 + PathPart::Vector("facets".into(), Some("app.bsky.richtext.facet".into())), 465 + PathPart::Vector( 466 + "features".into(), 467 + Some("app.bsky.richtext.facet#mention".into()), 468 + ), 469 + PathPart::Scalar("did".into()), 470 + ], 471 + ), 472 ]; 473 474 for (path, expected) in cases { ··· 485 ("strong-ref", json!(""), None), 486 ("strong-ref", json!({}), None), 487 ("strong-ref", json!({ "uri": "abc" }), None), 488 + ("strong-ref", json!({ "cid": TEST_CID }), None), 489 ( 490 "strong-ref", 491 + json!({ "uri": "at://a.com/xx.yy.zz/1", "cid": TEST_CID }), 492 + Some(MatchedRef::AtUri(FullAtUriParts { 493 + repo: "a.com".parse().unwrap(), 494 + collection: "xx.yy.zz".parse().unwrap(), 495 + rkey: "1".parse().unwrap(), 496 + cid: Some(TEST_CID.parse().unwrap()), 497 + })), 498 ), 499 ("at-uri", json!({ "uri": "abc" }), None), 500 ( 501 "at-uri", 502 + json!({ "uri": "at://did:web:y.com/xx.yy.zz/1", "cid": TEST_CID }), 503 + None, 504 + ), 505 + ( 506 + "at-uri", 507 + json!("at://did:web:y.com/xx.yy.zz/1"), 508 + Some(MatchedRef::AtUri(FullAtUriParts { 509 + repo: "did:web:y.com".parse().unwrap(), 510 + collection: "xx.yy.zz".parse().unwrap(), 511 + rkey: "1".parse().unwrap(), 512 + cid: None, 513 + })), 514 ), 515 ("at-uri-parts", json!("abc"), None), 516 ("at-uri-parts", json!({}), None), 517 ( 518 "at-uri-parts", 519 + json!({"repo": "a.com", "collection": "xx.yy.zz", "rkey": "1", "cid": TEST_CID}), 520 + Some(MatchedRef::AtUri(FullAtUriParts { 521 + repo: "a.com".parse().unwrap(), 522 + collection: "xx.yy.zz".parse().unwrap(), 523 + rkey: "1".parse().unwrap(), 524 + cid: Some(TEST_CID.parse().unwrap()), 525 + })), 526 ), 527 ( 528 "at-uri-parts", 529 + json!({"did": "a.com", "collection": "xx.yy.zz", "rkey": "1"}), 530 + Some(MatchedRef::AtUri(FullAtUriParts { 531 + repo: "a.com".parse().unwrap(), 532 + collection: "xx.yy.zz".parse().unwrap(), 533 + rkey: "1".parse().unwrap(), 534 + cid: None, 535 + })), 536 ), 537 ( 538 "at-uri-parts", 539 // 'repo' takes precedence over 'did' 540 + json!({"did": "did:web:a.com", "repo": "z.com", "collection": "xx.yy.zz", "rkey": "1"}), 541 + Some(MatchedRef::AtUri(FullAtUriParts { 542 + repo: "z.com".parse().unwrap(), 543 + collection: "xx.yy.zz".parse().unwrap(), 544 + rkey: "1".parse().unwrap(), 545 + cid: None, 546 + })), 547 ), 548 ( 549 "at-uri-parts", 550 + json!({"repo": "a.com", "collection": "xx.yy.zz", "rkey": "1", "cid": TEST_CID}), 551 + Some(MatchedRef::AtUri(FullAtUriParts { 552 + repo: "a.com".parse().unwrap(), 553 + collection: "xx.yy.zz".parse().unwrap(), 554 + rkey: "1".parse().unwrap(), 555 + cid: Some(TEST_CID.parse().unwrap()), 556 + })), 557 ), 558 ( 559 "at-uri-parts", 560 + json!({"repo": "a.com", "collection": "xx.yy.zz", "rkey": "1", "cid": {}}), 561 + Some(MatchedRef::AtUri(FullAtUriParts { 562 + repo: "a.com".parse().unwrap(), 563 + collection: "xx.yy.zz".parse().unwrap(), 564 + rkey: "1".parse().unwrap(), 565 + cid: None, 566 + })), 567 ), 568 ("did", json!({}), None), 569 ("did", json!(""), None), 570 ("did", json!("bad-example.com"), None), 571 + ( 572 + "did", 573 + json!("did:plc:xyz"), 574 + Some(MatchedRef::Identifier("did:plc:xyz".parse().unwrap())), 575 + ), 576 ("handle", json!({}), None), 577 + ( 578 + "handle", 579 + json!("bad-example.com"), 580 + Some(MatchedRef::Identifier("bad-example.com".parse().unwrap())), 581 + ), 582 ("handle", json!("did:plc:xyz"), None), 583 ("at-identifier", json!({}), None), 584 + ( 585 + "at-identifier", 586 + json!("bad-example.com"), 587 + Some(MatchedRef::Identifier("bad-example.com".parse().unwrap())), 588 + ), 589 + ( 590 + "at-identifier", 591 + json!("did:plc:xyz"), 592 + Some(MatchedRef::Identifier("did:plc:xyz".parse().unwrap())), 593 + ), 594 ]; 595 + for (i, (shape, val, expected)) in cases.into_iter().enumerate() { 596 let s = shape.try_into().unwrap(); 597 + let matched = match_shape(s, &val); 598 + assert_eq!(matched, expected, "{i}: shape: {shape:?}, val: {val:?}"); 599 } 600 } 601 }
+546 -194
slingshot/src/server.rs
··· 1 use crate::{ 2 CachedRecord, ErrorResponseObject, Identity, Proxy, Repo, 3 error::{RecordError, ServerError}, 4 - proxy::{extract_links, MatchedRef}, 5 record::RawRecord, 6 }; 7 - use atrium_api::types::string::{Cid, Did, Handle, Nsid, RecordKey}; 8 use foyer::HybridCache; 9 use links::at_uri::parse_at_uri as normalize_at_uri; 10 use serde::Serialize; 11 - use std::{path::PathBuf, str::FromStr, sync::Arc, time::Instant, collections::HashMap}; 12 use tokio::sync::mpsc; 13 use tokio_util::sync::CancellationToken; 14 15 use poem::{ 16 - Endpoint, EndpointExt, IntoResponse, Route, Server, 17 endpoint::{StaticFileEndpoint, make_sync}, 18 http::Method, 19 listener::{ 20 Listener, TcpListener, 21 - acme::{AutoCert, LETS_ENCRYPT_PRODUCTION}, 22 }, 23 middleware::{CatchPanic, Cors, Tracing}, 24 }; 25 use poem_openapi::{ 26 ApiResponse, ContactObject, ExternalDocumentObject, Object, OpenApi, OpenApiService, Tags, 27 - Union, 28 - param::Query, payload::Json, types::Example, 29 }; 30 31 fn example_handle() -> String { ··· 33 } 34 fn example_did() -> String { 35 "did:plc:hdhoaan3xa3jiuq4fg4mefid".to_string() 36 } 37 fn example_collection() -> String { 38 "app.bsky.feed.like".to_string() 39 } 40 fn example_rkey() -> String { 41 "3lv4ouczo2b2a".to_string() 42 } 43 fn example_uri() -> String { 44 format!( ··· 86 })) 87 } 88 89 - fn bad_request_handler_resolve_mini(err: poem::Error) -> ResolveMiniIDResponse { 90 - ResolveMiniIDResponse::BadRequest(Json(XrpcErrorResponseObject { 91 error: "InvalidRequest".to_string(), 92 message: format!("Bad request, here's some info that maybe should not be exposed: {err}"), 93 })) ··· 189 190 #[derive(ApiResponse)] 191 #[oai(bad_request_handler = "bad_request_handler_resolve_mini")] 192 - enum ResolveMiniIDResponse { 193 /// Identity resolved 194 #[oai(status = 200)] 195 Ok(Json<MiniDocResponseObject>), ··· 199 } 200 201 #[derive(Object)] 202 - struct ProxyHydrationError { 203 - reason: String, 204 } 205 - 206 - #[derive(Object)] 207 - struct ProxyHydrationPending { 208 - url: String, 209 } 210 211 - #[derive(Object)] 212 - struct ProxyHydrationRecordFound { 213 - record: serde_json::Value, 214 } 215 216 #[derive(Object)] 217 - struct ProxyHydrationIdentifierFound { 218 - mini_doc: MiniDocResponseObject, 219 } 220 221 #[derive(Object)] 222 #[oai(rename_all = "camelCase")] 223 - struct ProxyHydrationBlobFound { 224 - /// cdn url 225 - link: String, 226 - mime_type: String, 227 - size: u64, 228 } 229 230 // todo: there's gotta be a supertrait that collects these? 231 - use poem_openapi::types::{Type, ToJSON, ParseFromJSON, IsObjectType}; 232 233 #[derive(Union)] 234 #[oai(discriminator_name = "status", rename_all = "camelCase")] ··· 244 /// The original upstream response content 245 output: serde_json::Value, 246 /// Any hydrated records 247 - records: HashMap<String, Hydration<ProxyHydrationRecordFound>>, 248 /// Any hydrated identifiers 249 - /// 250 - /// TODO: "identifiers" feels wrong as the name, probably "identities"? 251 - identifiers: HashMap<String, Hydration<ProxyHydrationIdentifierFound>>, 252 - /// Any hydrated blob CDN urls 253 - blobs: HashMap<String, Hydration<ProxyHydrationBlobFound>>, 254 } 255 impl Example for ProxyHydrateResponseObject { 256 fn example() -> Self { 257 Self { 258 output: serde_json::json!({}), 259 - records: HashMap::from([ 260 - ("asdf".into(), Hydration::Pending(ProxyHydrationPending { url: "todo".into() })), 261 - ]), 262 identifiers: HashMap::new(), 263 - blobs: HashMap::new(), 264 } 265 } 266 } ··· 271 #[oai(status = 200)] 272 Ok(Json<ProxyHydrateResponseObject>), 273 #[oai(status = 400)] 274 - BadRequest(XrpcError) 275 } 276 277 #[derive(Object)] ··· 296 xrpc: String, 297 /// The destination service the request will be forwarded to 298 atproto_proxy: String, 299 /// The `params` for the destination service XRPC endpoint 300 /// 301 /// Currently this will be passed along unchecked, but a future version of ··· 304 params: Option<serde_json::Value>, 305 /// Paths within the response to look for at-uris that can be hydrated 306 hydration_sources: Vec<HydrationSource>, 307 - // todo: deadline thing 308 - 309 } 310 impl Example for ProxyQueryPayload { 311 fn example() -> Self { 312 Self { 313 xrpc: "app.bsky.feed.getFeedSkeleton".to_string(), 314 atproto_proxy: "did:web:blue.mackuba.eu#bsky_fg".to_string(), 315 params: Some(serde_json::json!({ 316 "feed": "at://did:plc:oio4hkxaop4ao4wz2pp3f4cr/app.bsky.feed.generator/atproto", 317 })), 318 - hydration_sources: vec![ 319 - HydrationSource { 320 - path: "feed[].post".to_string(), 321 - shape: "at-uri".to_string(), 322 - } 323 - ], 324 } 325 } 326 } ··· 354 } 355 356 struct Xrpc { 357 cache: HybridCache<String, CachedRecord>, 358 identity: Identity, 359 proxy: Arc<Proxy>, ··· 422 /// only retains the most recent version of a record. 423 Query(cid): Query<Option<String>>, 424 ) -> GetRecordResponse { 425 - self.get_record_impl(repo, collection, rkey, cid).await 426 } 427 428 /// blue.microcosm.repo.getRecordByUri ··· 492 return bad_at_uri(); 493 }; 494 495 - // TODO: move this to links 496 - let Some(rest) = normalized.strip_prefix("at://") else { 497 - return bad_at_uri(); 498 - }; 499 - let Some((repo, rest)) = rest.split_once('/') else { 500 return bad_at_uri(); 501 }; 502 - let Some((collection, rest)) = rest.split_once('/') else { 503 - return bad_at_uri(); 504 - }; 505 - let rkey = if let Some((rkey, _rest)) = rest.split_once('?') { 506 - rkey 507 - } else { 508 - rest 509 - }; 510 511 self.get_record_impl( 512 - repo.to_string(), 513 - collection.to_string(), 514 - rkey.to_string(), 515 - cid, 516 ) 517 .await 518 } ··· 592 /// Handle or DID to resolve 593 #[oai(example = "example_handle")] 594 Query(identifier): Query<String>, 595 - ) -> ResolveMiniIDResponse { 596 self.resolve_mini_id(Query(identifier)).await 597 } 598 ··· 610 /// Handle or DID to resolve 611 #[oai(example = "example_handle")] 612 Query(identifier): Query<String>, 613 - ) -> ResolveMiniIDResponse { 614 Self::resolve_mini_doc_impl(&identifier, self.identity.clone()).await 615 } 616 617 - async fn resolve_mini_doc_impl(identifier: &str, identity: Identity) -> ResolveMiniIDResponse { 618 let invalid = |reason: &'static str| { 619 - ResolveMiniIDResponse::BadRequest(xrpc_error("InvalidRequest", reason)) 620 }; 621 622 let mut unverified_handle = None; ··· 681 } 682 }; 683 684 - ResolveMiniIDResponse::Ok(Json(MiniDocResponseObject { 685 did: did.to_string(), 686 handle, 687 pds: partial_doc.pds, ··· 689 })) 690 } 691 692 /// com.bad-example.proxy.hydrateQueryResponse 693 /// 694 /// > [!important] ··· 704 &self, 705 Json(payload): Json<ProxyQueryPayload>, 706 ) -> ProxyHydrateResponse { 707 - // TODO: the Accept request header, if present, gotta be json 708 - // TODO: find any Authorization header and verify it. TBD about `aud`. 709 - 710 let params = if let Some(p) = payload.params { 711 let serde_json::Value::Object(map) = p else { 712 panic!("params have to be an object"); 713 }; 714 Some(map) 715 - } else { None }; 716 717 - match self.proxy.proxy( 718 - payload.xrpc, 719 - payload.atproto_proxy, 720 - params, 721 - ).await { 722 Ok(skeleton) => { 723 let links = match extract_links(payload.hydration_sources, &skeleton) { 724 Ok(l) => l, 725 Err(e) => { 726 log::warn!("problem extracting: {e:?}"); 727 - return ProxyHydrateResponse::BadRequest(xrpc_error("oop", "sorry, error extracting")) 728 } 729 }; 730 let mut records = HashMap::new(); 731 let mut identifiers = HashMap::new(); 732 - let mut blobs = HashMap::new(); 733 734 enum GetThing { 735 - Record(String, Hydration<ProxyHydrationRecordFound>), 736 - Identifier(String, Hydration<ProxyHydrationIdentifierFound>), 737 - Blob(String, Hydration<ProxyHydrationBlobFound>), 738 } 739 740 let (tx, mut rx) = mpsc::channel(1); 741 742 - for link in links { 743 match link { 744 - MatchedRef::AtUri { uri, cid } => { 745 - if records.contains_key(&uri) { 746 log::warn!("skipping duplicate record without checking cid"); 747 continue; 748 } 749 - let mut u = url::Url::parse("https://example.com").unwrap(); 750 - u.query_pairs_mut().append_pair("at_uri", &uri); // BLEH todo 751 - records.insert(uri.clone(), Hydration::Pending(ProxyHydrationPending { 752 - url: format!("/xrpc/blue.microcosm.repo.getRecordByUri?{}", u.query().unwrap()), // TODO better; with cid, etc. 753 - })); 754 let tx = tx.clone(); 755 let identity = self.identity.clone(); 756 let repo = self.repo.clone(); 757 tokio::task::spawn(async move { 758 - let rest = uri.strip_prefix("at://").unwrap(); 759 - let (identifier, rest) = rest.split_once('/').unwrap(); 760 - let (collection, rkey) = rest.split_once('/').unwrap(); 761 - 762 - let did = if identifier.starts_with("did:") { 763 - Did::new(identifier.to_string()).unwrap() 764 - } else { 765 - let handle = Handle::new(identifier.to_string()).unwrap(); 766 - identity.handle_to_did(handle).await.unwrap().unwrap() 767 }; 768 769 - let res = match repo.get_record( 770 - &did, 771 - &Nsid::new(collection.to_string()).unwrap(), 772 - &RecordKey::new(rkey.to_string()).unwrap(), 773 - &cid.as_ref().map(|s| Cid::from_str(s).unwrap()), 774 - ).await { 775 - Ok(CachedRecord::Deleted) => 776 - Hydration::Error(ProxyHydrationError { 777 - reason: "record deleted".to_string(), 778 - }), 779 - Ok(CachedRecord::Found(RawRecord { cid: found_cid, record })) => { 780 - if let Some(c) = cid && found_cid.as_ref().to_string() != c { 781 - log::warn!("ignoring cid mismatch"); 782 } 783 - let value = serde_json::from_str(&record).unwrap(); 784 - Hydration::Found(ProxyHydrationRecordFound { 785 - record: value, 786 - }) 787 - } 788 - Err(e) => { 789 - log::warn!("finally oop {e:?}"); 790 - Hydration::Error(ProxyHydrationError { 791 - reason: "failed to fetch record".to_string(), 792 - }) 793 - } 794 - }; 795 - tx.send(GetThing::Record(uri, res)).await 796 }); 797 } 798 MatchedRef::Identifier(id) => { 799 - if identifiers.contains_key(&id) { 800 continue; 801 } 802 - let mut u = url::Url::parse("https://example.com").unwrap(); 803 - u.query_pairs_mut().append_pair("identifier", &id); 804 - identifiers.insert(id.clone(), Hydration::Pending(ProxyHydrationPending { 805 - url: format!("/xrpc/blue.microcosm.identity.resolveMiniDoc?{}", u.query().unwrap()), // gross 806 - })); 807 let tx = tx.clone(); 808 let identity = self.identity.clone(); 809 tokio::task::spawn(async move { 810 - let res = match Self::resolve_mini_doc_impl(&id, identity).await { 811 - ResolveMiniIDResponse::Ok(Json(mini_doc)) => Hydration::Found(ProxyHydrationIdentifierFound { 812 - mini_doc 813 - }), 814 - ResolveMiniIDResponse::BadRequest(e) => { 815 log::warn!("minidoc fail: {:?}", e.0); 816 Hydration::Error(ProxyHydrationError { 817 reason: "failed to resolve mini doc".to_string(), 818 }) 819 } 820 }; 821 - tx.send(GetThing::Identifier(id, res)).await 822 }); 823 } 824 - MatchedRef::Blob { link, mime, size: _ } => { 825 - if blobs.contains_key(&link) { 826 - continue; 827 - } 828 - if mime != "image/jpeg" { 829 - Hydration::<ProxyHydrationBlobFound>::Error(ProxyHydrationError { 830 - reason: "only image/jpeg supported for now".to_string(), 831 - }); 832 - } 833 - todo!("oops we need to know the account too") 834 - } 835 } 836 } 837 // so the channel can close when all are completed 838 // (we shoudl be doing a timeout...) 839 drop(tx); 840 841 - while let Some(hydration) = rx.recv().await { 842 - match hydration { 843 - GetThing::Record(uri, h) => { records.insert(uri, h); } 844 - GetThing::Identifier(uri, md) => { identifiers.insert(uri, md); } 845 - GetThing::Blob(cid, asdf) => { blobs.insert(cid, asdf); } 846 - }; 847 } 848 849 ProxyHydrateResponse::Ok(Json(ProxyHydrateResponseObject { 850 output: skeleton, 851 records, 852 identifiers, 853 - blobs, 854 })) 855 } 856 Err(e) => { ··· 858 ProxyHydrateResponse::BadRequest(xrpc_error("oop", "sorry")) 859 } 860 } 861 - 862 } 863 864 async fn get_record_impl( 865 &self, 866 - repo: String, 867 - collection: String, 868 - rkey: String, 869 - cid: Option<String>, 870 ) -> GetRecordResponse { 871 - let did = match Did::new(repo.clone()) { 872 Ok(did) => did, 873 Err(_) => { 874 let Ok(handle) = Handle::new(repo.to_lowercase()) else { ··· 899 } 900 }; 901 902 - let Ok(collection) = Nsid::new(collection) else { 903 return GetRecordResponse::BadRequest(xrpc_error( 904 "InvalidRequest", 905 "Invalid NSID for collection", 906 )); 907 }; 908 909 - let Ok(rkey) = RecordKey::new(rkey) else { 910 return GetRecordResponse::BadRequest(xrpc_error("InvalidRequest", "Invalid rkey")); 911 }; 912 913 let cid: Option<Cid> = if let Some(cid) = cid { 914 - let Ok(cid) = Cid::from_str(&cid) else { 915 return GetRecordResponse::BadRequest(xrpc_error("InvalidRequest", "Invalid CID")); 916 }; 917 Some(cid) ··· 1060 identity: Identity, 1061 repo: Repo, 1062 proxy: Proxy, 1063 - acme_domain: Option<String>, 1064 acme_contact: Option<String>, 1065 - acme_cache_path: Option<PathBuf>, 1066 - acme_ipv6: bool, 1067 shutdown: CancellationToken, 1068 bind: std::net::SocketAddr, 1069 ) -> Result<(), ServerError> { ··· 1071 let proxy = Arc::new(proxy); 1072 let api_service = OpenApiService::new( 1073 Xrpc { 1074 cache, 1075 identity, 1076 proxy, ··· 1079 "Slingshot", 1080 env!("CARGO_PKG_VERSION"), 1081 ) 1082 - .server(if let Some(ref h) = acme_domain { 1083 format!("https://{h}") 1084 } else { 1085 format!("http://{bind}") // yeah should probably fix this for reverse-proxy scenarios but it's ok for dev for now ··· 1095 "https://microcosm.blue/slingshot", 1096 )); 1097 1098 - let mut app = Route::new() 1099 .at("/", StaticFileEndpoint::new("./static/index.html")) 1100 .nest("/openapi", api_service.spec_endpoint()) 1101 .nest("/xrpc/", api_service); 1102 1103 - if let Some(domain) = acme_domain { 1104 rustls::crypto::aws_lc_rs::default_provider() 1105 .install_default() 1106 .expect("alskfjalksdjf"); 1107 1108 - app = app.at("/.well-known/did.json", get_did_doc(&domain)); 1109 1110 - let mut auto_cert = AutoCert::builder() 1111 - .directory_url(LETS_ENCRYPT_PRODUCTION) 1112 - .domain(&domain); 1113 if let Some(contact) = acme_contact { 1114 - auto_cert = auto_cert.contact(contact); 1115 } 1116 - if let Some(cache_path) = acme_cache_path { 1117 - auto_cert = auto_cert.cache_path(cache_path); 1118 - } 1119 - let auto_cert = auto_cert.build().map_err(ServerError::AcmeBuildError)?; 1120 1121 - run( 1122 - TcpListener::bind(if acme_ipv6 { "[::]:443" } else { "0.0.0.0:443" }).acme(auto_cert), 1123 - app, 1124 - shutdown, 1125 - ) 1126 - .await 1127 } else { 1128 - run(TcpListener::bind(bind), app, shutdown).await 1129 } 1130 } 1131 1132 - async fn run<L>(listener: L, app: Route, shutdown: CancellationToken) -> Result<(), ServerError> 1133 where 1134 L: Listener + 'static, 1135 { 1136 let app = app 1137 - .with( 1138 - Cors::new() 1139 - .allow_origin_regex("*") 1140 - .allow_methods([Method::GET, Method::POST]) 1141 - .allow_credentials(false), 1142 - ) 1143 .with(CatchPanic::new()) 1144 .around(request_counter) 1145 .with(Tracing);
··· 1 use crate::{ 2 CachedRecord, ErrorResponseObject, Identity, Proxy, Repo, 3 error::{RecordError, ServerError}, 4 + proxy::{FullAtUriParts, MatchedRef, extract_links, split_uri}, 5 record::RawRecord, 6 }; 7 + use atrium_api::types::string::{AtIdentifier, Cid, Did, Handle, Nsid, RecordKey}; 8 use foyer::HybridCache; 9 use links::at_uri::parse_at_uri as normalize_at_uri; 10 use serde::Serialize; 11 + use std::{collections::HashMap, path::PathBuf, str::FromStr, sync::Arc, time::Instant}; 12 use tokio::sync::mpsc; 13 use tokio_util::sync::CancellationToken; 14 15 use poem::{ 16 + Endpoint, EndpointExt, IntoResponse, Route, RouteScheme, Server, 17 endpoint::{StaticFileEndpoint, make_sync}, 18 http::Method, 19 listener::{ 20 Listener, TcpListener, 21 + acme::{AutoCert, ChallengeType, LETS_ENCRYPT_PRODUCTION, LETS_ENCRYPT_STAGING}, 22 }, 23 middleware::{CatchPanic, Cors, Tracing}, 24 }; 25 use poem_openapi::{ 26 ApiResponse, ContactObject, ExternalDocumentObject, Object, OpenApi, OpenApiService, Tags, 27 + Union, param::Query, payload::Json, types::Example, 28 }; 29 30 fn example_handle() -> String { ··· 32 } 33 fn example_did() -> String { 34 "did:plc:hdhoaan3xa3jiuq4fg4mefid".to_string() 35 + } 36 + fn example_service_did() -> String { 37 + "did:web:constellation.microcosm.blue".to_string() 38 } 39 fn example_collection() -> String { 40 "app.bsky.feed.like".to_string() 41 } 42 fn example_rkey() -> String { 43 "3lv4ouczo2b2a".to_string() 44 + } 45 + fn example_id_fragment() -> String { 46 + "#constellation".to_string() 47 } 48 fn example_uri() -> String { 49 format!( ··· 91 })) 92 } 93 94 + fn bad_request_handler_resolve_mini(err: poem::Error) -> ResolveMiniDocResponse { 95 + ResolveMiniDocResponse::BadRequest(Json(XrpcErrorResponseObject { 96 + error: "InvalidRequest".to_string(), 97 + message: format!("Bad request, here's some info that maybe should not be exposed: {err}"), 98 + })) 99 + } 100 + 101 + fn bad_request_handler_resolve_service(err: poem::Error) -> ResolveServiceResponse { 102 + ResolveServiceResponse::BadRequest(Json(XrpcErrorResponseObject { 103 error: "InvalidRequest".to_string(), 104 message: format!("Bad request, here's some info that maybe should not be exposed: {err}"), 105 })) ··· 201 202 #[derive(ApiResponse)] 203 #[oai(bad_request_handler = "bad_request_handler_resolve_mini")] 204 + enum ResolveMiniDocResponse { 205 /// Identity resolved 206 #[oai(status = 200)] 207 Ok(Json<MiniDocResponseObject>), ··· 211 } 212 213 #[derive(Object)] 214 + #[oai(example = true)] 215 + struct ServiceResponseObject { 216 + /// The service endpoint URL, if found 217 + endpoint: String, 218 } 219 + impl Example for ServiceResponseObject { 220 + fn example() -> Self { 221 + Self { 222 + endpoint: "https://example.com".to_string(), 223 + } 224 + } 225 } 226 227 + #[derive(ApiResponse)] 228 + #[oai(bad_request_handler = "bad_request_handler_resolve_service")] 229 + enum ResolveServiceResponse { 230 + /// Service resolved 231 + #[oai(status = 200)] 232 + Ok(Json<ServiceResponseObject>), 233 + /// Bad request or service not resolved 234 + #[oai(status = 400)] 235 + BadRequest(XrpcError), 236 } 237 238 #[derive(Object)] 239 + #[oai(rename_all = "camelCase")] 240 + struct ProxyHydrationError { 241 + /// Short description of why the hydration failed 242 + reason: String, 243 + /// Whether or not it's recommended to retry requesting this item 244 + should_retry: bool, 245 + /// URL to follow up at if retrying 246 + follow_up: String, 247 } 248 249 #[derive(Object)] 250 #[oai(rename_all = "camelCase")] 251 + struct ProxyHydrationPending { 252 + /// URL you can request to finish hydrating this item 253 + follow_up: String, 254 + /// Why this item couldn't be hydrated: 'deadline' or 'limit' 255 + /// 256 + /// - `deadline`: the item fetch didn't complete before the response was 257 + /// due, but will continue on slingshot in the background -- `followUp` 258 + /// requests are coalesced into the original item fetch to be available as 259 + /// early as possible. 260 + /// 261 + /// - `limit`: slingshot only attempts to hydrate the first 100 items found 262 + /// in a proxied response, with the remaining marked `pending`. You can 263 + /// request `followUp` to fetch them. 264 + /// 265 + /// In the future, Slingshot may put pending links after `limit` into a low- 266 + /// priority fetch queue, so that these items become available sooner on 267 + /// follow-up request as well. 268 + reason: String, 269 } 270 271 // todo: there's gotta be a supertrait that collects these? 272 + use poem_openapi::types::{IsObjectType, ParseFromJSON, ToJSON, Type}; 273 274 #[derive(Union)] 275 #[oai(discriminator_name = "status", rename_all = "camelCase")] ··· 285 /// The original upstream response content 286 output: serde_json::Value, 287 /// Any hydrated records 288 + records: HashMap<String, Hydration<FoundRecordResponseObject>>, 289 /// Any hydrated identifiers 290 + identifiers: HashMap<String, Hydration<MiniDocResponseObject>>, 291 } 292 impl Example for ProxyHydrateResponseObject { 293 fn example() -> Self { 294 Self { 295 output: serde_json::json!({}), 296 + records: HashMap::from([( 297 + "asdf".into(), 298 + Hydration::Pending(ProxyHydrationPending { 299 + follow_up: "/xrpc/com.atproto.repo.getRecord?...".to_string(), 300 + reason: "deadline".to_string(), 301 + }), 302 + )]), 303 identifiers: HashMap::new(), 304 } 305 } 306 } ··· 311 #[oai(status = 200)] 312 Ok(Json<ProxyHydrateResponseObject>), 313 #[oai(status = 400)] 314 + BadRequest(XrpcError), 315 } 316 317 #[derive(Object)] ··· 336 xrpc: String, 337 /// The destination service the request will be forwarded to 338 atproto_proxy: String, 339 + /// An optional auth token to pass on 340 + /// 341 + /// the `aud` field must match the upstream atproto_proxy service 342 + authorization: Option<String>, 343 + /// An optional set of labelers to request be applied by the upstream 344 + atproto_accept_labelers: Option<String>, 345 /// The `params` for the destination service XRPC endpoint 346 /// 347 /// Currently this will be passed along unchecked, but a future version of ··· 350 params: Option<serde_json::Value>, 351 /// Paths within the response to look for at-uris that can be hydrated 352 hydration_sources: Vec<HydrationSource>, 353 + // todo: let clients pass a hydration deadline? 354 } 355 impl Example for ProxyQueryPayload { 356 fn example() -> Self { 357 Self { 358 xrpc: "app.bsky.feed.getFeedSkeleton".to_string(), 359 atproto_proxy: "did:web:blue.mackuba.eu#bsky_fg".to_string(), 360 + authorization: None, 361 + atproto_accept_labelers: None, 362 params: Some(serde_json::json!({ 363 "feed": "at://did:plc:oio4hkxaop4ao4wz2pp3f4cr/app.bsky.feed.generator/atproto", 364 })), 365 + hydration_sources: vec![HydrationSource { 366 + path: "feed[].post".to_string(), 367 + shape: "at-uri".to_string(), 368 + }], 369 } 370 } 371 } ··· 399 } 400 401 struct Xrpc { 402 + base_url: url::Url, 403 cache: HybridCache<String, CachedRecord>, 404 identity: Identity, 405 proxy: Arc<Proxy>, ··· 468 /// only retains the most recent version of a record. 469 Query(cid): Query<Option<String>>, 470 ) -> GetRecordResponse { 471 + self.get_record_impl(&repo, &collection, &rkey, cid.as_deref()) 472 + .await 473 } 474 475 /// blue.microcosm.repo.getRecordByUri ··· 539 return bad_at_uri(); 540 }; 541 542 + let Some((repo, collection, rkey)) = split_uri(&normalized) else { 543 return bad_at_uri(); 544 }; 545 546 self.get_record_impl( 547 + Into::<String>::into(repo).as_str(), 548 + collection.as_str(), 549 + rkey.as_str(), 550 + cid.as_deref(), 551 ) 552 .await 553 } ··· 627 /// Handle or DID to resolve 628 #[oai(example = "example_handle")] 629 Query(identifier): Query<String>, 630 + ) -> ResolveMiniDocResponse { 631 self.resolve_mini_id(Query(identifier)).await 632 } 633 ··· 645 /// Handle or DID to resolve 646 #[oai(example = "example_handle")] 647 Query(identifier): Query<String>, 648 + ) -> ResolveMiniDocResponse { 649 Self::resolve_mini_doc_impl(&identifier, self.identity.clone()).await 650 } 651 652 + async fn resolve_mini_doc_impl(identifier: &str, identity: Identity) -> ResolveMiniDocResponse { 653 let invalid = |reason: &'static str| { 654 + ResolveMiniDocResponse::BadRequest(xrpc_error("InvalidRequest", reason)) 655 }; 656 657 let mut unverified_handle = None; ··· 716 } 717 }; 718 719 + ResolveMiniDocResponse::Ok(Json(MiniDocResponseObject { 720 did: did.to_string(), 721 handle, 722 pds: partial_doc.pds, ··· 724 })) 725 } 726 727 + /// com.bad-example.identity.resolveService 728 + /// 729 + /// resolve an atproto service did + id to its http endpoint 730 + /// 731 + /// > [!important] 732 + /// > this endpoint is experimental and may change 733 + #[oai( 734 + path = "/com.bad-example.identity.resolveService", 735 + method = "get", 736 + tag = "ApiTags::Custom" 737 + )] 738 + async fn resolve_service( 739 + &self, 740 + /// the service's did 741 + #[oai(example = "example_service_did")] 742 + Query(did): Query<String>, 743 + /// id fragment, starting with '#' 744 + /// 745 + /// must be url-encoded! 746 + #[oai(example = "example_id_fragment")] 747 + Query(id): Query<String>, 748 + /// optionally, the exact service type to filter 749 + /// 750 + /// resolving a pds requires matching the type as well as id. service 751 + /// proxying ignores the type. 752 + Query(r#type): Query<Option<String>>, 753 + ) -> ResolveServiceResponse { 754 + let Ok(did) = Did::new(did) else { 755 + return ResolveServiceResponse::BadRequest(xrpc_error( 756 + "InvalidRequest", 757 + "could not parse 'did' into a DID", 758 + )); 759 + }; 760 + let identity = self.identity.clone(); 761 + Self::resolve_service_impl(&did, &id, r#type.as_deref(), identity).await 762 + } 763 + 764 + async fn resolve_service_impl( 765 + did: &Did, 766 + id_fragment: &str, 767 + service_type: Option<&str>, 768 + identity: Identity, 769 + ) -> ResolveServiceResponse { 770 + let invalid = |reason: &'static str| { 771 + ResolveServiceResponse::BadRequest(xrpc_error("InvalidRequest", reason)) 772 + }; 773 + let Ok(service_mini_doc) = identity.did_to_mini_service_doc(did).await else { 774 + return invalid("Failed to get DID doc"); 775 + }; 776 + let Some(service_mini_doc) = service_mini_doc else { 777 + return invalid("Failed to find DID doc"); 778 + }; 779 + 780 + let Some(matching) = service_mini_doc.get(id_fragment, service_type) else { 781 + return invalid("failed to match identity (and maybe type)"); 782 + }; 783 + 784 + ResolveServiceResponse::Ok(Json(ServiceResponseObject { 785 + endpoint: matching.endpoint.clone(), 786 + })) 787 + } 788 + 789 /// com.bad-example.proxy.hydrateQueryResponse 790 /// 791 /// > [!important] ··· 801 &self, 802 Json(payload): Json<ProxyQueryPayload>, 803 ) -> ProxyHydrateResponse { 804 let params = if let Some(p) = payload.params { 805 let serde_json::Value::Object(map) = p else { 806 panic!("params have to be an object"); 807 }; 808 Some(map) 809 + } else { 810 + None 811 + }; 812 + 813 + let Some((service_did, id_fragment)) = payload.atproto_proxy.rsplit_once("#") else { 814 + return ProxyHydrateResponse::BadRequest(xrpc_error( 815 + "BadParameter", 816 + "atproto_proxy could not be understood", 817 + )); 818 + }; 819 820 + let Ok(service_did) = service_did.parse() else { 821 + return ProxyHydrateResponse::BadRequest(xrpc_error( 822 + "BadParameter", 823 + "atproto_proxy service did could not be parsed", 824 + )); 825 + }; 826 + 827 + let Ok(xrpc) = payload.xrpc.parse() else { 828 + return ProxyHydrateResponse::BadRequest(xrpc_error( 829 + "BadParameter", 830 + "invalid NSID for xrpc param", 831 + )); 832 + }; 833 + 834 + match self 835 + .proxy 836 + .proxy( 837 + &service_did, 838 + &format!("#{id_fragment}"), 839 + &xrpc, 840 + payload.authorization.as_deref(), 841 + payload.atproto_accept_labelers.as_deref(), 842 + params, 843 + ) 844 + .await 845 + { 846 Ok(skeleton) => { 847 let links = match extract_links(payload.hydration_sources, &skeleton) { 848 Ok(l) => l, 849 Err(e) => { 850 log::warn!("problem extracting: {e:?}"); 851 + return ProxyHydrateResponse::BadRequest(xrpc_error( 852 + "oop", 853 + "sorry, error extracting", 854 + )); 855 } 856 }; 857 let mut records = HashMap::new(); 858 let mut identifiers = HashMap::new(); 859 860 enum GetThing { 861 + Record(String, Hydration<FoundRecordResponseObject>), 862 + Identifier(String, Hydration<MiniDocResponseObject>), 863 } 864 865 let (tx, mut rx) = mpsc::channel(1); 866 867 + let t0 = Instant::now(); 868 + 869 + for (i, link) in links.into_iter().enumerate() { 870 match link { 871 + MatchedRef::AtUri(parts) => { 872 + let non_canonical_url = parts.to_uri(); 873 + if records.contains_key(&non_canonical_url) { 874 log::warn!("skipping duplicate record without checking cid"); 875 continue; 876 } 877 + let mut follow_up = self.base_url.clone(); 878 + follow_up.set_path("/xrpc/com.atproto.repo.getRecord"); 879 + follow_up 880 + .query_pairs_mut() 881 + .append_pair("repo", &Into::<String>::into(parts.repo.clone())) 882 + .append_pair("collection", parts.collection.as_str()) 883 + .append_pair("rkey", parts.rkey.as_str()); 884 + if let Some(ref cid) = parts.cid { 885 + follow_up 886 + .query_pairs_mut() 887 + .append_pair("cid", &cid.as_ref().to_string()); 888 + } 889 + 890 + if i >= 100 { 891 + records.insert( 892 + non_canonical_url.clone(), 893 + Hydration::Pending(ProxyHydrationPending { 894 + reason: "limit".to_string(), 895 + follow_up: follow_up.to_string(), 896 + }), 897 + ); 898 + continue; 899 + } else { 900 + records.insert( 901 + non_canonical_url.clone(), 902 + Hydration::Pending(ProxyHydrationPending { 903 + reason: "deadline".to_string(), 904 + follow_up: follow_up.to_string(), 905 + }), 906 + ); 907 + } 908 + 909 let tx = tx.clone(); 910 let identity = self.identity.clone(); 911 let repo = self.repo.clone(); 912 tokio::task::spawn(async move { 913 + let FullAtUriParts { 914 + repo: ident, 915 + collection, 916 + rkey, 917 + cid, 918 + } = parts; 919 + let did = match ident { 920 + AtIdentifier::Did(did) => did, 921 + AtIdentifier::Handle(handle) => { 922 + let Ok(Some(did)) = identity.handle_to_did(handle).await 923 + else { 924 + let res = Hydration::Error(ProxyHydrationError { 925 + reason: "could not resolve handle".to_string(), 926 + should_retry: true, 927 + follow_up: follow_up.to_string(), 928 + }); 929 + return if tx 930 + .send(GetThing::Record(non_canonical_url, res)) 931 + .await 932 + .is_ok() 933 + { 934 + metrics::counter!("slingshot_hydrated_one", "type" => "record", "ontime" => "true").increment(1); 935 + } else { 936 + metrics::counter!("slingshot_hydrated_one", "type" => "record", "ontime" => "false").increment(1); 937 + }; 938 + }; 939 + did 940 + } 941 }; 942 943 + let res = 944 + match repo.get_record(&did, &collection, &rkey, &cid).await { 945 + Ok(CachedRecord::Deleted) => { 946 + Hydration::Error(ProxyHydrationError { 947 + reason: "record deleted".to_string(), 948 + should_retry: false, 949 + follow_up: follow_up.to_string(), 950 + }) 951 + } 952 + Ok(CachedRecord::Found(RawRecord { 953 + cid: found_cid, 954 + record, 955 + })) => { 956 + if cid 957 + .as_ref() 958 + .map(|expected| *expected != found_cid) 959 + .unwrap_or(false) 960 + { 961 + Hydration::Error(ProxyHydrationError { 962 + reason: "not found".to_string(), 963 + should_retry: false, 964 + follow_up: follow_up.to_string(), 965 + }) 966 + } else if let Ok(value) = serde_json::from_str(&record) 967 + { 968 + let canonical_uri = FullAtUriParts { 969 + repo: AtIdentifier::Did(did), 970 + collection, 971 + rkey, 972 + cid: None, // not used for .to_uri 973 + } 974 + .to_uri(); 975 + Hydration::Found(FoundRecordResponseObject { 976 + cid: Some(found_cid.as_ref().to_string()), 977 + uri: canonical_uri, 978 + value, 979 + }) 980 + } else { 981 + Hydration::Error(ProxyHydrationError { 982 + reason: "could not parse upstream response" 983 + .to_string(), 984 + should_retry: false, 985 + follow_up: follow_up.to_string(), 986 + }) 987 + } 988 } 989 + Err(e) => { 990 + log::warn!("finally oop {e:?}"); 991 + Hydration::Error(ProxyHydrationError { 992 + reason: "failed to fetch record".to_string(), 993 + should_retry: true, // TODO 994 + follow_up: follow_up.to_string(), 995 + }) 996 + } 997 + }; 998 + if tx 999 + .send(GetThing::Record(non_canonical_url, res)) 1000 + .await 1001 + .is_ok() 1002 + { 1003 + metrics::counter!("slingshot_hydrated_one", "type" => "record", "ontime" => "true").increment(1); 1004 + } else { 1005 + metrics::counter!("slingshot_hydrated_one", "type" => "record", "ontime" => "false").increment(1); 1006 + } 1007 }); 1008 } 1009 MatchedRef::Identifier(id) => { 1010 + let identifier: String = id.clone().into(); 1011 + if identifiers.contains_key(&identifier) { 1012 continue; 1013 } 1014 + 1015 + let mut follow_up = self.base_url.clone(); 1016 + follow_up.set_path("/xrpc/blue.microcosm.identity.resolveMiniDoc"); 1017 + 1018 + follow_up 1019 + .query_pairs_mut() 1020 + .append_pair("identifier", &identifier); 1021 + 1022 + if i >= 100 { 1023 + identifiers.insert( 1024 + identifier.clone(), 1025 + Hydration::Pending(ProxyHydrationPending { 1026 + reason: "limit".to_string(), 1027 + follow_up: follow_up.to_string(), 1028 + }), 1029 + ); 1030 + continue; 1031 + } else { 1032 + identifiers.insert( 1033 + identifier.clone(), 1034 + Hydration::Pending(ProxyHydrationPending { 1035 + reason: "deadline".to_string(), 1036 + follow_up: follow_up.to_string(), 1037 + }), 1038 + ); 1039 + } 1040 + 1041 let tx = tx.clone(); 1042 let identity = self.identity.clone(); 1043 tokio::task::spawn(async move { 1044 + let res = match Self::resolve_mini_doc_impl(&identifier, identity) 1045 + .await 1046 + { 1047 + ResolveMiniDocResponse::Ok(Json(mini_doc)) => { 1048 + Hydration::Found(mini_doc) 1049 + } 1050 + ResolveMiniDocResponse::BadRequest(e) => { 1051 log::warn!("minidoc fail: {:?}", e.0); 1052 Hydration::Error(ProxyHydrationError { 1053 reason: "failed to resolve mini doc".to_string(), 1054 + should_retry: false, 1055 + follow_up: follow_up.to_string(), 1056 }) 1057 } 1058 }; 1059 + if tx.send(GetThing::Identifier(identifier, res)).await.is_ok() { 1060 + metrics::counter!("slingshot_hydrated_one", "type" => "identity", "ontime" => "true").increment(1); 1061 + } else { 1062 + metrics::counter!("slingshot_hydrated_one", "type" => "identity", "ontime" => "false").increment(1); 1063 + } 1064 }); 1065 } 1066 } 1067 } 1068 // so the channel can close when all are completed 1069 // (we shoudl be doing a timeout...) 1070 drop(tx); 1071 1072 + let deadline = t0 + std::time::Duration::from_secs_f64(1.6); 1073 + let res = tokio::time::timeout_at(deadline.into(), async { 1074 + while let Some(hydration) = rx.recv().await { 1075 + match hydration { 1076 + GetThing::Record(uri, h) => { 1077 + if let Some(r) = records.get_mut(&uri) { 1078 + match (&r, &h) { 1079 + (_, Hydration::Found(_)) => *r = h, // always replace if found 1080 + (Hydration::Pending(_), _) => *r = h, // or if it was pending 1081 + _ => {} // else leave it 1082 + } 1083 + } else { 1084 + records.insert(uri, h); 1085 + } 1086 + } 1087 + GetThing::Identifier(identifier, md) => { 1088 + identifiers.insert(identifier.to_string(), md); 1089 + } 1090 + }; 1091 + } 1092 + }) 1093 + .await; 1094 + 1095 + if res.is_ok() { 1096 + metrics::histogram!("slingshot_hydration_all_completed").record(t0.elapsed()); 1097 + } else { 1098 + metrics::counter!("slingshot_hydration_cut_off").increment(1); 1099 } 1100 1101 ProxyHydrateResponse::Ok(Json(ProxyHydrateResponseObject { 1102 output: skeleton, 1103 records, 1104 identifiers, 1105 })) 1106 } 1107 Err(e) => { ··· 1109 ProxyHydrateResponse::BadRequest(xrpc_error("oop", "sorry")) 1110 } 1111 } 1112 } 1113 1114 async fn get_record_impl( 1115 &self, 1116 + repo: &str, 1117 + collection: &str, 1118 + rkey: &str, 1119 + cid: Option<&str>, 1120 ) -> GetRecordResponse { 1121 + let did = match Did::new(repo.to_string()) { 1122 Ok(did) => did, 1123 Err(_) => { 1124 let Ok(handle) = Handle::new(repo.to_lowercase()) else { ··· 1149 } 1150 }; 1151 1152 + let Ok(collection) = Nsid::new(collection.to_string()) else { 1153 return GetRecordResponse::BadRequest(xrpc_error( 1154 "InvalidRequest", 1155 "Invalid NSID for collection", 1156 )); 1157 }; 1158 1159 + let Ok(rkey) = RecordKey::new(rkey.to_string()) else { 1160 return GetRecordResponse::BadRequest(xrpc_error("InvalidRequest", "Invalid rkey")); 1161 }; 1162 1163 let cid: Option<Cid> = if let Some(cid) = cid { 1164 + let Ok(cid) = Cid::from_str(cid) else { 1165 return GetRecordResponse::BadRequest(xrpc_error("InvalidRequest", "Invalid CID")); 1166 }; 1167 Some(cid) ··· 1310 identity: Identity, 1311 repo: Repo, 1312 proxy: Proxy, 1313 + base_url: url::Url, 1314 + tls_domain: Option<String>, 1315 + tls_certs: Option<PathBuf>, 1316 + tls_ipv6: bool, 1317 + acme_challenge_redirect: Option<String>, 1318 acme_contact: Option<String>, 1319 + acme_staging: bool, 1320 shutdown: CancellationToken, 1321 bind: std::net::SocketAddr, 1322 ) -> Result<(), ServerError> { ··· 1324 let proxy = Arc::new(proxy); 1325 let api_service = OpenApiService::new( 1326 Xrpc { 1327 + base_url, 1328 cache, 1329 identity, 1330 proxy, ··· 1333 "Slingshot", 1334 env!("CARGO_PKG_VERSION"), 1335 ) 1336 + .server(if let Some(ref h) = tls_domain { 1337 format!("https://{h}") 1338 } else { 1339 format!("http://{bind}") // yeah should probably fix this for reverse-proxy scenarios but it's ok for dev for now ··· 1349 "https://microcosm.blue/slingshot", 1350 )); 1351 1352 + let app = Route::new() 1353 .at("/", StaticFileEndpoint::new("./static/index.html")) 1354 .nest("/openapi", api_service.spec_endpoint()) 1355 .nest("/xrpc/", api_service); 1356 1357 + let cors = Cors::new() 1358 + .allow_origin_regex("*") 1359 + .allow_methods([Method::GET, Method::POST]) 1360 + .allow_credentials(false); 1361 + 1362 + if let Some(domain) = tls_domain { 1363 rustls::crypto::aws_lc_rs::default_provider() 1364 .install_default() 1365 .expect("alskfjalksdjf"); 1366 1367 + let app = app 1368 + .at("/.well-known/did.json", get_did_doc(&domain)) 1369 + .with(cors); 1370 1371 if let Some(contact) = acme_contact { 1372 + let (listener, app) = acmify(app, domain, tls_certs, tls_ipv6, contact, acme_staging)?; 1373 + run(listener, app, shutdown).await 1374 + } else { 1375 + let certs = tls_certs.expect("certs path must be set for non-acme tls"); 1376 + let (listener, app) = tlsify(app, domain, certs, tls_ipv6, acme_challenge_redirect)?; 1377 + run(listener, app, shutdown).await 1378 } 1379 + } else { 1380 + run(TcpListener::bind(bind), app.with(cors), shutdown).await 1381 + } 1382 + } 1383 1384 + fn acmify( 1385 + app: impl Endpoint + 'static, 1386 + domain: String, 1387 + tls_certs: Option<PathBuf>, 1388 + tls_ipv6: bool, 1389 + acme_contact: String, 1390 + acme_staging: bool, 1391 + ) -> Result<(impl Listener + 'static, impl Endpoint + 'static), ServerError> { 1392 + let mut auto_cert = AutoCert::builder() 1393 + .contact(acme_contact) 1394 + .directory_url(if acme_staging { 1395 + LETS_ENCRYPT_STAGING 1396 + } else { 1397 + LETS_ENCRYPT_PRODUCTION 1398 + }) 1399 + .domain(&domain) 1400 + .challenge_type(ChallengeType::Http01); 1401 + 1402 + if let Some(path) = tls_certs { 1403 + auto_cert = auto_cert.cache_path(path); 1404 } else { 1405 + log::warn!( 1406 + "provisioning acme certs without `--tls-certs` folder configured, you might hit letsencrypt rate limits." 1407 + ); 1408 } 1409 + 1410 + let auto_cert = auto_cert.build().map_err(ServerError::AcmeBuildError)?; 1411 + 1412 + let app = RouteScheme::new() 1413 + .https(app) 1414 + .http(auto_cert.http_01_endpoint()); 1415 + 1416 + let listener = TcpListener::bind(if tls_ipv6 { "[::]:443" } else { "0.0.0.0:443" }) 1417 + .acme(auto_cert) 1418 + .combine(TcpListener::bind(if tls_ipv6 { 1419 + "[::]:80" 1420 + } else { 1421 + "0.0.0.0:80" 1422 + })); 1423 + 1424 + Ok((listener, app)) 1425 } 1426 1427 + fn tlsify( 1428 + app: impl Endpoint + 'static, 1429 + domain: String, 1430 + tls_certs: PathBuf, 1431 + tls_ipv6: bool, 1432 + acme_challenge_redirect: Option<String>, 1433 + ) -> Result<(impl Listener + 'static, impl Endpoint + 'static), ServerError> { 1434 + use poem::listener::{RustlsCertificate, RustlsConfig}; 1435 + use std::path::Path; 1436 + 1437 + fn load_tls_config(f: &Path, domain: &str) -> Result<RustlsConfig, std::io::Error> { 1438 + let cert_contents = std::fs::read(f.join("cert.pem")) 1439 + .inspect_err(|e| log::error!("failed to read cert file in {f:?}: {e}"))?; 1440 + 1441 + let key_contents = std::fs::read(f.join("key.pem")) 1442 + .inspect_err(|e| log::error!("failed to read key file in {f:?}: {e}"))?; 1443 + 1444 + let cert = RustlsCertificate::new() 1445 + .cert(cert_contents) 1446 + .key(key_contents); 1447 + Ok(RustlsConfig::new().certificate(domain, cert)) 1448 + } 1449 + 1450 + let listener = TcpListener::bind(if tls_ipv6 { "[::]:443" } else { "0.0.0.0:443" }) 1451 + .rustls(async_stream::stream! { 1452 + loop { 1453 + if let Ok(tls_config) = load_tls_config(&tls_certs, &domain) { 1454 + // TODO: cert reload healthcheck 1455 + yield tls_config; 1456 + } else { 1457 + log::warn!("failed to load tls config."); 1458 + } 1459 + tokio::time::sleep(std::time::Duration::from_secs(3600 * 12)).await; 1460 + } 1461 + }) 1462 + // TODO: should be allowed to run in tls mode without binding port 80 if not forwarding acme challenges 1463 + .combine(TcpListener::bind(if tls_ipv6 { 1464 + "[::]:80" 1465 + } else { 1466 + "0.0.0.0:80" 1467 + })); 1468 + 1469 + let app = if let Some(redir) = acme_challenge_redirect { 1470 + use poem::web; 1471 + 1472 + let redirect = poem::endpoint::make_sync(move |req| { 1473 + let token = req.path_params::<String>().unwrap(); 1474 + metrics::counter!("http_challenge_redirects").increment(1); 1475 + web::Redirect::temporary(format!("{redir}{token}")) 1476 + }); 1477 + 1478 + RouteScheme::new() 1479 + .https(app) 1480 + .http(Route::new().at("/.well-known/acme-challenge/:token", redirect)) 1481 + } else { 1482 + // just uh... 404 for port 80? should probably reply with something. 1483 + RouteScheme::new().https(app).http(Route::new()) 1484 + }; 1485 + 1486 + Ok((listener, app)) 1487 + } 1488 + 1489 + async fn run<L, A>(listener: L, app: A, shutdown: CancellationToken) -> Result<(), ServerError> 1490 where 1491 L: Listener + 'static, 1492 + A: Endpoint + 'static, 1493 { 1494 let app = app 1495 .with(CatchPanic::new()) 1496 .around(request_counter) 1497 .with(Tracing);