//! TCP socket abstraction wrapping `std::net::TcpStream`. use std::fmt; use std::io::{self, BufRead, BufReader, BufWriter, Read, Write}; use std::net::{Shutdown, TcpStream, ToSocketAddrs}; use std::time::Duration; // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- /// Network errors. #[derive(Debug)] pub enum NetError { /// Connection was refused by the remote host. ConnectionRefused, /// Connection timed out. Timeout, /// DNS resolution failed for the given hostname. DnsResolutionFailed(String), /// An I/O error occurred. Io(io::Error), } impl fmt::Display for NetError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::ConnectionRefused => write!(f, "connection refused"), Self::Timeout => write!(f, "connection timed out"), Self::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for '{host}'"), Self::Io(e) => write!(f, "I/O error: {e}"), } } } impl From for NetError { fn from(err: io::Error) -> Self { match err.kind() { io::ErrorKind::ConnectionRefused => NetError::ConnectionRefused, io::ErrorKind::TimedOut => NetError::Timeout, _ => NetError::Io(err), } } } pub type Result = std::result::Result; // --------------------------------------------------------------------------- // TcpConnection // --------------------------------------------------------------------------- /// A TCP connection wrapping `std::net::TcpStream`. pub struct TcpConnection { stream: TcpStream, } impl TcpConnection { /// Connect to a TCP server by hostname and port. /// /// Resolves the hostname via the system resolver and connects to the first /// address that succeeds. pub fn connect(host: &str, port: u16) -> Result { let addr_str = format!("{host}:{port}"); let addrs = addr_str .to_socket_addrs() .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?; let mut last_err = None; for addr in addrs { match TcpStream::connect(addr) { Ok(stream) => return Ok(Self { stream }), Err(e) => last_err = Some(e), } } match last_err { Some(e) => Err(NetError::from(e)), None => Err(NetError::DnsResolutionFailed(host.to_string())), } } /// Connect with a timeout. pub fn connect_timeout(host: &str, port: u16, timeout: Duration) -> Result { let addr_str = format!("{host}:{port}"); let addrs: Vec<_> = addr_str .to_socket_addrs() .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))? .collect(); let mut last_err = None; for addr in addrs { match TcpStream::connect_timeout(&addr, timeout) { Ok(stream) => return Ok(Self { stream }), Err(e) => last_err = Some(e), } } match last_err { Some(e) => Err(NetError::from(e)), None => Err(NetError::DnsResolutionFailed(host.to_string())), } } /// Read bytes into the buffer. Returns the number of bytes read. pub fn read(&mut self, buf: &mut [u8]) -> Result { self.stream.read(buf).map_err(NetError::from) } /// Read exactly `buf.len()` bytes, blocking until complete or error. pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { self.stream.read_exact(buf).map_err(NetError::from) } /// Write bytes. Returns the number of bytes written. pub fn write(&mut self, data: &[u8]) -> Result { self.stream.write(data).map_err(NetError::from) } /// Write all bytes, blocking until complete or error. pub fn write_all(&mut self, data: &[u8]) -> Result<()> { self.stream.write_all(data).map_err(NetError::from) } /// Flush the underlying stream. pub fn flush(&mut self) -> Result<()> { self.stream.flush().map_err(NetError::from) } /// Set the read timeout. pub fn set_read_timeout(&self, duration: Option) -> Result<()> { self.stream .set_read_timeout(duration) .map_err(NetError::from) } /// Set the write timeout. pub fn set_write_timeout(&self, duration: Option) -> Result<()> { self.stream .set_write_timeout(duration) .map_err(NetError::from) } /// Shut down the connection (both read and write). pub fn shutdown(&self) -> Result<()> { self.stream.shutdown(Shutdown::Both).map_err(NetError::from) } /// Create a buffered reader over this connection. /// /// Consumes the connection. Use `into_buffered` if you need both buffered /// read and write. pub fn into_buf_reader(self) -> BufferedReader { BufferedReader { inner: BufReader::new(self.stream), } } /// Split into a buffered reader and writer pair sharing the same stream. pub fn into_buffered(self) -> Result<(BufferedReader, BufferedWriter)> { let clone = self.stream.try_clone().map_err(NetError::from)?; Ok(( BufferedReader { inner: BufReader::new(self.stream), }, BufferedWriter { inner: BufWriter::new(clone), }, )) } /// Get a reference to the underlying `TcpStream`. pub fn as_raw(&self) -> &TcpStream { &self.stream } } impl Read for TcpConnection { fn read(&mut self, buf: &mut [u8]) -> io::Result { self.stream.read(buf) } } impl Write for TcpConnection { fn write(&mut self, buf: &[u8]) -> io::Result { self.stream.write(buf) } fn flush(&mut self) -> io::Result<()> { self.stream.flush() } } impl fmt::Debug for TcpConnection { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TcpConnection") .field("peer", &self.stream.peer_addr().ok()) .field("local", &self.stream.local_addr().ok()) .finish() } } // --------------------------------------------------------------------------- // Buffered I/O wrappers // --------------------------------------------------------------------------- /// A buffered reader over a TCP stream. pub struct BufferedReader { inner: BufReader, } impl BufferedReader { /// Read a line (including the trailing `\n` or `\r\n`). /// Returns the number of bytes read, or 0 at EOF. pub fn read_line(&mut self, buf: &mut String) -> Result { self.inner.read_line(buf).map_err(NetError::from) } /// Read bytes into the buffer. pub fn read(&mut self, buf: &mut [u8]) -> Result { self.inner.read(buf).map_err(NetError::from) } /// Read exactly `buf.len()` bytes. pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> { self.inner.read_exact(buf).map_err(NetError::from) } /// Return a reference to the internal buffer contents without consuming. pub fn buffer(&self) -> &[u8] { self.inner.buffer() } /// Consume `n` bytes from the internal buffer. pub fn consume(&mut self, n: usize) { self.inner.consume(n); } /// Fill the internal buffer, returning a slice of the available data. pub fn fill_buf(&mut self) -> Result<&[u8]> { self.inner.fill_buf().map_err(NetError::from) } /// Set the read timeout on the underlying stream. pub fn set_read_timeout(&self, duration: Option) -> Result<()> { self.inner .get_ref() .set_read_timeout(duration) .map_err(NetError::from) } } impl fmt::Debug for BufferedReader { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BufferedReader") .field("buffered_bytes", &self.inner.buffer().len()) .finish() } } /// A buffered writer over a TCP stream. pub struct BufferedWriter { inner: BufWriter, } impl BufferedWriter { /// Write bytes. Returns the number of bytes written. pub fn write(&mut self, data: &[u8]) -> Result { self.inner.write(data).map_err(NetError::from) } /// Write all bytes. pub fn write_all(&mut self, data: &[u8]) -> Result<()> { self.inner.write_all(data).map_err(NetError::from) } /// Flush the buffered writer, sending all pending data. pub fn flush(&mut self) -> Result<()> { self.inner.flush().map_err(NetError::from) } /// Set the write timeout on the underlying stream. pub fn set_write_timeout(&self, duration: Option) -> Result<()> { self.inner .get_ref() .set_write_timeout(duration) .map_err(NetError::from) } } impl fmt::Debug for BufferedWriter { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("BufferedWriter").finish() } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; #[test] fn net_error_display_connection_refused() { let err = NetError::ConnectionRefused; assert_eq!(err.to_string(), "connection refused"); } #[test] fn net_error_display_timeout() { let err = NetError::Timeout; assert_eq!(err.to_string(), "connection timed out"); } #[test] fn net_error_display_dns() { let err = NetError::DnsResolutionFailed("example.invalid".to_string()); assert_eq!( err.to_string(), "DNS resolution failed for 'example.invalid'" ); } #[test] fn net_error_display_io() { let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken pipe"); let err = NetError::Io(io_err); assert!(err.to_string().contains("broken pipe")); } #[test] fn net_error_from_io_connection_refused() { let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused"); let err = NetError::from(io_err); assert!(matches!(err, NetError::ConnectionRefused)); } #[test] fn net_error_from_io_timed_out() { let io_err = io::Error::new(io::ErrorKind::TimedOut, "timed out"); let err = NetError::from(io_err); assert!(matches!(err, NetError::Timeout)); } #[test] fn net_error_from_io_other() { let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken"); let err = NetError::from(io_err); assert!(matches!(err, NetError::Io(_))); } #[test] fn connect_to_nonexistent_host_fails() { let result = TcpConnection::connect("host.invalid", 1); assert!(result.is_err()); } #[test] fn connect_to_refused_port_fails() { // Port 1 on localhost is almost certainly not listening. let result = TcpConnection::connect("127.0.0.1", 1); assert!(result.is_err()); } #[test] fn connect_timeout_to_nonexistent_host_fails() { let result = TcpConnection::connect_timeout("host.invalid", 1, Duration::from_millis(100)); assert!(result.is_err()); } #[test] fn loopback_echo() { use std::net::TcpListener; use std::thread; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let handle = thread::spawn(move || { let (mut stream, _) = listener.accept().unwrap(); let mut buf = [0u8; 64]; let n = stream.read(&mut buf).unwrap(); stream.write_all(&buf[..n]).unwrap(); }); let mut conn = TcpConnection::connect("127.0.0.1", port).unwrap(); conn.write_all(b"hello").unwrap(); let mut response = [0u8; 5]; conn.read_exact(&mut response).unwrap(); assert_eq!(&response, b"hello"); conn.shutdown().ok(); handle.join().unwrap(); } #[test] fn buffered_read_line() { use std::net::TcpListener; use std::thread; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let handle = thread::spawn(move || { let (mut stream, _) = listener.accept().unwrap(); stream.write_all(b"line one\nline two\n").unwrap(); }); let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); let mut reader = conn.into_buf_reader(); let mut line = String::new(); reader.read_line(&mut line).unwrap(); assert_eq!(line, "line one\n"); line.clear(); reader.read_line(&mut line).unwrap(); assert_eq!(line, "line two\n"); handle.join().unwrap(); } #[test] fn buffered_read_write_pair() { use std::net::TcpListener; use std::thread; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let handle = thread::spawn(move || { let (mut stream, _) = listener.accept().unwrap(); let mut buf = [0u8; 64]; let n = stream.read(&mut buf).unwrap(); stream.write_all(&buf[..n]).unwrap(); }); let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); let (mut reader, mut writer) = conn.into_buffered().unwrap(); writer.write_all(b"ping").unwrap(); writer.flush().unwrap(); let mut response = [0u8; 4]; reader.read_exact(&mut response).unwrap(); assert_eq!(&response, b"ping"); handle.join().unwrap(); } #[test] fn set_timeouts() { use std::net::TcpListener; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); // Accept in background so connect succeeds. let handle = std::thread::spawn(move || { let _ = listener.accept(); }); let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); conn.set_read_timeout(Some(Duration::from_millis(50))) .unwrap(); conn.set_write_timeout(Some(Duration::from_millis(50))) .unwrap(); // Clear the timeouts. conn.set_read_timeout(None).unwrap(); conn.set_write_timeout(None).unwrap(); conn.shutdown().ok(); handle.join().unwrap(); } #[test] fn debug_format() { use std::net::TcpListener; let listener = TcpListener::bind("127.0.0.1:0").unwrap(); let port = listener.local_addr().unwrap().port(); let handle = std::thread::spawn(move || { let _ = listener.accept(); }); let conn = TcpConnection::connect("127.0.0.1", port).unwrap(); let debug = format!("{conn:?}"); assert!(debug.contains("TcpConnection")); conn.shutdown().ok(); handle.join().unwrap(); } }