we (web engine): Experimental web browser project to understand the limits of Claude
at http-parser 482 lines 15 kB view raw
1//! TCP socket abstraction wrapping `std::net::TcpStream`. 2 3use std::fmt; 4use std::io::{self, BufRead, BufReader, BufWriter, Read, Write}; 5use std::net::{Shutdown, TcpStream, ToSocketAddrs}; 6use std::time::Duration; 7 8// --------------------------------------------------------------------------- 9// Error types 10// --------------------------------------------------------------------------- 11 12/// Network errors. 13#[derive(Debug)] 14pub enum NetError { 15 /// Connection was refused by the remote host. 16 ConnectionRefused, 17 /// Connection timed out. 18 Timeout, 19 /// DNS resolution failed for the given hostname. 20 DnsResolutionFailed(String), 21 /// An I/O error occurred. 22 Io(io::Error), 23} 24 25impl fmt::Display for NetError { 26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 27 match self { 28 Self::ConnectionRefused => write!(f, "connection refused"), 29 Self::Timeout => write!(f, "connection timed out"), 30 Self::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for '{host}'"), 31 Self::Io(e) => write!(f, "I/O error: {e}"), 32 } 33 } 34} 35 36impl From<io::Error> for NetError { 37 fn from(err: io::Error) -> Self { 38 match err.kind() { 39 io::ErrorKind::ConnectionRefused => NetError::ConnectionRefused, 40 io::ErrorKind::TimedOut => NetError::Timeout, 41 _ => NetError::Io(err), 42 } 43 } 44} 45 46pub type Result<T> = std::result::Result<T, NetError>; 47 48// --------------------------------------------------------------------------- 49// TcpConnection 50// --------------------------------------------------------------------------- 51 52/// A TCP connection wrapping `std::net::TcpStream`. 53pub struct TcpConnection { 54 stream: TcpStream, 55} 56 57impl TcpConnection { 58 /// Connect to a TCP server by hostname and port. 59 /// 60 /// Resolves the hostname via the system resolver and connects to the first 61 /// address that succeeds. 62 pub fn connect(host: &str, port: u16) -> Result<Self> { 63 let addr_str = format!("{host}:{port}"); 64 let addrs = addr_str 65 .to_socket_addrs() 66 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?; 67 68 let mut last_err = None; 69 for addr in addrs { 70 match TcpStream::connect(addr) { 71 Ok(stream) => return Ok(Self { stream }), 72 Err(e) => last_err = Some(e), 73 } 74 } 75 76 match last_err { 77 Some(e) => Err(NetError::from(e)), 78 None => Err(NetError::DnsResolutionFailed(host.to_string())), 79 } 80 } 81 82 /// Connect with a timeout. 83 pub fn connect_timeout(host: &str, port: u16, timeout: Duration) -> Result<Self> { 84 let addr_str = format!("{host}:{port}"); 85 let addrs: Vec<_> = addr_str 86 .to_socket_addrs() 87 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))? 88 .collect(); 89 90 let mut last_err = None; 91 for addr in addrs { 92 match TcpStream::connect_timeout(&addr, timeout) { 93 Ok(stream) => return Ok(Self { stream }), 94 Err(e) => last_err = Some(e), 95 } 96 } 97 98 match last_err { 99 Some(e) => Err(NetError::from(e)), 100 None => Err(NetError::DnsResolutionFailed(host.to_string())), 101 } 102 } 103 104 /// Read bytes into the buffer. Returns the number of bytes read. 105 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { 106 self.stream.read(buf).map_err(NetError::from) 107 } 108 109 /// Read exactly `buf.len()` bytes, blocking until complete or error. 110 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { 111 self.stream.read_exact(buf).map_err(NetError::from) 112 } 113 114 /// Write bytes. Returns the number of bytes written. 115 pub fn write(&mut self, data: &[u8]) -> Result<usize> { 116 self.stream.write(data).map_err(NetError::from) 117 } 118 119 /// Write all bytes, blocking until complete or error. 120 pub fn write_all(&mut self, data: &[u8]) -> Result<()> { 121 self.stream.write_all(data).map_err(NetError::from) 122 } 123 124 /// Flush the underlying stream. 125 pub fn flush(&mut self) -> Result<()> { 126 self.stream.flush().map_err(NetError::from) 127 } 128 129 /// Set the read timeout. 130 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> { 131 self.stream 132 .set_read_timeout(duration) 133 .map_err(NetError::from) 134 } 135 136 /// Set the write timeout. 137 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> { 138 self.stream 139 .set_write_timeout(duration) 140 .map_err(NetError::from) 141 } 142 143 /// Shut down the connection (both read and write). 144 pub fn shutdown(&self) -> Result<()> { 145 self.stream.shutdown(Shutdown::Both).map_err(NetError::from) 146 } 147 148 /// Create a buffered reader over this connection. 149 /// 150 /// Consumes the connection. Use `into_buffered` if you need both buffered 151 /// read and write. 152 pub fn into_buf_reader(self) -> BufferedReader { 153 BufferedReader { 154 inner: BufReader::new(self.stream), 155 } 156 } 157 158 /// Split into a buffered reader and writer pair sharing the same stream. 159 pub fn into_buffered(self) -> Result<(BufferedReader, BufferedWriter)> { 160 let clone = self.stream.try_clone().map_err(NetError::from)?; 161 Ok(( 162 BufferedReader { 163 inner: BufReader::new(self.stream), 164 }, 165 BufferedWriter { 166 inner: BufWriter::new(clone), 167 }, 168 )) 169 } 170 171 /// Get a reference to the underlying `TcpStream`. 172 pub fn as_raw(&self) -> &TcpStream { 173 &self.stream 174 } 175} 176 177impl fmt::Debug for TcpConnection { 178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 179 f.debug_struct("TcpConnection") 180 .field("peer", &self.stream.peer_addr().ok()) 181 .field("local", &self.stream.local_addr().ok()) 182 .finish() 183 } 184} 185 186// --------------------------------------------------------------------------- 187// Buffered I/O wrappers 188// --------------------------------------------------------------------------- 189 190/// A buffered reader over a TCP stream. 191pub struct BufferedReader { 192 inner: BufReader<TcpStream>, 193} 194 195impl BufferedReader { 196 /// Read a line (including the trailing `\n` or `\r\n`). 197 /// Returns the number of bytes read, or 0 at EOF. 198 pub fn read_line(&mut self, buf: &mut String) -> Result<usize> { 199 self.inner.read_line(buf).map_err(NetError::from) 200 } 201 202 /// Read bytes into the buffer. 203 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { 204 self.inner.read(buf).map_err(NetError::from) 205 } 206 207 /// Read exactly `buf.len()` bytes. 208 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { 209 self.inner.read_exact(buf).map_err(NetError::from) 210 } 211 212 /// Return a reference to the internal buffer contents without consuming. 213 pub fn buffer(&self) -> &[u8] { 214 self.inner.buffer() 215 } 216 217 /// Consume `n` bytes from the internal buffer. 218 pub fn consume(&mut self, n: usize) { 219 self.inner.consume(n); 220 } 221 222 /// Fill the internal buffer, returning a slice of the available data. 223 pub fn fill_buf(&mut self) -> Result<&[u8]> { 224 self.inner.fill_buf().map_err(NetError::from) 225 } 226 227 /// Set the read timeout on the underlying stream. 228 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> { 229 self.inner 230 .get_ref() 231 .set_read_timeout(duration) 232 .map_err(NetError::from) 233 } 234} 235 236impl fmt::Debug for BufferedReader { 237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 238 f.debug_struct("BufferedReader") 239 .field("buffered_bytes", &self.inner.buffer().len()) 240 .finish() 241 } 242} 243 244/// A buffered writer over a TCP stream. 245pub struct BufferedWriter { 246 inner: BufWriter<TcpStream>, 247} 248 249impl BufferedWriter { 250 /// Write bytes. Returns the number of bytes written. 251 pub fn write(&mut self, data: &[u8]) -> Result<usize> { 252 self.inner.write(data).map_err(NetError::from) 253 } 254 255 /// Write all bytes. 256 pub fn write_all(&mut self, data: &[u8]) -> Result<()> { 257 self.inner.write_all(data).map_err(NetError::from) 258 } 259 260 /// Flush the buffered writer, sending all pending data. 261 pub fn flush(&mut self) -> Result<()> { 262 self.inner.flush().map_err(NetError::from) 263 } 264 265 /// Set the write timeout on the underlying stream. 266 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> { 267 self.inner 268 .get_ref() 269 .set_write_timeout(duration) 270 .map_err(NetError::from) 271 } 272} 273 274impl fmt::Debug for BufferedWriter { 275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 276 f.debug_struct("BufferedWriter").finish() 277 } 278} 279 280// --------------------------------------------------------------------------- 281// Tests 282// --------------------------------------------------------------------------- 283 284#[cfg(test)] 285mod tests { 286 use super::*; 287 288 #[test] 289 fn net_error_display_connection_refused() { 290 let err = NetError::ConnectionRefused; 291 assert_eq!(err.to_string(), "connection refused"); 292 } 293 294 #[test] 295 fn net_error_display_timeout() { 296 let err = NetError::Timeout; 297 assert_eq!(err.to_string(), "connection timed out"); 298 } 299 300 #[test] 301 fn net_error_display_dns() { 302 let err = NetError::DnsResolutionFailed("example.invalid".to_string()); 303 assert_eq!( 304 err.to_string(), 305 "DNS resolution failed for 'example.invalid'" 306 ); 307 } 308 309 #[test] 310 fn net_error_display_io() { 311 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken pipe"); 312 let err = NetError::Io(io_err); 313 assert!(err.to_string().contains("broken pipe")); 314 } 315 316 #[test] 317 fn net_error_from_io_connection_refused() { 318 let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused"); 319 let err = NetError::from(io_err); 320 assert!(matches!(err, NetError::ConnectionRefused)); 321 } 322 323 #[test] 324 fn net_error_from_io_timed_out() { 325 let io_err = io::Error::new(io::ErrorKind::TimedOut, "timed out"); 326 let err = NetError::from(io_err); 327 assert!(matches!(err, NetError::Timeout)); 328 } 329 330 #[test] 331 fn net_error_from_io_other() { 332 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken"); 333 let err = NetError::from(io_err); 334 assert!(matches!(err, NetError::Io(_))); 335 } 336 337 #[test] 338 fn connect_to_nonexistent_host_fails() { 339 let result = TcpConnection::connect("host.invalid", 1); 340 assert!(result.is_err()); 341 } 342 343 #[test] 344 fn connect_to_refused_port_fails() { 345 // Port 1 on localhost is almost certainly not listening. 346 let result = TcpConnection::connect("127.0.0.1", 1); 347 assert!(result.is_err()); 348 } 349 350 #[test] 351 fn connect_timeout_to_nonexistent_host_fails() { 352 let result = TcpConnection::connect_timeout("host.invalid", 1, Duration::from_millis(100)); 353 assert!(result.is_err()); 354 } 355 356 #[test] 357 fn loopback_echo() { 358 use std::net::TcpListener; 359 use std::thread; 360 361 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 362 let port = listener.local_addr().unwrap().port(); 363 364 let handle = thread::spawn(move || { 365 let (mut stream, _) = listener.accept().unwrap(); 366 let mut buf = [0u8; 64]; 367 let n = stream.read(&mut buf).unwrap(); 368 stream.write_all(&buf[..n]).unwrap(); 369 }); 370 371 let mut conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 372 conn.write_all(b"hello").unwrap(); 373 374 let mut response = [0u8; 5]; 375 conn.read_exact(&mut response).unwrap(); 376 assert_eq!(&response, b"hello"); 377 378 conn.shutdown().ok(); 379 handle.join().unwrap(); 380 } 381 382 #[test] 383 fn buffered_read_line() { 384 use std::net::TcpListener; 385 use std::thread; 386 387 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 388 let port = listener.local_addr().unwrap().port(); 389 390 let handle = thread::spawn(move || { 391 let (mut stream, _) = listener.accept().unwrap(); 392 stream.write_all(b"line one\nline two\n").unwrap(); 393 }); 394 395 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 396 let mut reader = conn.into_buf_reader(); 397 398 let mut line = String::new(); 399 reader.read_line(&mut line).unwrap(); 400 assert_eq!(line, "line one\n"); 401 402 line.clear(); 403 reader.read_line(&mut line).unwrap(); 404 assert_eq!(line, "line two\n"); 405 406 handle.join().unwrap(); 407 } 408 409 #[test] 410 fn buffered_read_write_pair() { 411 use std::net::TcpListener; 412 use std::thread; 413 414 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 415 let port = listener.local_addr().unwrap().port(); 416 417 let handle = thread::spawn(move || { 418 let (mut stream, _) = listener.accept().unwrap(); 419 let mut buf = [0u8; 64]; 420 let n = stream.read(&mut buf).unwrap(); 421 stream.write_all(&buf[..n]).unwrap(); 422 }); 423 424 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 425 let (mut reader, mut writer) = conn.into_buffered().unwrap(); 426 427 writer.write_all(b"ping").unwrap(); 428 writer.flush().unwrap(); 429 430 let mut response = [0u8; 4]; 431 reader.read_exact(&mut response).unwrap(); 432 assert_eq!(&response, b"ping"); 433 434 handle.join().unwrap(); 435 } 436 437 #[test] 438 fn set_timeouts() { 439 use std::net::TcpListener; 440 441 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 442 let port = listener.local_addr().unwrap().port(); 443 444 // Accept in background so connect succeeds. 445 let handle = std::thread::spawn(move || { 446 let _ = listener.accept(); 447 }); 448 449 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 450 451 conn.set_read_timeout(Some(Duration::from_millis(50))) 452 .unwrap(); 453 conn.set_write_timeout(Some(Duration::from_millis(50))) 454 .unwrap(); 455 456 // Clear the timeouts. 457 conn.set_read_timeout(None).unwrap(); 458 conn.set_write_timeout(None).unwrap(); 459 460 conn.shutdown().ok(); 461 handle.join().unwrap(); 462 } 463 464 #[test] 465 fn debug_format() { 466 use std::net::TcpListener; 467 468 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 469 let port = listener.local_addr().unwrap().port(); 470 471 let handle = std::thread::spawn(move || { 472 let _ = listener.accept(); 473 }); 474 475 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 476 let debug = format!("{conn:?}"); 477 assert!(debug.contains("TcpConnection")); 478 479 conn.shutdown().ok(); 480 handle.join().unwrap(); 481 } 482}