PDS software with bells & whistles you didn’t even know you needed. will move this to its own account when ready.
at main 8.9 kB view raw
1use reqwest::{Client, ClientBuilder, Url}; 2use std::net::{IpAddr, SocketAddr, ToSocketAddrs}; 3use std::sync::OnceLock; 4use std::time::Duration; 5use tracing::warn; 6 7pub const DEFAULT_HEADERS_TIMEOUT: Duration = Duration::from_secs(10); 8pub const DEFAULT_BODY_TIMEOUT: Duration = Duration::from_secs(30); 9pub const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(5); 10pub const MAX_RESPONSE_SIZE: u64 = 10 * 1024 * 1024; 11 12static PROXY_CLIENT: OnceLock<Client> = OnceLock::new(); 13static DID_RESOLUTION_CLIENT: OnceLock<Client> = OnceLock::new(); 14static HANDLE_RESOLUTION_CLIENT: OnceLock<Client> = OnceLock::new(); 15 16pub fn proxy_client() -> &'static Client { 17 PROXY_CLIENT.get_or_init(|| { 18 ClientBuilder::new() 19 .timeout(DEFAULT_BODY_TIMEOUT) 20 .connect_timeout(DEFAULT_CONNECT_TIMEOUT) 21 .pool_max_idle_per_host(10) 22 .pool_idle_timeout(Duration::from_secs(90)) 23 .redirect(reqwest::redirect::Policy::none()) 24 .build() 25 .expect( 26 "Failed to build HTTP client - this indicates a TLS or system configuration issue", 27 ) 28 }) 29} 30 31pub fn did_resolution_client() -> &'static Client { 32 DID_RESOLUTION_CLIENT.get_or_init(|| { 33 ClientBuilder::new() 34 .timeout(Duration::from_secs(5)) 35 .connect_timeout(DEFAULT_CONNECT_TIMEOUT) 36 .pool_max_idle_per_host(10) 37 .pool_idle_timeout(Duration::from_secs(90)) 38 .build() 39 .expect( 40 "Failed to build DID resolution client - this indicates a TLS or system configuration issue", 41 ) 42 }) 43} 44 45pub fn handle_resolution_client() -> &'static Client { 46 HANDLE_RESOLUTION_CLIENT.get_or_init(|| { 47 ClientBuilder::new() 48 .timeout(Duration::from_secs(10)) 49 .connect_timeout(DEFAULT_CONNECT_TIMEOUT) 50 .pool_max_idle_per_host(10) 51 .pool_idle_timeout(Duration::from_secs(90)) 52 .redirect(reqwest::redirect::Policy::limited(5)) 53 .build() 54 .expect( 55 "Failed to build handle resolution client - this indicates a TLS or system configuration issue", 56 ) 57 }) 58} 59 60pub fn is_ssrf_safe(url: &str) -> Result<(), SsrfError> { 61 let parsed = Url::parse(url).map_err(|_| SsrfError::InvalidUrl)?; 62 let scheme = parsed.scheme(); 63 if scheme != "https" { 64 let allow_http = std::env::var("ALLOW_HTTP_PROXY").is_ok() 65 || url.starts_with("http://127.0.0.1") 66 || url.starts_with("http://localhost"); 67 if !allow_http { 68 return Err(SsrfError::InsecureProtocol(scheme.to_string())); 69 } 70 } 71 let host = parsed.host_str().ok_or(SsrfError::NoHost)?; 72 if host == "localhost" { 73 return Ok(()); 74 } 75 if let Ok(ip) = host.parse::<IpAddr>() { 76 if ip.is_loopback() { 77 return Ok(()); 78 } 79 if !is_unicast_ip(&ip) { 80 return Err(SsrfError::NonUnicastIp(ip.to_string())); 81 } 82 return Ok(()); 83 } 84 let port = parsed 85 .port() 86 .unwrap_or(if scheme == "https" { 443 } else { 80 }); 87 let socket_addrs: Vec<SocketAddr> = match (host, port).to_socket_addrs() { 88 Ok(addrs) => addrs.collect(), 89 Err(_) => return Err(SsrfError::DnsResolutionFailed(host.to_string())), 90 }; 91 for addr in &socket_addrs { 92 if !is_unicast_ip(&addr.ip()) { 93 warn!( 94 "DNS resolution for {} returned non-unicast IP: {}", 95 host, 96 addr.ip() 97 ); 98 return Err(SsrfError::NonUnicastIp(addr.ip().to_string())); 99 } 100 } 101 Ok(()) 102} 103 104fn is_unicast_ip(ip: &IpAddr) -> bool { 105 match ip { 106 IpAddr::V4(v4) => { 107 !v4.is_loopback() 108 && !v4.is_broadcast() 109 && !v4.is_multicast() 110 && !v4.is_unspecified() 111 && !v4.is_link_local() 112 && !is_private_v4(v4) 113 } 114 IpAddr::V6(v6) => !v6.is_loopback() && !v6.is_multicast() && !v6.is_unspecified(), 115 } 116} 117 118fn is_private_v4(ip: &std::net::Ipv4Addr) -> bool { 119 let octets = ip.octets(); 120 octets[0] == 10 121 || (octets[0] == 172 && (16..=31).contains(&octets[1])) 122 || (octets[0] == 192 && octets[1] == 168) 123 || (octets[0] == 169 && octets[1] == 254) 124} 125 126#[derive(Debug, Clone)] 127pub enum SsrfError { 128 InvalidUrl, 129 InsecureProtocol(String), 130 NoHost, 131 NonUnicastIp(String), 132 DnsResolutionFailed(String), 133} 134 135impl std::fmt::Display for SsrfError { 136 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 137 match self { 138 SsrfError::InvalidUrl => write!(f, "Invalid URL"), 139 SsrfError::InsecureProtocol(p) => write!(f, "Insecure protocol: {}", p), 140 SsrfError::NoHost => write!(f, "No host in URL"), 141 SsrfError::NonUnicastIp(ip) => write!(f, "Non-unicast IP address: {}", ip), 142 SsrfError::DnsResolutionFailed(host) => { 143 write!(f, "DNS resolution failed for: {}", host) 144 } 145 } 146 } 147} 148 149impl std::error::Error for SsrfError {} 150 151pub const HEADERS_TO_FORWARD: &[&str] = &[ 152 "accept-language", 153 "atproto-accept-labelers", 154 "x-bsky-topics", 155 "content-type", 156]; 157pub const RESPONSE_HEADERS_TO_FORWARD: &[&str] = &[ 158 "atproto-repo-rev", 159 "atproto-content-labelers", 160 "retry-after", 161 "content-type", 162 "cache-control", 163 "etag", 164]; 165 166pub fn validate_at_uri(uri: &str) -> Result<AtUriParts, &'static str> { 167 if !uri.starts_with("at://") { 168 return Err("URI must start with at://"); 169 } 170 let path = uri.trim_start_matches("at://"); 171 let parts: Vec<&str> = path.split('/').collect(); 172 if parts.is_empty() { 173 return Err("URI missing DID"); 174 } 175 let did = parts[0]; 176 if !did.starts_with("did:") { 177 return Err("Invalid DID in URI"); 178 } 179 if parts.len() > 1 { 180 let collection = parts[1]; 181 if collection.is_empty() || !collection.contains('.') { 182 return Err("Invalid collection NSID"); 183 } 184 } 185 Ok(AtUriParts { 186 did: did.to_string(), 187 collection: parts.get(1).map(|s| s.to_string()), 188 rkey: parts.get(2).map(|s| s.to_string()), 189 }) 190} 191 192#[derive(Debug, Clone)] 193pub struct AtUriParts { 194 pub did: String, 195 pub collection: Option<String>, 196 pub rkey: Option<String>, 197} 198 199pub fn validate_limit(limit: Option<u32>, default: u32, max: u32) -> u32 { 200 match limit { 201 Some(0) => default, 202 Some(l) if l > max => max, 203 Some(l) => l, 204 None => default, 205 } 206} 207 208pub fn validate_did(did: &str) -> Result<(), &'static str> { 209 if !did.starts_with("did:") { 210 return Err("Invalid DID format"); 211 } 212 let parts: Vec<&str> = did.split(':').collect(); 213 if parts.len() < 3 { 214 return Err("DID must have at least method and identifier"); 215 } 216 let method = parts[1]; 217 if method != "plc" && method != "web" { 218 return Err("Unsupported DID method"); 219 } 220 Ok(()) 221} 222 223#[cfg(test)] 224mod tests { 225 use super::*; 226 #[test] 227 fn test_ssrf_safe_https() { 228 assert!(is_ssrf_safe("https://api.bsky.app/xrpc/test").is_ok()); 229 } 230 #[test] 231 fn test_ssrf_blocks_http_by_default() { 232 let result = is_ssrf_safe("http://external.example.com/xrpc/test"); 233 assert!(matches!( 234 result, 235 Err(SsrfError::InsecureProtocol(_)) | Err(SsrfError::DnsResolutionFailed(_)) 236 )); 237 } 238 #[test] 239 fn test_ssrf_allows_localhost_http() { 240 assert!(is_ssrf_safe("http://127.0.0.1:8080/test").is_ok()); 241 assert!(is_ssrf_safe("http://localhost:8080/test").is_ok()); 242 } 243 #[test] 244 fn test_validate_at_uri() { 245 let result = validate_at_uri("at://did:plc:test/app.bsky.feed.post/abc123"); 246 assert!(result.is_ok()); 247 let parts = result.unwrap(); 248 assert_eq!(parts.did, "did:plc:test"); 249 assert_eq!(parts.collection, Some("app.bsky.feed.post".to_string())); 250 assert_eq!(parts.rkey, Some("abc123".to_string())); 251 } 252 #[test] 253 fn test_validate_at_uri_invalid() { 254 assert!(validate_at_uri("https://example.com").is_err()); 255 assert!(validate_at_uri("at://notadid/collection/rkey").is_err()); 256 } 257 #[test] 258 fn test_validate_limit() { 259 assert_eq!(validate_limit(None, 50, 100), 50); 260 assert_eq!(validate_limit(Some(0), 50, 100), 50); 261 assert_eq!(validate_limit(Some(200), 50, 100), 100); 262 assert_eq!(validate_limit(Some(75), 50, 100), 75); 263 } 264 #[test] 265 fn test_validate_did() { 266 assert!(validate_did("did:plc:abc123").is_ok()); 267 assert!(validate_did("did:web:example.com").is_ok()); 268 assert!(validate_did("notadid").is_err()); 269 assert!(validate_did("did:unknown:test").is_err()); 270 } 271}