we (web engine): Experimental web browser project to understand the limits of Claude
1//! TCP socket abstraction wrapping `std::net::TcpStream`.
2
3use std::fmt;
4use std::io::{self, BufRead, BufReader, BufWriter, Read, Write};
5use std::net::{Shutdown, TcpStream, ToSocketAddrs};
6use std::time::Duration;
7
8// ---------------------------------------------------------------------------
9// Error types
10// ---------------------------------------------------------------------------
11
12/// Network errors.
13#[derive(Debug)]
14pub enum NetError {
15 /// Connection was refused by the remote host.
16 ConnectionRefused,
17 /// Connection timed out.
18 Timeout,
19 /// DNS resolution failed for the given hostname.
20 DnsResolutionFailed(String),
21 /// An I/O error occurred.
22 Io(io::Error),
23}
24
25impl fmt::Display for NetError {
26 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
27 match self {
28 Self::ConnectionRefused => write!(f, "connection refused"),
29 Self::Timeout => write!(f, "connection timed out"),
30 Self::DnsResolutionFailed(host) => write!(f, "DNS resolution failed for '{host}'"),
31 Self::Io(e) => write!(f, "I/O error: {e}"),
32 }
33 }
34}
35
36impl From<io::Error> for NetError {
37 fn from(err: io::Error) -> Self {
38 match err.kind() {
39 io::ErrorKind::ConnectionRefused => NetError::ConnectionRefused,
40 io::ErrorKind::TimedOut => NetError::Timeout,
41 _ => NetError::Io(err),
42 }
43 }
44}
45
46pub type Result<T> = std::result::Result<T, NetError>;
47
48// ---------------------------------------------------------------------------
49// TcpConnection
50// ---------------------------------------------------------------------------
51
52/// A TCP connection wrapping `std::net::TcpStream`.
53pub struct TcpConnection {
54 stream: TcpStream,
55}
56
57impl TcpConnection {
58 /// Connect to a TCP server by hostname and port.
59 ///
60 /// Resolves the hostname via the system resolver and connects to the first
61 /// address that succeeds.
62 pub fn connect(host: &str, port: u16) -> Result<Self> {
63 let addr_str = format!("{host}:{port}");
64 let addrs = addr_str
65 .to_socket_addrs()
66 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?;
67
68 let mut last_err = None;
69 for addr in addrs {
70 match TcpStream::connect(addr) {
71 Ok(stream) => return Ok(Self { stream }),
72 Err(e) => last_err = Some(e),
73 }
74 }
75
76 match last_err {
77 Some(e) => Err(NetError::from(e)),
78 None => Err(NetError::DnsResolutionFailed(host.to_string())),
79 }
80 }
81
82 /// Connect with a timeout.
83 pub fn connect_timeout(host: &str, port: u16, timeout: Duration) -> Result<Self> {
84 let addr_str = format!("{host}:{port}");
85 let addrs: Vec<_> = addr_str
86 .to_socket_addrs()
87 .map_err(|_| NetError::DnsResolutionFailed(host.to_string()))?
88 .collect();
89
90 let mut last_err = None;
91 for addr in addrs {
92 match TcpStream::connect_timeout(&addr, timeout) {
93 Ok(stream) => return Ok(Self { stream }),
94 Err(e) => last_err = Some(e),
95 }
96 }
97
98 match last_err {
99 Some(e) => Err(NetError::from(e)),
100 None => Err(NetError::DnsResolutionFailed(host.to_string())),
101 }
102 }
103
104 /// Read bytes into the buffer. Returns the number of bytes read.
105 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
106 self.stream.read(buf).map_err(NetError::from)
107 }
108
109 /// Read exactly `buf.len()` bytes, blocking until complete or error.
110 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
111 self.stream.read_exact(buf).map_err(NetError::from)
112 }
113
114 /// Write bytes. Returns the number of bytes written.
115 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
116 self.stream.write(data).map_err(NetError::from)
117 }
118
119 /// Write all bytes, blocking until complete or error.
120 pub fn write_all(&mut self, data: &[u8]) -> Result<()> {
121 self.stream.write_all(data).map_err(NetError::from)
122 }
123
124 /// Flush the underlying stream.
125 pub fn flush(&mut self) -> Result<()> {
126 self.stream.flush().map_err(NetError::from)
127 }
128
129 /// Set the read timeout.
130 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> {
131 self.stream
132 .set_read_timeout(duration)
133 .map_err(NetError::from)
134 }
135
136 /// Set the write timeout.
137 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> {
138 self.stream
139 .set_write_timeout(duration)
140 .map_err(NetError::from)
141 }
142
143 /// Shut down the connection (both read and write).
144 pub fn shutdown(&self) -> Result<()> {
145 self.stream.shutdown(Shutdown::Both).map_err(NetError::from)
146 }
147
148 /// Create a buffered reader over this connection.
149 ///
150 /// Consumes the connection. Use `into_buffered` if you need both buffered
151 /// read and write.
152 pub fn into_buf_reader(self) -> BufferedReader {
153 BufferedReader {
154 inner: BufReader::new(self.stream),
155 }
156 }
157
158 /// Split into a buffered reader and writer pair sharing the same stream.
159 pub fn into_buffered(self) -> Result<(BufferedReader, BufferedWriter)> {
160 let clone = self.stream.try_clone().map_err(NetError::from)?;
161 Ok((
162 BufferedReader {
163 inner: BufReader::new(self.stream),
164 },
165 BufferedWriter {
166 inner: BufWriter::new(clone),
167 },
168 ))
169 }
170
171 /// Get a reference to the underlying `TcpStream`.
172 pub fn as_raw(&self) -> &TcpStream {
173 &self.stream
174 }
175}
176
177impl fmt::Debug for TcpConnection {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 f.debug_struct("TcpConnection")
180 .field("peer", &self.stream.peer_addr().ok())
181 .field("local", &self.stream.local_addr().ok())
182 .finish()
183 }
184}
185
186// ---------------------------------------------------------------------------
187// Buffered I/O wrappers
188// ---------------------------------------------------------------------------
189
190/// A buffered reader over a TCP stream.
191pub struct BufferedReader {
192 inner: BufReader<TcpStream>,
193}
194
195impl BufferedReader {
196 /// Read a line (including the trailing `\n` or `\r\n`).
197 /// Returns the number of bytes read, or 0 at EOF.
198 pub fn read_line(&mut self, buf: &mut String) -> Result<usize> {
199 self.inner.read_line(buf).map_err(NetError::from)
200 }
201
202 /// Read bytes into the buffer.
203 pub fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
204 self.inner.read(buf).map_err(NetError::from)
205 }
206
207 /// Read exactly `buf.len()` bytes.
208 pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<()> {
209 self.inner.read_exact(buf).map_err(NetError::from)
210 }
211
212 /// Return a reference to the internal buffer contents without consuming.
213 pub fn buffer(&self) -> &[u8] {
214 self.inner.buffer()
215 }
216
217 /// Consume `n` bytes from the internal buffer.
218 pub fn consume(&mut self, n: usize) {
219 self.inner.consume(n);
220 }
221
222 /// Fill the internal buffer, returning a slice of the available data.
223 pub fn fill_buf(&mut self) -> Result<&[u8]> {
224 self.inner.fill_buf().map_err(NetError::from)
225 }
226
227 /// Set the read timeout on the underlying stream.
228 pub fn set_read_timeout(&self, duration: Option<Duration>) -> Result<()> {
229 self.inner
230 .get_ref()
231 .set_read_timeout(duration)
232 .map_err(NetError::from)
233 }
234}
235
236impl fmt::Debug for BufferedReader {
237 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
238 f.debug_struct("BufferedReader")
239 .field("buffered_bytes", &self.inner.buffer().len())
240 .finish()
241 }
242}
243
244/// A buffered writer over a TCP stream.
245pub struct BufferedWriter {
246 inner: BufWriter<TcpStream>,
247}
248
249impl BufferedWriter {
250 /// Write bytes. Returns the number of bytes written.
251 pub fn write(&mut self, data: &[u8]) -> Result<usize> {
252 self.inner.write(data).map_err(NetError::from)
253 }
254
255 /// Write all bytes.
256 pub fn write_all(&mut self, data: &[u8]) -> Result<()> {
257 self.inner.write_all(data).map_err(NetError::from)
258 }
259
260 /// Flush the buffered writer, sending all pending data.
261 pub fn flush(&mut self) -> Result<()> {
262 self.inner.flush().map_err(NetError::from)
263 }
264
265 /// Set the write timeout on the underlying stream.
266 pub fn set_write_timeout(&self, duration: Option<Duration>) -> Result<()> {
267 self.inner
268 .get_ref()
269 .set_write_timeout(duration)
270 .map_err(NetError::from)
271 }
272}
273
274impl fmt::Debug for BufferedWriter {
275 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
276 f.debug_struct("BufferedWriter").finish()
277 }
278}
279
280// ---------------------------------------------------------------------------
281// Tests
282// ---------------------------------------------------------------------------
283
284#[cfg(test)]
285mod tests {
286 use super::*;
287
288 #[test]
289 fn net_error_display_connection_refused() {
290 let err = NetError::ConnectionRefused;
291 assert_eq!(err.to_string(), "connection refused");
292 }
293
294 #[test]
295 fn net_error_display_timeout() {
296 let err = NetError::Timeout;
297 assert_eq!(err.to_string(), "connection timed out");
298 }
299
300 #[test]
301 fn net_error_display_dns() {
302 let err = NetError::DnsResolutionFailed("example.invalid".to_string());
303 assert_eq!(
304 err.to_string(),
305 "DNS resolution failed for 'example.invalid'"
306 );
307 }
308
309 #[test]
310 fn net_error_display_io() {
311 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken pipe");
312 let err = NetError::Io(io_err);
313 assert!(err.to_string().contains("broken pipe"));
314 }
315
316 #[test]
317 fn net_error_from_io_connection_refused() {
318 let io_err = io::Error::new(io::ErrorKind::ConnectionRefused, "refused");
319 let err = NetError::from(io_err);
320 assert!(matches!(err, NetError::ConnectionRefused));
321 }
322
323 #[test]
324 fn net_error_from_io_timed_out() {
325 let io_err = io::Error::new(io::ErrorKind::TimedOut, "timed out");
326 let err = NetError::from(io_err);
327 assert!(matches!(err, NetError::Timeout));
328 }
329
330 #[test]
331 fn net_error_from_io_other() {
332 let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken");
333 let err = NetError::from(io_err);
334 assert!(matches!(err, NetError::Io(_)));
335 }
336
337 #[test]
338 fn connect_to_nonexistent_host_fails() {
339 let result = TcpConnection::connect("host.invalid", 1);
340 assert!(result.is_err());
341 }
342
343 #[test]
344 fn connect_to_refused_port_fails() {
345 // Port 1 on localhost is almost certainly not listening.
346 let result = TcpConnection::connect("127.0.0.1", 1);
347 assert!(result.is_err());
348 }
349
350 #[test]
351 fn connect_timeout_to_nonexistent_host_fails() {
352 let result = TcpConnection::connect_timeout("host.invalid", 1, Duration::from_millis(100));
353 assert!(result.is_err());
354 }
355
356 #[test]
357 fn loopback_echo() {
358 use std::net::TcpListener;
359 use std::thread;
360
361 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
362 let port = listener.local_addr().unwrap().port();
363
364 let handle = thread::spawn(move || {
365 let (mut stream, _) = listener.accept().unwrap();
366 let mut buf = [0u8; 64];
367 let n = stream.read(&mut buf).unwrap();
368 stream.write_all(&buf[..n]).unwrap();
369 });
370
371 let mut conn = TcpConnection::connect("127.0.0.1", port).unwrap();
372 conn.write_all(b"hello").unwrap();
373
374 let mut response = [0u8; 5];
375 conn.read_exact(&mut response).unwrap();
376 assert_eq!(&response, b"hello");
377
378 conn.shutdown().ok();
379 handle.join().unwrap();
380 }
381
382 #[test]
383 fn buffered_read_line() {
384 use std::net::TcpListener;
385 use std::thread;
386
387 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
388 let port = listener.local_addr().unwrap().port();
389
390 let handle = thread::spawn(move || {
391 let (mut stream, _) = listener.accept().unwrap();
392 stream.write_all(b"line one\nline two\n").unwrap();
393 });
394
395 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
396 let mut reader = conn.into_buf_reader();
397
398 let mut line = String::new();
399 reader.read_line(&mut line).unwrap();
400 assert_eq!(line, "line one\n");
401
402 line.clear();
403 reader.read_line(&mut line).unwrap();
404 assert_eq!(line, "line two\n");
405
406 handle.join().unwrap();
407 }
408
409 #[test]
410 fn buffered_read_write_pair() {
411 use std::net::TcpListener;
412 use std::thread;
413
414 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
415 let port = listener.local_addr().unwrap().port();
416
417 let handle = thread::spawn(move || {
418 let (mut stream, _) = listener.accept().unwrap();
419 let mut buf = [0u8; 64];
420 let n = stream.read(&mut buf).unwrap();
421 stream.write_all(&buf[..n]).unwrap();
422 });
423
424 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
425 let (mut reader, mut writer) = conn.into_buffered().unwrap();
426
427 writer.write_all(b"ping").unwrap();
428 writer.flush().unwrap();
429
430 let mut response = [0u8; 4];
431 reader.read_exact(&mut response).unwrap();
432 assert_eq!(&response, b"ping");
433
434 handle.join().unwrap();
435 }
436
437 #[test]
438 fn set_timeouts() {
439 use std::net::TcpListener;
440
441 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
442 let port = listener.local_addr().unwrap().port();
443
444 // Accept in background so connect succeeds.
445 let handle = std::thread::spawn(move || {
446 let _ = listener.accept();
447 });
448
449 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
450
451 conn.set_read_timeout(Some(Duration::from_millis(50)))
452 .unwrap();
453 conn.set_write_timeout(Some(Duration::from_millis(50)))
454 .unwrap();
455
456 // Clear the timeouts.
457 conn.set_read_timeout(None).unwrap();
458 conn.set_write_timeout(None).unwrap();
459
460 conn.shutdown().ok();
461 handle.join().unwrap();
462 }
463
464 #[test]
465 fn debug_format() {
466 use std::net::TcpListener;
467
468 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
469 let port = listener.local_addr().unwrap().port();
470
471 let handle = std::thread::spawn(move || {
472 let _ = listener.accept();
473 });
474
475 let conn = TcpConnection::connect("127.0.0.1", port).unwrap();
476 let debug = format!("{conn:?}");
477 assert!(debug.contains("TcpConnection"));
478
479 conn.shutdown().ok();
480 handle.join().unwrap();
481 }
482}