//! High-level HTTP/1.1 client with connection pooling. //! //! Brings together TCP, TLS 1.3, DNS, URL parsing, and HTTP message //! parsing into a single `HttpClient` that can fetch HTTP and HTTPS URLs. use std::collections::HashMap; use std::fmt; use std::io; use std::time::{Duration, Instant}; use we_url::Url; use crate::http::{self, Headers, HttpResponse, Method}; use crate::tcp::{self, TcpConnection}; use crate::tls::handshake::{self, HandshakeError, TlsStream}; // --------------------------------------------------------------------------- // Constants // --------------------------------------------------------------------------- const DEFAULT_MAX_REDIRECTS: u32 = 10; const DEFAULT_CONNECT_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_READ_TIMEOUT: Duration = Duration::from_secs(30); const DEFAULT_MAX_IDLE_TIME: Duration = Duration::from_secs(60); const DEFAULT_MAX_PER_HOST: usize = 6; const READ_BUF_SIZE: usize = 8192; // --------------------------------------------------------------------------- // Error type // --------------------------------------------------------------------------- /// Errors that can occur during an HTTP client operation. #[derive(Debug)] pub enum ClientError { /// URL is invalid or missing required components. InvalidUrl(String), /// Unsupported URL scheme. UnsupportedScheme(String), /// TCP connection error. Tcp(tcp::NetError), /// TLS handshake error. Tls(HandshakeError), /// HTTP parsing error. Http(http::HttpError), /// Too many redirects. TooManyRedirects, /// Connection was closed unexpectedly. ConnectionClosed, /// I/O error. Io(io::Error), } impl fmt::Display for ClientError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidUrl(s) => write!(f, "invalid URL: {s}"), Self::UnsupportedScheme(s) => write!(f, "unsupported scheme: {s}"), Self::Tcp(e) => write!(f, "TCP error: {e}"), Self::Tls(e) => write!(f, "TLS error: {e}"), Self::Http(e) => write!(f, "HTTP error: {e}"), Self::TooManyRedirects => write!(f, "too many redirects"), Self::ConnectionClosed => write!(f, "connection closed"), Self::Io(e) => write!(f, "I/O error: {e}"), } } } impl From for ClientError { fn from(e: tcp::NetError) -> Self { Self::Tcp(e) } } impl From for ClientError { fn from(e: HandshakeError) -> Self { Self::Tls(e) } } impl From for ClientError { fn from(e: http::HttpError) -> Self { Self::Http(e) } } impl From for ClientError { fn from(e: io::Error) -> Self { Self::Io(e) } } pub type Result = std::result::Result; // --------------------------------------------------------------------------- // Connection abstraction // --------------------------------------------------------------------------- /// A connection that can be either plain TCP or TLS-encrypted. enum Connection { Plain(TcpConnection), Tls(TlsStream), } impl Connection { fn read(&mut self, buf: &mut [u8]) -> Result { match self { Self::Plain(tcp) => tcp.read(buf).map_err(ClientError::Tcp), Self::Tls(tls) => tls.read(buf).map_err(ClientError::Tls), } } fn write_all(&mut self, data: &[u8]) -> Result<()> { match self { Self::Plain(tcp) => tcp.write_all(data).map_err(ClientError::Tcp), Self::Tls(tls) => tls.write_all(data).map_err(ClientError::Tls), } } fn flush(&mut self) -> Result<()> { match self { Self::Plain(tcp) => tcp.flush().map_err(ClientError::Tcp), Self::Tls(_) => Ok(()), // TLS writes are flushed per record } } fn set_read_timeout(&self, duration: Option) -> Result<()> { match self { Self::Plain(tcp) => tcp.set_read_timeout(duration).map_err(ClientError::Tcp), Self::Tls(tls) => tls .stream() .set_read_timeout(duration) .map_err(ClientError::Tcp), } } } // --------------------------------------------------------------------------- // Connection pool // --------------------------------------------------------------------------- /// Key for pooling connections by origin. #[derive(Hash, Eq, PartialEq, Clone, Debug)] struct ConnectionKey { host: String, port: u16, is_tls: bool, } /// A pooled connection with its idle timestamp. struct PooledConnection { conn: Connection, idle_since: Instant, } /// Pool of idle HTTP connections for reuse. struct ConnectionPool { connections: HashMap>, max_idle_time: Duration, max_per_host: usize, } impl ConnectionPool { fn new(max_idle_time: Duration, max_per_host: usize) -> Self { Self { connections: HashMap::new(), max_idle_time, max_per_host, } } /// Take an idle connection for the given key, if one is available. fn take(&mut self, key: &ConnectionKey) -> Option { let entries = self.connections.get_mut(key)?; let now = Instant::now(); // Remove expired connections entries.retain(|pc| now.duration_since(pc.idle_since) < self.max_idle_time); // Take the most recently idled connection entries.pop().map(|pc| pc.conn) } /// Return a connection to the pool. fn put(&mut self, key: ConnectionKey, conn: Connection) { let entries = self.connections.entry(key).or_default(); // Evict oldest if at capacity if entries.len() >= self.max_per_host { entries.remove(0); } entries.push(PooledConnection { conn, idle_since: Instant::now(), }); } } // --------------------------------------------------------------------------- // HttpClient // --------------------------------------------------------------------------- /// High-level HTTP/1.1 client with connection pooling and redirect following. pub struct HttpClient { pool: ConnectionPool, max_redirects: u32, connect_timeout: Duration, read_timeout: Duration, } impl HttpClient { /// Create a new HTTP client with default settings. pub fn new() -> Self { Self { pool: ConnectionPool::new(DEFAULT_MAX_IDLE_TIME, DEFAULT_MAX_PER_HOST), max_redirects: DEFAULT_MAX_REDIRECTS, connect_timeout: DEFAULT_CONNECT_TIMEOUT, read_timeout: DEFAULT_READ_TIMEOUT, } } /// Set the maximum number of redirects to follow. pub fn set_max_redirects(&mut self, max: u32) { self.max_redirects = max; } /// Set the connection timeout. pub fn set_connect_timeout(&mut self, timeout: Duration) { self.connect_timeout = timeout; } /// Set the read timeout. pub fn set_read_timeout(&mut self, timeout: Duration) { self.read_timeout = timeout; } /// Perform an HTTP GET request. pub fn get(&mut self, url: &Url) -> Result { self.request(Method::Get, url, &Headers::new(), None) } /// Perform an HTTP POST request. pub fn post(&mut self, url: &Url, body: &[u8], content_type: &str) -> Result { let mut headers = Headers::new(); headers.add("Content-Type", content_type); self.request(Method::Post, url, &headers, Some(body)) } /// Perform an HTTP request with full control over method, headers, and body. /// /// Follows redirects (301, 302, 307, 308) up to `max_redirects`. pub fn request( &mut self, method: Method, url: &Url, headers: &Headers, body: Option<&[u8]>, ) -> Result { let mut current_url = url.clone(); let mut redirects = 0; loop { let resp = self.execute_request(method, ¤t_url, headers, body)?; // Check for redirects if matches!(resp.status_code, 301 | 302 | 307 | 308) { redirects += 1; if redirects > self.max_redirects { return Err(ClientError::TooManyRedirects); } if let Some(location) = resp.headers.get("Location") { // Resolve relative URLs against current URL current_url = Url::parse_with_base(location, ¤t_url) .or_else(|_| Url::parse(location)) .map_err(|_| { ClientError::InvalidUrl(format!( "invalid redirect location: {location}" )) })?; continue; } } return Ok(resp); } } /// Execute a single HTTP request (no redirect following). fn execute_request( &mut self, method: Method, url: &Url, headers: &Headers, body: Option<&[u8]>, ) -> Result { let scheme = url.scheme(); let is_tls = match scheme { "https" => true, "http" => false, other => return Err(ClientError::UnsupportedScheme(other.to_string())), }; let host = url .host_str() .ok_or_else(|| ClientError::InvalidUrl("missing host".to_string()))?; let port = url .port_or_default() .ok_or_else(|| ClientError::InvalidUrl("cannot determine port".to_string()))?; let path = request_path(url); let key = ConnectionKey { host: host.clone(), port, is_tls, }; // Try to reuse a pooled connection, fall back to new connection let mut conn = match self.pool.take(&key) { Some(conn) => conn, None => self.connect(&host, port, is_tls)?, }; conn.set_read_timeout(Some(self.read_timeout))?; // Serialize and send request let request_bytes = http::serialize_request(method, &path, &host, headers, body); conn.write_all(&request_bytes)?; conn.flush()?; // Read and parse response let response = read_response(&mut conn)?; // Return connection to pool if keep-alive if !response.connection_close() { self.pool.put(key, conn); } Ok(response) } /// Establish a new connection (plain TCP or TLS). fn connect(&self, host: &str, port: u16, is_tls: bool) -> Result { let tcp = TcpConnection::connect_timeout(host, port, self.connect_timeout)?; if is_tls { let tls = handshake::connect(tcp, host)?; Ok(Connection::Tls(tls)) } else { Ok(Connection::Plain(tcp)) } } } impl Default for HttpClient { fn default() -> Self { Self::new() } } // --------------------------------------------------------------------------- // Helpers // --------------------------------------------------------------------------- /// Build the request path from a URL (path + query). fn request_path(url: &Url) -> String { let path = url.path(); let path = if path.is_empty() { "/" } else { &path }; match url.query() { Some(q) => format!("{path}?{q}"), None => path.to_string(), } } /// Read a complete HTTP response from a connection. /// /// Reads the header section first, then determines the body length from /// headers and reads the appropriate amount of body data. fn read_response(conn: &mut Connection) -> Result { let mut buf = Vec::with_capacity(READ_BUF_SIZE); let mut temp = [0u8; READ_BUF_SIZE]; // Phase 1: Read until we have the complete header section (\r\n\r\n) let header_end = loop { let n = conn.read(&mut temp)?; if n == 0 { if buf.is_empty() { return Err(ClientError::ConnectionClosed); } break find_header_end(&buf); } buf.extend_from_slice(&temp[..n]); if let Some(pos) = find_header_end(&buf) { break Some(pos); } }; let header_end = header_end.ok_or(ClientError::Http(http::HttpError::Incomplete))?; let body_start = header_end + 4; // skip \r\n\r\n // Quick-parse headers to determine body strategy let header_str = std::str::from_utf8(&buf[..header_end]).map_err(|_| { ClientError::Http(http::HttpError::Parse( "invalid UTF-8 in headers".to_string(), )) })?; let status_code = parse_status_code(header_str)?; let body_strategy = determine_body_strategy(header_str, status_code); // Phase 2: Read body according to strategy match body_strategy { BodyStrategy::NoBody => { // Truncate buffer to just headers + \r\n\r\n buf.truncate(body_start); } BodyStrategy::ContentLength(len) => { let total_needed = body_start + len; while buf.len() < total_needed { let n = conn.read(&mut temp)?; if n == 0 { break; } buf.extend_from_slice(&temp[..n]); } } BodyStrategy::Chunked => { // Read until we find the terminating 0-length chunk while !has_chunked_terminator(&buf[body_start..]) { let n = conn.read(&mut temp)?; if n == 0 { break; } buf.extend_from_slice(&temp[..n]); } } BodyStrategy::ReadUntilClose => { // Read until EOF loop { let n = conn.read(&mut temp)?; if n == 0 { break; } buf.extend_from_slice(&temp[..n]); } } } // Parse the complete response http::parse_response(&buf).map_err(ClientError::Http) } /// Find the end of the HTTP header section (\r\n\r\n). /// Returns the position of the first \r in the \r\n\r\n sequence. fn find_header_end(data: &[u8]) -> Option { data.windows(4).position(|w| w == b"\r\n\r\n") } /// Extract status code from the first line of headers. fn parse_status_code(headers: &str) -> Result { let first_line = headers.lines().next().unwrap_or(""); let mut parts = first_line.splitn(3, ' '); let _version = parts.next(); let code_str = parts.next().unwrap_or(""); code_str.parse().map_err(|_| { ClientError::Http(http::HttpError::MalformedStatusLine(first_line.to_string())) }) } /// Strategy for reading the response body. enum BodyStrategy { NoBody, ContentLength(usize), Chunked, ReadUntilClose, } /// Extract the value for a header name (case-insensitive match). fn header_value<'a>(line: &'a str, name: &str) -> Option<&'a str> { let colon = line.find(':')?; if line[..colon].eq_ignore_ascii_case(name) { Some(line[colon + 1..].trim()) } else { None } } /// Determine how to read the body from headers. fn determine_body_strategy(headers: &str, status_code: u16) -> BodyStrategy { // 1xx, 204, 304 have no body if status_code < 200 || status_code == 204 || status_code == 304 { return BodyStrategy::NoBody; } // Check for Transfer-Encoding: chunked for line in headers.split("\r\n").skip(1) { if let Some(val) = header_value(line, "transfer-encoding") { if val.eq_ignore_ascii_case("chunked") { return BodyStrategy::Chunked; } } } // Check for Content-Length for line in headers.split("\r\n").skip(1) { if let Some(val) = header_value(line, "content-length") { if let Ok(len) = val.parse::() { return BodyStrategy::ContentLength(len); } } } BodyStrategy::ReadUntilClose } /// Check if chunked body data contains the terminating `0\r\n\r\n`. fn has_chunked_terminator(data: &[u8]) -> bool { // Look for \r\n0\r\n\r\n (the final chunk after some data) or 0\r\n\r\n at start data.windows(5).any(|w| w == b"0\r\n\r\n") } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; // -- ClientError Display tests -- #[test] fn error_display_invalid_url() { let e = ClientError::InvalidUrl("bad".to_string()); assert_eq!(e.to_string(), "invalid URL: bad"); } #[test] fn error_display_unsupported_scheme() { let e = ClientError::UnsupportedScheme("ftp".to_string()); assert_eq!(e.to_string(), "unsupported scheme: ftp"); } #[test] fn error_display_too_many_redirects() { let e = ClientError::TooManyRedirects; assert_eq!(e.to_string(), "too many redirects"); } #[test] fn error_display_connection_closed() { let e = ClientError::ConnectionClosed; assert_eq!(e.to_string(), "connection closed"); } // -- HttpClient configuration tests -- #[test] fn client_default() { let client = HttpClient::default(); assert_eq!(client.max_redirects, DEFAULT_MAX_REDIRECTS); assert_eq!(client.connect_timeout, DEFAULT_CONNECT_TIMEOUT); assert_eq!(client.read_timeout, DEFAULT_READ_TIMEOUT); } #[test] fn client_set_max_redirects() { let mut client = HttpClient::new(); client.set_max_redirects(5); assert_eq!(client.max_redirects, 5); } #[test] fn client_set_connect_timeout() { let mut client = HttpClient::new(); client.set_connect_timeout(Duration::from_secs(10)); assert_eq!(client.connect_timeout, Duration::from_secs(10)); } #[test] fn client_set_read_timeout() { let mut client = HttpClient::new(); client.set_read_timeout(Duration::from_secs(5)); assert_eq!(client.read_timeout, Duration::from_secs(5)); } // -- ConnectionPool tests -- #[test] fn pool_take_empty() { let mut pool = ConnectionPool::new(Duration::from_secs(60), 6); let key = ConnectionKey { host: "example.com".to_string(), port: 80, is_tls: false, }; assert!(pool.take(&key).is_none()); } #[test] fn pool_connections_map_starts_empty() { let pool = ConnectionPool::new(Duration::from_secs(60), 6); assert!(pool.connections.is_empty()); } // -- request_path tests -- #[test] fn request_path_simple() { let url = Url::parse("http://example.com/path").unwrap(); assert_eq!(request_path(&url), "/path"); } #[test] fn request_path_with_query() { let url = Url::parse("http://example.com/path?key=value").unwrap(); assert_eq!(request_path(&url), "/path?key=value"); } #[test] fn request_path_root() { let url = Url::parse("http://example.com").unwrap(); assert_eq!(request_path(&url), "/"); } #[test] fn request_path_deep() { let url = Url::parse("http://example.com/a/b/c").unwrap(); assert_eq!(request_path(&url), "/a/b/c"); } // -- find_header_end tests -- #[test] fn find_header_end_found() { let data = b"HTTP/1.1 200 OK\r\nHost: x\r\n\r\nbody"; assert_eq!(find_header_end(data), Some(24)); } #[test] fn find_header_end_not_found() { let data = b"HTTP/1.1 200 OK\r\nHost: x\r\n"; assert_eq!(find_header_end(data), None); } #[test] fn find_header_end_empty() { assert_eq!(find_header_end(b""), None); } #[test] fn find_header_end_minimal() { let data = b"\r\n\r\n"; assert_eq!(find_header_end(data), Some(0)); } // -- parse_status_code tests -- #[test] fn parse_status_code_200() { assert_eq!(parse_status_code("HTTP/1.1 200 OK").unwrap(), 200); } #[test] fn parse_status_code_404() { assert_eq!(parse_status_code("HTTP/1.1 404 Not Found").unwrap(), 404); } #[test] fn parse_status_code_301() { assert_eq!( parse_status_code("HTTP/1.1 301 Moved Permanently").unwrap(), 301 ); } #[test] fn parse_status_code_invalid() { assert!(parse_status_code("INVALID").is_err()); } // -- determine_body_strategy tests -- #[test] fn strategy_no_body_204() { let headers = "HTTP/1.1 204 No Content\r\nConnection: keep-alive"; assert!(matches!( determine_body_strategy(headers, 204), BodyStrategy::NoBody )); } #[test] fn strategy_no_body_304() { let headers = "HTTP/1.1 304 Not Modified\r\nETag: \"abc\""; assert!(matches!( determine_body_strategy(headers, 304), BodyStrategy::NoBody )); } #[test] fn strategy_no_body_1xx() { let headers = "HTTP/1.1 100 Continue"; assert!(matches!( determine_body_strategy(headers, 100), BodyStrategy::NoBody )); } #[test] fn strategy_content_length() { let headers = "HTTP/1.1 200 OK\r\nContent-Length: 42"; match determine_body_strategy(headers, 200) { BodyStrategy::ContentLength(42) => {} _ => panic!("expected ContentLength(42)"), } } #[test] fn strategy_chunked() { let headers = "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked"; assert!(matches!( determine_body_strategy(headers, 200), BodyStrategy::Chunked )); } #[test] fn strategy_read_until_close() { let headers = "HTTP/1.1 200 OK\r\nConnection: close"; assert!(matches!( determine_body_strategy(headers, 200), BodyStrategy::ReadUntilClose )); } // -- has_chunked_terminator tests -- #[test] fn chunked_terminator_present() { assert!(has_chunked_terminator(b"5\r\nHello\r\n0\r\n\r\n")); } #[test] fn chunked_terminator_at_start() { assert!(has_chunked_terminator(b"0\r\n\r\n")); } #[test] fn chunked_terminator_missing() { assert!(!has_chunked_terminator(b"5\r\nHello\r\n")); } #[test] fn chunked_terminator_empty() { assert!(!has_chunked_terminator(b"")); } // -- ConnectionKey equality tests -- #[test] fn connection_key_equal() { let a = ConnectionKey { host: "example.com".to_string(), port: 443, is_tls: true, }; let b = ConnectionKey { host: "example.com".to_string(), port: 443, is_tls: true, }; assert_eq!(a, b); } #[test] fn connection_key_different_host() { let a = ConnectionKey { host: "a.com".to_string(), port: 443, is_tls: true, }; let b = ConnectionKey { host: "b.com".to_string(), port: 443, is_tls: true, }; assert_ne!(a, b); } #[test] fn connection_key_different_port() { let a = ConnectionKey { host: "example.com".to_string(), port: 80, is_tls: false, }; let b = ConnectionKey { host: "example.com".to_string(), port: 8080, is_tls: false, }; assert_ne!(a, b); } #[test] fn connection_key_different_tls() { let a = ConnectionKey { host: "example.com".to_string(), port: 443, is_tls: true, }; let b = ConnectionKey { host: "example.com".to_string(), port: 443, is_tls: false, }; assert_ne!(a, b); } // -- Header parsing strategy with case variations -- #[test] fn strategy_content_length_lowercase() { let headers = "HTTP/1.1 200 OK\r\ncontent-length: 10"; match determine_body_strategy(headers, 200) { BodyStrategy::ContentLength(10) => {} _ => panic!("expected ContentLength(10)"), } } #[test] fn strategy_chunked_lowercase() { let headers = "HTTP/1.1 200 OK\r\ntransfer-encoding: chunked"; assert!(matches!( determine_body_strategy(headers, 200), BodyStrategy::Chunked )); } #[test] fn strategy_chunked_uppercase_value() { let headers = "HTTP/1.1 200 OK\r\nTransfer-Encoding: CHUNKED"; assert!(matches!( determine_body_strategy(headers, 200), BodyStrategy::Chunked )); } #[test] fn strategy_content_length_mixed_case() { let headers = "HTTP/1.1 200 OK\r\nCONTENT-LENGTH: 99"; match determine_body_strategy(headers, 200) { BodyStrategy::ContentLength(99) => {} _ => panic!("expected ContentLength(99)"), } } #[test] fn strategy_chunked_mixed_case_name() { let headers = "HTTP/1.1 200 OK\r\nTRANSFER-ENCODING: chunked"; assert!(matches!( determine_body_strategy(headers, 200), BodyStrategy::Chunked )); } // -- URL scheme handling -- #[test] fn unsupported_scheme_error() { let mut client = HttpClient::new(); let url = Url::parse("ftp://example.com/file").unwrap(); let result = client.get(&url); assert!(matches!(result, Err(ClientError::UnsupportedScheme(_)))); } }