we (web engine): Experimental web browser project to understand the limits of Claude
at utf-codecs 498 lines 15 kB view raw
1//! TCP socket abstraction wrapping `std::net::TcpStream`. 2 3use std::fmt; 4use std::io::{self, BufRead, BufReader, BufWriter, Read, Write}; 5use std::net::{Shutdown, TcpStream, ToSocketAddrs}; 6use std::time::Duration; 7 8// --------------------------------------------------------------------------- 9// Error types 10// --------------------------------------------------------------------------- 11 12/// Network errors. 13#[derive(Debug)] 14pub enum NetError { 15 /// Connection was refused by the remote host. 16 ConnectionRefused, 17 /// Connection timed out. 18 Timeout, 19 /// DNS resolution failed for the given hostname. 20 DnsResolutionFailed(String), 21 /// An I/O error occurred. 22 Io(io::Error), 23} 24 25impl fmt::Display for NetError { 26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 27 match self { 28 Self::ConnectionRefused => write!(f, "connection refused"), 29 Self::Timeout => write!(f, "connection timed out"), 30 Self::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for '{host}'"), 31 Self::Io(e) => write!(f, "I/O error: {e}"), 32 } 33 } 34} 35 36impl From<io::Error> for NetError { 37 fn from(err: io::Error) -> Self { 38 match err.kind() { 39 io::ErrorKind::ConnectionRefused => NetError::ConnectionRefused, 40 io::ErrorKind::TimedOut => NetError::Timeout, 41 _ => NetError::Io(err), 42 } 43 } 44} 45 46pub type Result<T> = std::result::Result<T, NetError>; 47 48// --------------------------------------------------------------------------- 49// TcpConnection 50// --------------------------------------------------------------------------- 51 52/// A TCP connection wrapping `std::net::TcpStream`. 53pub struct TcpConnection { 54 stream: TcpStream, 55} 56 57impl TcpConnection { 58 /// Connect to a TCP server by hostname and port. 59 /// 60 /// Resolves the hostname via the system resolver and connects to the first 61 /// address that succeeds. 62 pub fn connect(host: &str, port: u16) -> Result<Self> { 63 let addr_str = format!("{host}:{port}"); 64 let addrs = addr_str 65 .to_socket_addrs() 66 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?; 67 68 let mut last_err = None; 69 for addr in addrs { 70 match TcpStream::connect(addr) { 71 Ok(stream) => return Ok(Self { stream }), 72 Err(e) => last_err = Some(e), 73 } 74 } 75 76 match last_err { 77 Some(e) => Err(NetError::from(e)), 78 None => Err(NetError::DnsResolutionFailed(host.to_string())), 79 } 80 } 81 82 /// Connect with a timeout. 83 pub fn connect_timeout(host: &str, port: u16, timeout: Duration) -> Result<Self> { 84 let addr_str = format!("{host}:{port}"); 85 let addrs: Vec<_> = addr_str 86 .to_socket_addrs() 87 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))? 88 .collect(); 89 90 let mut last_err = None; 91 for addr in addrs { 92 match TcpStream::connect_timeout(&addr, timeout) { 93 Ok(stream) => return Ok(Self { stream }), 94 Err(e) => last_err = Some(e), 95 } 96 } 97 98 match last_err { 99 Some(e) => Err(NetError::from(e)), 100 None => Err(NetError::DnsResolutionFailed(host.to_string())), 101 } 102 } 103 104 /// Read bytes into the buffer. Returns the number of bytes read. 105 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { 106 self.stream.read(buf).map_err(NetError::from) 107 } 108 109 /// Read exactly `buf.len()` bytes, blocking until complete or error. 110 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { 111 self.stream.read_exact(buf).map_err(NetError::from) 112 } 113 114 /// Write bytes. Returns the number of bytes written. 115 pub fn write(&mut self, data: &[u8]) -> Result<usize> { 116 self.stream.write(data).map_err(NetError::from) 117 } 118 119 /// Write all bytes, blocking until complete or error. 120 pub fn write_all(&mut self, data: &[u8]) -> Result<()> { 121 self.stream.write_all(data).map_err(NetError::from) 122 } 123 124 /// Flush the underlying stream. 125 pub fn flush(&mut self) -> Result<()> { 126 self.stream.flush().map_err(NetError::from) 127 } 128 129 /// Set the read timeout. 130 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> { 131 self.stream 132 .set_read_timeout(duration) 133 .map_err(NetError::from) 134 } 135 136 /// Set the write timeout. 137 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> { 138 self.stream 139 .set_write_timeout(duration) 140 .map_err(NetError::from) 141 } 142 143 /// Shut down the connection (both read and write). 144 pub fn shutdown(&self) -> Result<()> { 145 self.stream.shutdown(Shutdown::Both).map_err(NetError::from) 146 } 147 148 /// Create a buffered reader over this connection. 149 /// 150 /// Consumes the connection. Use `into_buffered` if you need both buffered 151 /// read and write. 152 pub fn into_buf_reader(self) -> BufferedReader { 153 BufferedReader { 154 inner: BufReader::new(self.stream), 155 } 156 } 157 158 /// Split into a buffered reader and writer pair sharing the same stream. 159 pub fn into_buffered(self) -> Result<(BufferedReader, BufferedWriter)> { 160 let clone = self.stream.try_clone().map_err(NetError::from)?; 161 Ok(( 162 BufferedReader { 163 inner: BufReader::new(self.stream), 164 }, 165 BufferedWriter { 166 inner: BufWriter::new(clone), 167 }, 168 )) 169 } 170 171 /// Get a reference to the underlying `TcpStream`. 172 pub fn as_raw(&self) -> &TcpStream { 173 &self.stream 174 } 175} 176 177impl Read for TcpConnection { 178 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> { 179 self.stream.read(buf) 180 } 181} 182 183impl Write for TcpConnection { 184 fn write(&mut self, buf: &[u8]) -> io::Result<usize> { 185 self.stream.write(buf) 186 } 187 188 fn flush(&mut self) -> io::Result<()> { 189 self.stream.flush() 190 } 191} 192 193impl fmt::Debug for TcpConnection { 194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 195 f.debug_struct("TcpConnection") 196 .field("peer", &self.stream.peer_addr().ok()) 197 .field("local", &self.stream.local_addr().ok()) 198 .finish() 199 } 200} 201 202// --------------------------------------------------------------------------- 203// Buffered I/O wrappers 204// --------------------------------------------------------------------------- 205 206/// A buffered reader over a TCP stream. 207pub struct BufferedReader { 208 inner: BufReader<TcpStream>, 209} 210 211impl BufferedReader { 212 /// Read a line (including the trailing `\n` or `\r\n`). 213 /// Returns the number of bytes read, or 0 at EOF. 214 pub fn read_line(&mut self, buf: &mut String) -> Result<usize> { 215 self.inner.read_line(buf).map_err(NetError::from) 216 } 217 218 /// Read bytes into the buffer. 219 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> { 220 self.inner.read(buf).map_err(NetError::from) 221 } 222 223 /// Read exactly `buf.len()` bytes. 224 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { 225 self.inner.read_exact(buf).map_err(NetError::from) 226 } 227 228 /// Return a reference to the internal buffer contents without consuming. 229 pub fn buffer(&self) -> &[u8] { 230 self.inner.buffer() 231 } 232 233 /// Consume `n` bytes from the internal buffer. 234 pub fn consume(&mut self, n: usize) { 235 self.inner.consume(n); 236 } 237 238 /// Fill the internal buffer, returning a slice of the available data. 239 pub fn fill_buf(&mut self) -> Result<&[u8]> { 240 self.inner.fill_buf().map_err(NetError::from) 241 } 242 243 /// Set the read timeout on the underlying stream. 244 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> { 245 self.inner 246 .get_ref() 247 .set_read_timeout(duration) 248 .map_err(NetError::from) 249 } 250} 251 252impl fmt::Debug for BufferedReader { 253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 254 f.debug_struct("BufferedReader") 255 .field("buffered_bytes", &self.inner.buffer().len()) 256 .finish() 257 } 258} 259 260/// A buffered writer over a TCP stream. 261pub struct BufferedWriter { 262 inner: BufWriter<TcpStream>, 263} 264 265impl BufferedWriter { 266 /// Write bytes. Returns the number of bytes written. 267 pub fn write(&mut self, data: &[u8]) -> Result<usize> { 268 self.inner.write(data).map_err(NetError::from) 269 } 270 271 /// Write all bytes. 272 pub fn write_all(&mut self, data: &[u8]) -> Result<()> { 273 self.inner.write_all(data).map_err(NetError::from) 274 } 275 276 /// Flush the buffered writer, sending all pending data. 277 pub fn flush(&mut self) -> Result<()> { 278 self.inner.flush().map_err(NetError::from) 279 } 280 281 /// Set the write timeout on the underlying stream. 282 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> { 283 self.inner 284 .get_ref() 285 .set_write_timeout(duration) 286 .map_err(NetError::from) 287 } 288} 289 290impl fmt::Debug for BufferedWriter { 291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 292 f.debug_struct("BufferedWriter").finish() 293 } 294} 295 296// --------------------------------------------------------------------------- 297// Tests 298// --------------------------------------------------------------------------- 299 300#[cfg(test)] 301mod tests { 302 use super::*; 303 304 #[test] 305 fn net_error_display_connection_refused() { 306 let err = NetError::ConnectionRefused; 307 assert_eq!(err.to_string(), "connection refused"); 308 } 309 310 #[test] 311 fn net_error_display_timeout() { 312 let err = NetError::Timeout; 313 assert_eq!(err.to_string(), "connection timed out"); 314 } 315 316 #[test] 317 fn net_error_display_dns() { 318 let err = NetError::DnsResolutionFailed("example.invalid".to_string()); 319 assert_eq!( 320 err.to_string(), 321 "DNS resolution failed for 'example.invalid'" 322 ); 323 } 324 325 #[test] 326 fn net_error_display_io() { 327 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken pipe"); 328 let err = NetError::Io(io_err); 329 assert!(err.to_string().contains("broken pipe")); 330 } 331 332 #[test] 333 fn net_error_from_io_connection_refused() { 334 let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused"); 335 let err = NetError::from(io_err); 336 assert!(matches!(err, NetError::ConnectionRefused)); 337 } 338 339 #[test] 340 fn net_error_from_io_timed_out() { 341 let io_err = io::Error::new(io::ErrorKind::TimedOut, "timed out"); 342 let err = NetError::from(io_err); 343 assert!(matches!(err, NetError::Timeout)); 344 } 345 346 #[test] 347 fn net_error_from_io_other() { 348 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken"); 349 let err = NetError::from(io_err); 350 assert!(matches!(err, NetError::Io(_))); 351 } 352 353 #[test] 354 fn connect_to_nonexistent_host_fails() { 355 let result = TcpConnection::connect("host.invalid", 1); 356 assert!(result.is_err()); 357 } 358 359 #[test] 360 fn connect_to_refused_port_fails() { 361 // Port 1 on localhost is almost certainly not listening. 362 let result = TcpConnection::connect("127.0.0.1", 1); 363 assert!(result.is_err()); 364 } 365 366 #[test] 367 fn connect_timeout_to_nonexistent_host_fails() { 368 let result = TcpConnection::connect_timeout("host.invalid", 1, Duration::from_millis(100)); 369 assert!(result.is_err()); 370 } 371 372 #[test] 373 fn loopback_echo() { 374 use std::net::TcpListener; 375 use std::thread; 376 377 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 378 let port = listener.local_addr().unwrap().port(); 379 380 let handle = thread::spawn(move || { 381 let (mut stream, _) = listener.accept().unwrap(); 382 let mut buf = [0u8; 64]; 383 let n = stream.read(&mut buf).unwrap(); 384 stream.write_all(&buf[..n]).unwrap(); 385 }); 386 387 let mut conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 388 conn.write_all(b"hello").unwrap(); 389 390 let mut response = [0u8; 5]; 391 conn.read_exact(&mut response).unwrap(); 392 assert_eq!(&response, b"hello"); 393 394 conn.shutdown().ok(); 395 handle.join().unwrap(); 396 } 397 398 #[test] 399 fn buffered_read_line() { 400 use std::net::TcpListener; 401 use std::thread; 402 403 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 404 let port = listener.local_addr().unwrap().port(); 405 406 let handle = thread::spawn(move || { 407 let (mut stream, _) = listener.accept().unwrap(); 408 stream.write_all(b"line one\nline two\n").unwrap(); 409 }); 410 411 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 412 let mut reader = conn.into_buf_reader(); 413 414 let mut line = String::new(); 415 reader.read_line(&mut line).unwrap(); 416 assert_eq!(line, "line one\n"); 417 418 line.clear(); 419 reader.read_line(&mut line).unwrap(); 420 assert_eq!(line, "line two\n"); 421 422 handle.join().unwrap(); 423 } 424 425 #[test] 426 fn buffered_read_write_pair() { 427 use std::net::TcpListener; 428 use std::thread; 429 430 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 431 let port = listener.local_addr().unwrap().port(); 432 433 let handle = thread::spawn(move || { 434 let (mut stream, _) = listener.accept().unwrap(); 435 let mut buf = [0u8; 64]; 436 let n = stream.read(&mut buf).unwrap(); 437 stream.write_all(&buf[..n]).unwrap(); 438 }); 439 440 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 441 let (mut reader, mut writer) = conn.into_buffered().unwrap(); 442 443 writer.write_all(b"ping").unwrap(); 444 writer.flush().unwrap(); 445 446 let mut response = [0u8; 4]; 447 reader.read_exact(&mut response).unwrap(); 448 assert_eq!(&response, b"ping"); 449 450 handle.join().unwrap(); 451 } 452 453 #[test] 454 fn set_timeouts() { 455 use std::net::TcpListener; 456 457 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 458 let port = listener.local_addr().unwrap().port(); 459 460 // Accept in background so connect succeeds. 461 let handle = std::thread::spawn(move || { 462 let _ = listener.accept(); 463 }); 464 465 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 466 467 conn.set_read_timeout(Some(Duration::from_millis(50))) 468 .unwrap(); 469 conn.set_write_timeout(Some(Duration::from_millis(50))) 470 .unwrap(); 471 472 // Clear the timeouts. 473 conn.set_read_timeout(None).unwrap(); 474 conn.set_write_timeout(None).unwrap(); 475 476 conn.shutdown().ok(); 477 handle.join().unwrap(); 478 } 479 480 #[test] 481 fn debug_format() { 482 use std::net::TcpListener; 483 484 let listener = TcpListener::bind("127.0.0.1:0").unwrap(); 485 let port = listener.local_addr().unwrap().port(); 486 487 let handle = std::thread::spawn(move || { 488 let _ = listener.accept(); 489 }); 490 491 let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); 492 let debug = format!("{conn:?}"); 493 assert!(debug.contains("TcpConnection")); 494 495 conn.shutdown().ok(); 496 handle.join().unwrap(); 497 } 498}