we (web engine): Experimental web browser project to understand the limits of Claude
1//! TCP socket abstraction wrapping `std::net::TcpStream`.
2
3use std::fmt;
4use std::io::{self, BufRead, BufReader, BufWriter, Read, Write};
5use std::net::{Shutdown, TcpStream, ToSocketAddrs};
6use std::time::Duration;
7
8// ---------------------------------------------------------------------------
9// Error types
10// ---------------------------------------------------------------------------
11
12/// Network errors.
13#[derive(Debug)]
14pub enum NetError {
15 /// Connection was refused by the remote host.
16 ConnectionRefused,
17 /// Connection timed out.
18 Timeout,
19 /// DNS resolution failed for the given hostname.
20 DnsResolutionFailed(String),
21 /// An I/O error occurred.
22 Io(io::Error),
23}
24
25impl fmt::Display for NetError {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 match self {
28 Self::ConnectionRefused => write!(f, "connection refused"),
29 Self::Timeout => write!(f, "connection timed out"),
30 Self::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for '{host}'"),
31 Self::Io(e) => write!(f, "I/O error: {e}"),
32 }
33 }
34}
35
36impl From<io::Error> for NetError {
37 fn from(err: io::Error) -> Self {
38 match err.kind() {
39 io::ErrorKind::ConnectionRefused => NetError::ConnectionRefused,
40 io::ErrorKind::TimedOut => NetError::Timeout,
41 _ => NetError::Io(err),
42 }
43 }
44}
45
46pub type Result<T> = std::result::Result<T, NetError>;
47
48// ---------------------------------------------------------------------------
49// TcpConnection
50// ---------------------------------------------------------------------------
51
52/// A TCP connection wrapping `std::net::TcpStream`.
53pub struct TcpConnection {
54 stream: TcpStream,
55}
56
57impl TcpConnection {
58 /// Connect to a TCP server by hostname and port.
59 ///
60 /// Resolves the hostname via the system resolver and connects to the first
61 /// address that succeeds.
62 pub fn connect(host: &str, port: u16) -> Result<Self> {
63 let addr_str = format!("{host}:{port}");
64 let addrs = addr_str
65 .to_socket_addrs()
66 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?;
67
68 let mut last_err = None;
69 for addr in addrs {
70 match TcpStream::connect(addr) {
71 Ok(stream) => return Ok(Self { stream }),
72 Err(e) => last_err = Some(e),
73 }
74 }
75
76 match last_err {
77 Some(e) => Err(NetError::from(e)),
78 None => Err(NetError::DnsResolutionFailed(host.to_string())),
79 }
80 }
81
82 /// Connect with a timeout.
83 pub fn connect_timeout(host: &str, port: u16, timeout: Duration) -> Result<Self> {
84 let addr_str = format!("{host}:{port}");
85 let addrs: Vec<_> = addr_str
86 .to_socket_addrs()
87 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?
88 .collect();
89
90 let mut last_err = None;
91 for addr in addrs {
92 match TcpStream::connect_timeout(&addr, timeout) {
93 Ok(stream) => return Ok(Self { stream }),
94 Err(e) => last_err = Some(e),
95 }
96 }
97
98 match last_err {
99 Some(e) => Err(NetError::from(e)),
100 None => Err(NetError::DnsResolutionFailed(host.to_string())),
101 }
102 }
103
104 /// Read bytes into the buffer. Returns the number of bytes read.
105 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
106 self.stream.read(buf).map_err(NetError::from)
107 }
108
109 /// Read exactly `buf.len()` bytes, blocking until complete or error.
110 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
111 self.stream.read_exact(buf).map_err(NetError::from)
112 }
113
114 /// Write bytes. Returns the number of bytes written.
115 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
116 self.stream.write(data).map_err(NetError::from)
117 }
118
119 /// Write all bytes, blocking until complete or error.
120 pub fn write_all(&mut self, data: &[u8]) -> Result<()> {
121 self.stream.write_all(data).map_err(NetError::from)
122 }
123
124 /// Flush the underlying stream.
125 pub fn flush(&mut self) -> Result<()> {
126 self.stream.flush().map_err(NetError::from)
127 }
128
129 /// Set the read timeout.
130 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> {
131 self.stream
132 .set_read_timeout(duration)
133 .map_err(NetError::from)
134 }
135
136 /// Set the write timeout.
137 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> {
138 self.stream
139 .set_write_timeout(duration)
140 .map_err(NetError::from)
141 }
142
143 /// Shut down the connection (both read and write).
144 pub fn shutdown(&self) -> Result<()> {
145 self.stream.shutdown(Shutdown::Both).map_err(NetError::from)
146 }
147
148 /// Create a buffered reader over this connection.
149 ///
150 /// Consumes the connection. Use `into_buffered` if you need both buffered
151 /// read and write.
152 pub fn into_buf_reader(self) -> BufferedReader {
153 BufferedReader {
154 inner: BufReader::new(self.stream),
155 }
156 }
157
158 /// Split into a buffered reader and writer pair sharing the same stream.
159 pub fn into_buffered(self) -> Result<(BufferedReader, BufferedWriter)> {
160 let clone = self.stream.try_clone().map_err(NetError::from)?;
161 Ok((
162 BufferedReader {
163 inner: BufReader::new(self.stream),
164 },
165 BufferedWriter {
166 inner: BufWriter::new(clone),
167 },
168 ))
169 }
170
171 /// Get a reference to the underlying `TcpStream`.
172 pub fn as_raw(&self) -> &TcpStream {
173 &self.stream
174 }
175}
176
177impl Read for TcpConnection {
178 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
179 self.stream.read(buf)
180 }
181}
182
183impl Write for TcpConnection {
184 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
185 self.stream.write(buf)
186 }
187
188 fn flush(&mut self) -> io::Result<()> {
189 self.stream.flush()
190 }
191}
192
193impl fmt::Debug for TcpConnection {
194 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
195 f.debug_struct("TcpConnection")
196 .field("peer", &self.stream.peer_addr().ok())
197 .field("local", &self.stream.local_addr().ok())
198 .finish()
199 }
200}
201
202// ---------------------------------------------------------------------------
203// Buffered I/O wrappers
204// ---------------------------------------------------------------------------
205
206/// A buffered reader over a TCP stream.
207pub struct BufferedReader {
208 inner: BufReader<TcpStream>,
209}
210
211impl BufferedReader {
212 /// Read a line (including the trailing `\n` or `\r\n`).
213 /// Returns the number of bytes read, or 0 at EOF.
214 pub fn read_line(&mut self, buf: &mut String) -> Result<usize> {
215 self.inner.read_line(buf).map_err(NetError::from)
216 }
217
218 /// Read bytes into the buffer.
219 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
220 self.inner.read(buf).map_err(NetError::from)
221 }
222
223 /// Read exactly `buf.len()` bytes.
224 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
225 self.inner.read_exact(buf).map_err(NetError::from)
226 }
227
228 /// Return a reference to the internal buffer contents without consuming.
229 pub fn buffer(&self) -> &[u8] {
230 self.inner.buffer()
231 }
232
233 /// Consume `n` bytes from the internal buffer.
234 pub fn consume(&mut self, n: usize) {
235 self.inner.consume(n);
236 }
237
238 /// Fill the internal buffer, returning a slice of the available data.
239 pub fn fill_buf(&mut self) -> Result<&[u8]> {
240 self.inner.fill_buf().map_err(NetError::from)
241 }
242
243 /// Set the read timeout on the underlying stream.
244 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> {
245 self.inner
246 .get_ref()
247 .set_read_timeout(duration)
248 .map_err(NetError::from)
249 }
250}
251
252impl fmt::Debug for BufferedReader {
253 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
254 f.debug_struct("BufferedReader")
255 .field("buffered_bytes", &self.inner.buffer().len())
256 .finish()
257 }
258}
259
260/// A buffered writer over a TCP stream.
261pub struct BufferedWriter {
262 inner: BufWriter<TcpStream>,
263}
264
265impl BufferedWriter {
266 /// Write bytes. Returns the number of bytes written.
267 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
268 self.inner.write(data).map_err(NetError::from)
269 }
270
271 /// Write all bytes.
272 pub fn write_all(&mut self, data: &[u8]) -> Result<()> {
273 self.inner.write_all(data).map_err(NetError::from)
274 }
275
276 /// Flush the buffered writer, sending all pending data.
277 pub fn flush(&mut self) -> Result<()> {
278 self.inner.flush().map_err(NetError::from)
279 }
280
281 /// Set the write timeout on the underlying stream.
282 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> {
283 self.inner
284 .get_ref()
285 .set_write_timeout(duration)
286 .map_err(NetError::from)
287 }
288}
289
290impl fmt::Debug for BufferedWriter {
291 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
292 f.debug_struct("BufferedWriter").finish()
293 }
294}
295
296// ---------------------------------------------------------------------------
297// Tests
298// ---------------------------------------------------------------------------
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn net_error_display_connection_refused() {
306 let err = NetError::ConnectionRefused;
307 assert_eq!(err.to_string(), "connection refused");
308 }
309
310 #[test]
311 fn net_error_display_timeout() {
312 let err = NetError::Timeout;
313 assert_eq!(err.to_string(), "connection timed out");
314 }
315
316 #[test]
317 fn net_error_display_dns() {
318 let err = NetError::DnsResolutionFailed("example.invalid".to_string());
319 assert_eq!(
320 err.to_string(),
321 "DNS resolution failed for 'example.invalid'"
322 );
323 }
324
325 #[test]
326 fn net_error_display_io() {
327 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken pipe");
328 let err = NetError::Io(io_err);
329 assert!(err.to_string().contains("broken pipe"));
330 }
331
332 #[test]
333 fn net_error_from_io_connection_refused() {
334 let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
335 let err = NetError::from(io_err);
336 assert!(matches!(err, NetError::ConnectionRefused));
337 }
338
339 #[test]
340 fn net_error_from_io_timed_out() {
341 let io_err = io::Error::new(io::ErrorKind::TimedOut, "timed out");
342 let err = NetError::from(io_err);
343 assert!(matches!(err, NetError::Timeout));
344 }
345
346 #[test]
347 fn net_error_from_io_other() {
348 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken");
349 let err = NetError::from(io_err);
350 assert!(matches!(err, NetError::Io(_)));
351 }
352
353 #[test]
354 fn connect_to_nonexistent_host_fails() {
355 let result = TcpConnection::connect("host.invalid", 1);
356 assert!(result.is_err());
357 }
358
359 #[test]
360 fn connect_to_refused_port_fails() {
361 // Port 1 on localhost is almost certainly not listening.
362 let result = TcpConnection::connect("127.0.0.1", 1);
363 assert!(result.is_err());
364 }
365
366 #[test]
367 fn connect_timeout_to_nonexistent_host_fails() {
368 let result = TcpConnection::connect_timeout("host.invalid", 1, Duration::from_millis(100));
369 assert!(result.is_err());
370 }
371
372 #[test]
373 fn loopback_echo() {
374 use std::net::TcpListener;
375 use std::thread;
376
377 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
378 let port = listener.local_addr().unwrap().port();
379
380 let handle = thread::spawn(move || {
381 let (mut stream, _) = listener.accept().unwrap();
382 let mut buf = [0u8; 64];
383 let n = stream.read(&mut buf).unwrap();
384 stream.write_all(&buf[..n]).unwrap();
385 });
386
387 let mut conn = TcpConnection::connect("127.0.0.1", port).unwrap();
388 conn.write_all(b"hello").unwrap();
389
390 let mut response = [0u8; 5];
391 conn.read_exact(&mut response).unwrap();
392 assert_eq!(&response, b"hello");
393
394 conn.shutdown().ok();
395 handle.join().unwrap();
396 }
397
398 #[test]
399 fn buffered_read_line() {
400 use std::net::TcpListener;
401 use std::thread;
402
403 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
404 let port = listener.local_addr().unwrap().port();
405
406 let handle = thread::spawn(move || {
407 let (mut stream, _) = listener.accept().unwrap();
408 stream.write_all(b"line one\nline two\n").unwrap();
409 });
410
411 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
412 let mut reader = conn.into_buf_reader();
413
414 let mut line = String::new();
415 reader.read_line(&mut line).unwrap();
416 assert_eq!(line, "line one\n");
417
418 line.clear();
419 reader.read_line(&mut line).unwrap();
420 assert_eq!(line, "line two\n");
421
422 handle.join().unwrap();
423 }
424
425 #[test]
426 fn buffered_read_write_pair() {
427 use std::net::TcpListener;
428 use std::thread;
429
430 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
431 let port = listener.local_addr().unwrap().port();
432
433 let handle = thread::spawn(move || {
434 let (mut stream, _) = listener.accept().unwrap();
435 let mut buf = [0u8; 64];
436 let n = stream.read(&mut buf).unwrap();
437 stream.write_all(&buf[..n]).unwrap();
438 });
439
440 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
441 let (mut reader, mut writer) = conn.into_buffered().unwrap();
442
443 writer.write_all(b"ping").unwrap();
444 writer.flush().unwrap();
445
446 let mut response = [0u8; 4];
447 reader.read_exact(&mut response).unwrap();
448 assert_eq!(&response, b"ping");
449
450 handle.join().unwrap();
451 }
452
453 #[test]
454 fn set_timeouts() {
455 use std::net::TcpListener;
456
457 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
458 let port = listener.local_addr().unwrap().port();
459
460 // Accept in background so connect succeeds.
461 let handle = std::thread::spawn(move || {
462 let _ = listener.accept();
463 });
464
465 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
466
467 conn.set_read_timeout(Some(Duration::from_millis(50)))
468 .unwrap();
469 conn.set_write_timeout(Some(Duration::from_millis(50)))
470 .unwrap();
471
472 // Clear the timeouts.
473 conn.set_read_timeout(None).unwrap();
474 conn.set_write_timeout(None).unwrap();
475
476 conn.shutdown().ok();
477 handle.join().unwrap();
478 }
479
480 #[test]
481 fn debug_format() {
482 use std::net::TcpListener;
483
484 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
485 let port = listener.local_addr().unwrap().port();
486
487 let handle = std::thread::spawn(move || {
488 let _ = listener.accept();
489 });
490
491 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
492 let debug = format!("{conn:?}");
493 assert!(debug.contains("TcpConnection"));
494
495 conn.shutdown().ok();
496 handle.join().unwrap();
497 }
498}