we (web engine): Experimental web browser project to understand the limits of Claude
at utf-codecs 989 lines 33 kB view raw
1//! DNS stub resolver (RFC 1035). 2//! 3//! Resolves hostnames to IP addresses by sending UDP queries to system nameservers 4//! parsed from `/etc/resolv.conf`. 5 6use std::io; 7use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; 8use std::time::Duration; 9 10// --------------------------------------------------------------------------- 11// Constants 12// --------------------------------------------------------------------------- 13 14/// Default DNS port. 15const DNS_PORT: u16 = 53; 16 17/// Default query timeout. 18const DEFAULT_TIMEOUT: Duration = Duration::from_secs(3); 19 20/// Maximum UDP DNS message size. 21const MAX_UDP_SIZE: usize = 512; 22 23/// DNS record types. 24const TYPE_A: u16 = 1; 25const TYPE_AAAA: u16 = 28; 26 27/// DNS class: Internet. 28const CLASS_IN: u16 = 1; 29 30/// DNS header flags. 31const FLAG_RD: u16 = 1 << 8; // Recursion Desired 32const FLAG_QR: u16 = 1 << 15; // Query/Response 33 34/// DNS compression pointer mask. 35const COMPRESSION_MASK: u8 = 0xC0; 36 37// --------------------------------------------------------------------------- 38// Error types 39// --------------------------------------------------------------------------- 40 41/// DNS resolution errors. 42#[derive(Debug)] 43pub enum DnsError { 44 /// Failed to parse `/etc/resolv.conf` or no nameservers found. 45 NoNameservers, 46 /// The query timed out on all nameservers. 47 Timeout, 48 /// The response was malformed. 49 MalformedResponse(String), 50 /// The server returned a non-zero RCODE. 51 ServerError(u8), 52 /// An I/O error occurred. 53 Io(io::Error), 54} 55 56impl std::fmt::Display for DnsError { 57 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 58 match self { 59 Self::NoNameservers => write!(f, "no DNS nameservers configured"), 60 Self::Timeout => write!(f, "DNS query timed out"), 61 Self::MalformedResponse(msg) => write!(f, "malformed DNS response: {msg}"), 62 Self::ServerError(code) => write!(f, "DNS server error (RCODE {code})"), 63 Self::Io(e) => write!(f, "DNS I/O error: {e}"), 64 } 65 } 66} 67 68impl From<io::Error> for DnsError { 69 fn from(err: io::Error) -> Self { 70 if err.kind() == io::ErrorKind::TimedOut || err.kind() == io::ErrorKind::WouldBlock { 71 DnsError::Timeout 72 } else { 73 DnsError::Io(err) 74 } 75 } 76} 77 78pub type Result<T> = std::result::Result<T, DnsError>; 79 80// --------------------------------------------------------------------------- 81// DNS message building 82// --------------------------------------------------------------------------- 83 84/// Build a DNS query message for the given hostname and record type. 85fn build_query(id: u16, hostname: &str, qtype: u16) -> std::result::Result<Vec<u8>, DnsError> { 86 let mut buf = Vec::with_capacity(64); 87 88 // Header (12 bytes) 89 buf.extend_from_slice(&id.to_be_bytes()); 90 buf.extend_from_slice(&FLAG_RD.to_be_bytes()); // flags: RD=1 91 buf.extend_from_slice(&1u16.to_be_bytes()); // QDCOUNT = 1 92 buf.extend_from_slice(&0u16.to_be_bytes()); // ANCOUNT = 0 93 buf.extend_from_slice(&0u16.to_be_bytes()); // NSCOUNT = 0 94 buf.extend_from_slice(&0u16.to_be_bytes()); // ARCOUNT = 0 95 96 // Question section: QNAME 97 encode_name(&mut buf, hostname)?; 98 99 // QTYPE and QCLASS 100 buf.extend_from_slice(&qtype.to_be_bytes()); 101 buf.extend_from_slice(&CLASS_IN.to_be_bytes()); 102 103 Ok(buf) 104} 105 106/// Encode a domain name in DNS wire format (length-prefixed labels). 107fn encode_name(buf: &mut Vec<u8>, name: &str) -> std::result::Result<(), DnsError> { 108 if name.is_empty() { 109 buf.push(0); 110 return Ok(()); 111 } 112 113 for label in name.split('.') { 114 if label.is_empty() { 115 continue; 116 } 117 if label.len() > 63 { 118 return Err(DnsError::MalformedResponse(format!( 119 "label too long: {} bytes", 120 label.len() 121 ))); 122 } 123 buf.push(label.len() as u8); 124 buf.extend_from_slice(label.as_bytes()); 125 } 126 buf.push(0); // root label 127 Ok(()) 128} 129 130// --------------------------------------------------------------------------- 131// DNS message parsing 132// --------------------------------------------------------------------------- 133 134/// A parsed DNS response. 135struct DnsResponse { 136 id: u16, 137 rcode: u8, 138 answers: Vec<DnsRecord>, 139} 140 141/// A parsed DNS resource record. 142struct DnsRecord { 143 rtype: u16, 144 rdata: Vec<u8>, 145} 146 147/// Parser state for reading a DNS message. 148struct DnsParser<'a> { 149 data: &'a [u8], 150 pos: usize, 151} 152 153impl<'a> DnsParser<'a> { 154 fn new(data: &'a [u8]) -> Self { 155 Self { data, pos: 0 } 156 } 157 158 fn remaining(&self) -> usize { 159 self.data.len().saturating_sub(self.pos) 160 } 161 162 fn read_u8(&mut self) -> std::result::Result<u8, DnsError> { 163 if self.pos >= self.data.len() { 164 return Err(DnsError::MalformedResponse("unexpected end of data".into())); 165 } 166 let val = self.data[self.pos]; 167 self.pos += 1; 168 Ok(val) 169 } 170 171 fn read_u16(&mut self) -> std::result::Result<u16, DnsError> { 172 if self.pos + 2 > self.data.len() { 173 return Err(DnsError::MalformedResponse("unexpected end of data".into())); 174 } 175 let val = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]); 176 self.pos += 2; 177 Ok(val) 178 } 179 180 fn read_u32(&mut self) -> std::result::Result<u32, DnsError> { 181 if self.pos + 4 > self.data.len() { 182 return Err(DnsError::MalformedResponse("unexpected end of data".into())); 183 } 184 let val = u32::from_be_bytes([ 185 self.data[self.pos], 186 self.data[self.pos + 1], 187 self.data[self.pos + 2], 188 self.data[self.pos + 3], 189 ]); 190 self.pos += 4; 191 Ok(val) 192 } 193 194 fn read_bytes(&mut self, n: usize) -> std::result::Result<&'a [u8], DnsError> { 195 if self.pos + n > self.data.len() { 196 return Err(DnsError::MalformedResponse("unexpected end of data".into())); 197 } 198 let slice = &self.data[self.pos..self.pos + n]; 199 self.pos += n; 200 Ok(slice) 201 } 202 203 /// Skip a DNS name (handling compression pointers). 204 fn skip_name(&mut self) -> std::result::Result<(), DnsError> { 205 let mut jumps = 0; 206 loop { 207 if jumps > 64 { 208 return Err(DnsError::MalformedResponse( 209 "too many compression jumps".into(), 210 )); 211 } 212 let len = self.read_u8()?; 213 if len == 0 { 214 break; 215 } 216 if len & COMPRESSION_MASK == COMPRESSION_MASK { 217 // Compression pointer: read the second byte and stop 218 let _offset_low = self.read_u8()?; 219 break; 220 } 221 // Regular label: skip the label bytes 222 if self.pos + len as usize > self.data.len() { 223 return Err(DnsError::MalformedResponse( 224 "label extends past data".into(), 225 )); 226 } 227 self.pos += len as usize; 228 jumps += 1; 229 } 230 Ok(()) 231 } 232 233 /// Parse a complete DNS response. 234 fn parse_response(&mut self) -> std::result::Result<DnsResponse, DnsError> { 235 if self.remaining() < 12 { 236 return Err(DnsError::MalformedResponse("response too short".into())); 237 } 238 239 let id = self.read_u16()?; 240 let flags = self.read_u16()?; 241 let qdcount = self.read_u16()?; 242 let ancount = self.read_u16()?; 243 let _nscount = self.read_u16()?; 244 let _arcount = self.read_u16()?; 245 246 // Verify QR bit is set (this is a response) 247 if flags & FLAG_QR == 0 { 248 return Err(DnsError::MalformedResponse("QR bit not set".into())); 249 } 250 251 let rcode = (flags & 0x000F) as u8; 252 253 // Skip question section 254 for _ in 0..qdcount { 255 self.skip_name()?; 256 let _qtype = self.read_u16()?; 257 let _qclass = self.read_u16()?; 258 } 259 260 // Parse answer section 261 let mut answers = Vec::new(); 262 for _ in 0..ancount { 263 self.skip_name()?; // NAME 264 let rtype = self.read_u16()?; // TYPE 265 let _rclass = self.read_u16()?; // CLASS 266 let _ttl = self.read_u32()?; // TTL 267 let rdlength = self.read_u16()? as usize; // RDLENGTH 268 let rdata = self.read_bytes(rdlength)?; // RDATA 269 270 answers.push(DnsRecord { 271 rtype, 272 rdata: rdata.to_vec(), 273 }); 274 } 275 276 Ok(DnsResponse { id, rcode, answers }) 277 } 278} 279 280// --------------------------------------------------------------------------- 281// resolv.conf parsing 282// --------------------------------------------------------------------------- 283 284/// Parse nameserver addresses from `/etc/resolv.conf`. 285pub fn parse_resolv_conf(content: &str) -> Vec<IpAddr> { 286 let mut nameservers = Vec::new(); 287 for line in content.lines() { 288 let line = line.trim(); 289 if line.starts_with('#') || line.is_empty() { 290 continue; 291 } 292 if let Some(rest) = line.strip_prefix("nameserver") { 293 let addr_str = rest.trim(); 294 if let Ok(addr) = addr_str.parse::<IpAddr>() { 295 nameservers.push(addr); 296 } 297 } 298 } 299 nameservers 300} 301 302/// Read nameservers from the system `/etc/resolv.conf`. 303fn system_nameservers() -> Vec<IpAddr> { 304 match std::fs::read_to_string("/etc/resolv.conf") { 305 Ok(content) => { 306 let servers = parse_resolv_conf(&content); 307 if servers.is_empty() { 308 // Fallback to common defaults 309 vec![ 310 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 311 IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 312 ] 313 } else { 314 servers 315 } 316 } 317 Err(_) => { 318 // Fallback to Google DNS 319 vec![ 320 IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), 321 IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), 322 ] 323 } 324 } 325} 326 327// --------------------------------------------------------------------------- 328// Public API 329// --------------------------------------------------------------------------- 330 331/// Configuration for DNS resolution. 332pub struct DnsResolver { 333 nameservers: Vec<IpAddr>, 334 timeout: Duration, 335} 336 337impl DnsResolver { 338 /// Create a resolver using system nameservers from `/etc/resolv.conf`. 339 pub fn system() -> Self { 340 Self { 341 nameservers: system_nameservers(), 342 timeout: DEFAULT_TIMEOUT, 343 } 344 } 345 346 /// Create a resolver with explicit nameserver addresses. 347 pub fn with_nameservers(nameservers: Vec<IpAddr>) -> Self { 348 Self { 349 nameservers, 350 timeout: DEFAULT_TIMEOUT, 351 } 352 } 353 354 /// Set the query timeout. 355 pub fn set_timeout(&mut self, timeout: Duration) { 356 self.timeout = timeout; 357 } 358 359 /// Resolve a hostname to a list of IP addresses. 360 /// 361 /// Sends A and AAAA queries and collects all results. 362 pub fn resolve(&self, hostname: &str) -> Result<Vec<IpAddr>> { 363 if self.nameservers.is_empty() { 364 return Err(DnsError::NoNameservers); 365 } 366 367 // Try to parse as an IP address directly 368 if let Ok(addr) = hostname.parse::<IpAddr>() { 369 return Ok(vec![addr]); 370 } 371 372 let mut addrs = Vec::new(); 373 374 // Send A query 375 if let Ok(mut v4) = self.query(hostname, TYPE_A) { 376 addrs.append(&mut v4); 377 } 378 379 // Send AAAA query 380 if let Ok(mut v6) = self.query(hostname, TYPE_AAAA) { 381 addrs.append(&mut v6); 382 } 383 384 if addrs.is_empty() { 385 // Retry once more, this time propagating the error 386 addrs = self.query(hostname, TYPE_A)?; 387 } 388 389 Ok(addrs) 390 } 391 392 /// Send a single DNS query for a specific record type. 393 fn query(&self, hostname: &str, qtype: u16) -> Result<Vec<IpAddr>> { 394 let id = generate_id(); 395 let query_msg = build_query(id, hostname, qtype)?; 396 397 let socket = UdpSocket::bind("0.0.0.0:0").map_err(DnsError::from)?; 398 socket 399 .set_read_timeout(Some(self.timeout)) 400 .map_err(DnsError::from)?; 401 402 let mut last_err = None; 403 404 for &ns in &self.nameservers { 405 let dest = SocketAddr::new(ns, DNS_PORT); 406 407 if let Err(e) = socket.send_to(&query_msg, dest) { 408 last_err = Some(DnsError::from(e)); 409 continue; 410 } 411 412 let mut resp_buf = [0u8; MAX_UDP_SIZE]; 413 match socket.recv_from(&mut resp_buf) { 414 Ok((len, _)) => { 415 let mut parser = DnsParser::new(&resp_buf[..len]); 416 match parser.parse_response() { 417 Ok(response) => { 418 if response.id != id { 419 last_err = Some(DnsError::MalformedResponse("ID mismatch".into())); 420 continue; 421 } 422 if response.rcode != 0 { 423 // RCODE 3 = NXDOMAIN: name doesn't exist, no point retrying 424 if response.rcode == 3 { 425 return Ok(Vec::new()); 426 } 427 last_err = Some(DnsError::ServerError(response.rcode)); 428 continue; 429 } 430 return Ok(extract_addrs(&response.answers, qtype)); 431 } 432 Err(e) => { 433 last_err = Some(e); 434 continue; 435 } 436 } 437 } 438 Err(e) => { 439 last_err = Some(DnsError::from(e)); 440 continue; 441 } 442 } 443 } 444 445 Err(last_err.unwrap_or(DnsError::Timeout)) 446 } 447} 448 449/// Extract IP addresses from DNS answer records. 450fn extract_addrs(answers: &[DnsRecord], qtype: u16) -> Vec<IpAddr> { 451 let mut addrs = Vec::new(); 452 for record in answers { 453 if record.rtype == TYPE_A && qtype == TYPE_A && record.rdata.len() == 4 { 454 addrs.push(IpAddr::V4(Ipv4Addr::new( 455 record.rdata[0], 456 record.rdata[1], 457 record.rdata[2], 458 record.rdata[3], 459 ))); 460 } else if record.rtype == TYPE_AAAA && qtype == TYPE_AAAA && record.rdata.len() == 16 { 461 let mut octets = [0u8; 16]; 462 octets.copy_from_slice(&record.rdata); 463 addrs.push(IpAddr::V6(Ipv6Addr::from(octets))); 464 } 465 } 466 addrs 467} 468 469/// Generate a pseudo-random 16-bit transaction ID. 470fn generate_id() -> u16 { 471 // Use a simple approach based on time for uniqueness 472 let duration = std::time::SystemTime::now() 473 .duration_since(std::time::UNIX_EPOCH) 474 .unwrap_or_default(); 475 // Mix nanoseconds for some entropy 476 let nanos = duration.subsec_nanos(); 477 let secs = duration.as_secs(); 478 ((nanos ^ (secs as u32)) & 0xFFFF) as u16 479} 480 481/// Convenience function: resolve a hostname using system nameservers. 482pub fn resolve(hostname: &str) -> Result<Vec<IpAddr>> { 483 DnsResolver::system().resolve(hostname) 484} 485 486// --------------------------------------------------------------------------- 487// Tests 488// --------------------------------------------------------------------------- 489 490#[cfg(test)] 491mod tests { 492 use super::*; 493 494 // -- Name encoding tests -- 495 496 #[test] 497 fn encode_simple_name() { 498 let mut buf = Vec::new(); 499 encode_name(&mut buf, "example.com").unwrap(); 500 assert_eq!( 501 buf, 502 vec![7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0] 503 ); 504 } 505 506 #[test] 507 fn encode_single_label() { 508 let mut buf = Vec::new(); 509 encode_name(&mut buf, "localhost").unwrap(); 510 assert_eq!( 511 buf, 512 vec![9, b'l', b'o', b'c', b'a', b'l', b'h', b'o', b's', b't', 0] 513 ); 514 } 515 516 #[test] 517 fn encode_empty_name() { 518 let mut buf = Vec::new(); 519 encode_name(&mut buf, "").unwrap(); 520 assert_eq!(buf, vec![0]); 521 } 522 523 #[test] 524 fn encode_trailing_dot() { 525 let mut buf = Vec::new(); 526 encode_name(&mut buf, "example.com.").unwrap(); 527 // Should handle trailing dot gracefully (skip empty label) 528 assert_eq!( 529 buf, 530 vec![7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0] 531 ); 532 } 533 534 #[test] 535 fn encode_label_too_long() { 536 let long_label = "a".repeat(64); 537 let mut buf = Vec::new(); 538 assert!(encode_name(&mut buf, &long_label).is_err()); 539 } 540 541 // -- Query building tests -- 542 543 #[test] 544 fn build_query_a_record() { 545 let query = build_query(0x1234, "example.com", TYPE_A).unwrap(); 546 // Header: 12 bytes 547 assert_eq!(query[0..2], [0x12, 0x34]); // ID 548 assert_eq!(query[2..4], [0x01, 0x00]); // Flags: RD=1 549 assert_eq!(query[4..6], [0x00, 0x01]); // QDCOUNT=1 550 assert_eq!(query[6..8], [0x00, 0x00]); // ANCOUNT=0 551 // QTYPE at end - 2 (A=1) 552 let qtype_offset = query.len() - 4; 553 assert_eq!(query[qtype_offset..qtype_offset + 2], [0x00, 0x01]); // TYPE A 554 assert_eq!(query[qtype_offset + 2..qtype_offset + 4], [0x00, 0x01]); // CLASS IN 555 } 556 557 #[test] 558 fn build_query_aaaa_record() { 559 let query = build_query(0xABCD, "example.com", TYPE_AAAA).unwrap(); 560 assert_eq!(query[0..2], [0xAB, 0xCD]); // ID 561 let qtype_offset = query.len() - 4; 562 assert_eq!(query[qtype_offset..qtype_offset + 2], [0x00, 0x1C]); // TYPE AAAA (28) 563 } 564 565 // -- Response parsing tests -- 566 567 fn make_response(id: u16, rcode: u8, answers: &[(u16, &[u8])]) -> Vec<u8> { 568 let mut buf = Vec::new(); 569 // Header 570 buf.extend_from_slice(&id.to_be_bytes()); 571 let flags: u16 = FLAG_QR | (rcode as u16); // QR=1 572 buf.extend_from_slice(&flags.to_be_bytes()); 573 buf.extend_from_slice(&1u16.to_be_bytes()); // QDCOUNT=1 574 buf.extend_from_slice(&(answers.len() as u16).to_be_bytes()); // ANCOUNT 575 buf.extend_from_slice(&0u16.to_be_bytes()); // NSCOUNT 576 buf.extend_from_slice(&0u16.to_be_bytes()); // ARCOUNT 577 578 // Question section 579 buf.extend_from_slice(&[ 580 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 581 ]); 582 buf.extend_from_slice(&TYPE_A.to_be_bytes()); // QTYPE 583 buf.extend_from_slice(&CLASS_IN.to_be_bytes()); // QCLASS 584 585 // Answer section 586 for &(rtype, rdata) in answers { 587 // NAME: compression pointer to offset 12 (question name) 588 buf.extend_from_slice(&[0xC0, 0x0C]); 589 buf.extend_from_slice(&rtype.to_be_bytes()); // TYPE 590 buf.extend_from_slice(&CLASS_IN.to_be_bytes()); // CLASS 591 buf.extend_from_slice(&300u32.to_be_bytes()); // TTL 592 buf.extend_from_slice(&(rdata.len() as u16).to_be_bytes()); // RDLENGTH 593 buf.extend_from_slice(rdata); // RDATA 594 } 595 596 buf 597 } 598 599 #[test] 600 fn parse_a_response() { 601 let resp_data = make_response(0x1234, 0, &[(TYPE_A, &[93, 184, 216, 34])]); 602 let mut parser = DnsParser::new(&resp_data); 603 let response = parser.parse_response().unwrap(); 604 assert_eq!(response.id, 0x1234); 605 assert_eq!(response.rcode, 0); 606 assert_eq!(response.answers.len(), 1); 607 assert_eq!(response.answers[0].rtype, TYPE_A); 608 assert_eq!(response.answers[0].rdata, vec![93, 184, 216, 34]); 609 } 610 611 #[test] 612 fn parse_aaaa_response() { 613 let ipv6_data: [u8; 16] = [ 614 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 615 0x00, 0x01, 616 ]; 617 let resp_data = make_response(0x5678, 0, &[(TYPE_AAAA, &ipv6_data)]); 618 let mut parser = DnsParser::new(&resp_data); 619 let response = parser.parse_response().unwrap(); 620 assert_eq!(response.id, 0x5678); 621 assert_eq!(response.answers.len(), 1); 622 assert_eq!(response.answers[0].rtype, TYPE_AAAA); 623 assert_eq!(response.answers[0].rdata, ipv6_data.to_vec()); 624 } 625 626 #[test] 627 fn parse_multiple_a_records() { 628 let resp_data = make_response( 629 0xAAAA, 630 0, 631 &[ 632 (TYPE_A, &[1, 2, 3, 4]), 633 (TYPE_A, &[5, 6, 7, 8]), 634 (TYPE_A, &[9, 10, 11, 12]), 635 ], 636 ); 637 let mut parser = DnsParser::new(&resp_data); 638 let response = parser.parse_response().unwrap(); 639 assert_eq!(response.answers.len(), 3); 640 let addrs = extract_addrs(&response.answers, TYPE_A); 641 assert_eq!(addrs.len(), 3); 642 assert_eq!(addrs[0], IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))); 643 assert_eq!(addrs[1], IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))); 644 assert_eq!(addrs[2], IpAddr::V4(Ipv4Addr::new(9, 10, 11, 12))); 645 } 646 647 #[test] 648 fn parse_server_error_rcode() { 649 let resp_data = make_response(0x1111, 2, &[]); // RCODE 2 = SERVFAIL 650 let mut parser = DnsParser::new(&resp_data); 651 let response = parser.parse_response().unwrap(); 652 assert_eq!(response.rcode, 2); 653 } 654 655 #[test] 656 fn parse_nxdomain() { 657 let resp_data = make_response(0x2222, 3, &[]); // RCODE 3 = NXDOMAIN 658 let mut parser = DnsParser::new(&resp_data); 659 let response = parser.parse_response().unwrap(); 660 assert_eq!(response.rcode, 3); 661 } 662 663 #[test] 664 fn parse_response_too_short() { 665 let short = vec![0u8; 6]; 666 let mut parser = DnsParser::new(&short); 667 assert!(parser.parse_response().is_err()); 668 } 669 670 #[test] 671 fn parse_response_not_response() { 672 // QR bit not set (flags = 0x0100, which is a query with RD) 673 let mut buf = vec![0x00, 0x01]; // ID 674 buf.extend_from_slice(&[0x01, 0x00]); // flags: RD=1, QR=0 675 buf.extend_from_slice(&[0, 0, 0, 0, 0, 0, 0, 0]); // counts 676 let mut parser = DnsParser::new(&buf); 677 assert!(parser.parse_response().is_err()); 678 } 679 680 // -- Name compression tests -- 681 682 #[test] 683 fn parse_compression_pointer() { 684 // Build a response where answer name uses a compression pointer 685 let resp_data = make_response(0x3333, 0, &[(TYPE_A, &[10, 0, 0, 1])]); 686 let mut parser = DnsParser::new(&resp_data); 687 let response = parser.parse_response().unwrap(); 688 assert_eq!(response.answers.len(), 1); 689 assert_eq!(response.answers[0].rdata, vec![10, 0, 0, 1]); 690 } 691 692 /// Test-only: decode a DNS name from wire format at a given offset, 693 /// following compression pointers. 694 fn decode_name(data: &[u8], offset: usize) -> std::result::Result<String, DnsError> { 695 fn decode_at( 696 data: &[u8], 697 mut offset: usize, 698 name: &mut String, 699 depth: usize, 700 ) -> std::result::Result<(), DnsError> { 701 if depth > 64 { 702 return Err(DnsError::MalformedResponse( 703 "too many compression jumps".into(), 704 )); 705 } 706 loop { 707 if offset >= data.len() { 708 return Err(DnsError::MalformedResponse("unexpected end of data".into())); 709 } 710 let len = data[offset]; 711 offset += 1; 712 if len == 0 { 713 break; 714 } 715 if len & COMPRESSION_MASK == COMPRESSION_MASK { 716 if offset >= data.len() { 717 return Err(DnsError::MalformedResponse("truncated pointer".into())); 718 } 719 let ptr = 720 (((len & !COMPRESSION_MASK) as u16) << 8 | data[offset] as u16) as usize; 721 return decode_at(data, ptr, name, depth + 1); 722 } 723 let label_end = offset + len as usize; 724 if label_end > data.len() { 725 return Err(DnsError::MalformedResponse( 726 "label extends past data".into(), 727 )); 728 } 729 if !name.is_empty() { 730 name.push('.'); 731 } 732 for &b in &data[offset..label_end] { 733 name.push(b as char); 734 } 735 offset = label_end; 736 } 737 Ok(()) 738 } 739 740 let mut name = String::new(); 741 decode_at(data, offset, &mut name, 0)?; 742 Ok(name) 743 } 744 745 #[test] 746 fn decode_name_with_compression() { 747 let mut data = Vec::new(); 748 // Name at offset 0: "test.com" 749 data.extend_from_slice(&[4, b't', b'e', b's', b't', 3, b'c', b'o', b'm', 0]); 750 // Pointer at offset 10: points to offset 0 751 data.extend_from_slice(&[0xC0, 0x00]); 752 753 let name = decode_name(&data, 10).unwrap(); 754 assert_eq!(name, "test.com"); 755 } 756 757 #[test] 758 fn decode_name_no_compression() { 759 let data = vec![ 760 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 761 0, 762 ]; 763 let name = decode_name(&data, 0).unwrap(); 764 assert_eq!(name, "www.example.com"); 765 } 766 767 #[test] 768 fn decode_name_partial_compression() { 769 // "www" label followed by pointer to "example.com" at offset 0 770 let mut data = Vec::new(); 771 // Offset 0: "example.com" 772 data.extend_from_slice(&[ 773 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, 774 ]); 775 // Offset 13: "www" + pointer to offset 0 776 data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]); 777 778 let name = decode_name(&data, 13).unwrap(); 779 assert_eq!(name, "www.example.com"); 780 } 781 782 // -- resolv.conf parsing tests -- 783 784 #[test] 785 fn parse_resolv_conf_basic() { 786 let content = "nameserver 8.8.8.8\nnameserver 8.8.4.4\n"; 787 let servers = parse_resolv_conf(content); 788 assert_eq!(servers.len(), 2); 789 assert_eq!(servers[0], IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); 790 assert_eq!(servers[1], IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4))); 791 } 792 793 #[test] 794 fn parse_resolv_conf_with_comments() { 795 let content = "\ 796# DNS config 797nameserver 1.1.1.1 798# backup 799nameserver 1.0.0.1 800search example.com 801"; 802 let servers = parse_resolv_conf(content); 803 assert_eq!(servers.len(), 2); 804 assert_eq!(servers[0], IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))); 805 assert_eq!(servers[1], IpAddr::V4(Ipv4Addr::new(1, 0, 0, 1))); 806 } 807 808 #[test] 809 fn parse_resolv_conf_ipv6() { 810 let content = "nameserver 2001:4860:4860::8888\nnameserver 8.8.8.8\n"; 811 let servers = parse_resolv_conf(content); 812 assert_eq!(servers.len(), 2); 813 assert!(servers[0].is_ipv6()); 814 assert!(servers[1].is_ipv4()); 815 } 816 817 #[test] 818 fn parse_resolv_conf_empty() { 819 let servers = parse_resolv_conf(""); 820 assert!(servers.is_empty()); 821 } 822 823 #[test] 824 fn parse_resolv_conf_no_nameservers() { 825 let content = "search example.com\noptions ndots:1\n"; 826 let servers = parse_resolv_conf(content); 827 assert!(servers.is_empty()); 828 } 829 830 #[test] 831 fn parse_resolv_conf_invalid_addr() { 832 let content = "nameserver not.an.ip.addr\nnameserver 8.8.8.8\n"; 833 let servers = parse_resolv_conf(content); 834 assert_eq!(servers.len(), 1); 835 assert_eq!(servers[0], IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); 836 } 837 838 // -- extract_addrs tests -- 839 840 #[test] 841 fn extract_v4_addrs() { 842 let records = vec![ 843 DnsRecord { 844 rtype: TYPE_A, 845 rdata: vec![192, 168, 1, 1], 846 }, 847 DnsRecord { 848 rtype: TYPE_A, 849 rdata: vec![10, 0, 0, 1], 850 }, 851 ]; 852 let addrs = extract_addrs(&records, TYPE_A); 853 assert_eq!(addrs.len(), 2); 854 assert_eq!(addrs[0], IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))); 855 assert_eq!(addrs[1], IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); 856 } 857 858 #[test] 859 fn extract_v6_addrs() { 860 let records = vec![DnsRecord { 861 rtype: TYPE_AAAA, 862 rdata: vec![0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], 863 }]; 864 let addrs = extract_addrs(&records, TYPE_AAAA); 865 assert_eq!(addrs.len(), 1); 866 assert!(addrs[0].is_ipv6()); 867 } 868 869 #[test] 870 fn extract_ignores_wrong_type() { 871 let records = vec![DnsRecord { 872 rtype: TYPE_AAAA, 873 rdata: vec![0; 16], 874 }]; 875 // Asking for A records should skip AAAA 876 let addrs = extract_addrs(&records, TYPE_A); 877 assert!(addrs.is_empty()); 878 } 879 880 #[test] 881 fn extract_ignores_bad_length() { 882 let records = vec![DnsRecord { 883 rtype: TYPE_A, 884 rdata: vec![1, 2, 3], // Too short for A record (needs 4 bytes) 885 }]; 886 let addrs = extract_addrs(&records, TYPE_A); 887 assert!(addrs.is_empty()); 888 } 889 890 // -- Error type tests -- 891 892 #[test] 893 fn dns_error_display() { 894 assert_eq!( 895 DnsError::NoNameservers.to_string(), 896 "no DNS nameservers configured" 897 ); 898 assert_eq!(DnsError::Timeout.to_string(), "DNS query timed out"); 899 assert_eq!( 900 DnsError::MalformedResponse("bad".into()).to_string(), 901 "malformed DNS response: bad" 902 ); 903 assert_eq!( 904 DnsError::ServerError(2).to_string(), 905 "DNS server error (RCODE 2)" 906 ); 907 } 908 909 #[test] 910 fn dns_error_from_io_timeout() { 911 let err = DnsError::from(io::Error::new(io::ErrorKind::TimedOut, "timed out")); 912 assert!(matches!(err, DnsError::Timeout)); 913 } 914 915 #[test] 916 fn dns_error_from_io_would_block() { 917 let err = DnsError::from(io::Error::new(io::ErrorKind::WouldBlock, "would block")); 918 assert!(matches!(err, DnsError::Timeout)); 919 } 920 921 #[test] 922 fn dns_error_from_io_other() { 923 let err = DnsError::from(io::Error::new(io::ErrorKind::BrokenPipe, "broken")); 924 assert!(matches!(err, DnsError::Io(_))); 925 } 926 927 // -- DnsResolver tests -- 928 929 #[test] 930 fn resolver_with_no_nameservers() { 931 let resolver = DnsResolver::with_nameservers(vec![]); 932 let result = resolver.resolve("example.com"); 933 assert!(matches!(result, Err(DnsError::NoNameservers))); 934 } 935 936 #[test] 937 fn resolver_passthrough_ip_v4() { 938 let resolver = DnsResolver::with_nameservers(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]); 939 let addrs = resolver.resolve("127.0.0.1").unwrap(); 940 assert_eq!(addrs, vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))]); 941 } 942 943 #[test] 944 fn resolver_passthrough_ip_v6() { 945 let resolver = DnsResolver::with_nameservers(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]); 946 let addrs = resolver.resolve("::1").unwrap(); 947 assert_eq!(addrs, vec![IpAddr::V6(Ipv6Addr::LOCALHOST)]); 948 } 949 950 #[test] 951 fn resolver_system_has_nameservers() { 952 let resolver = DnsResolver::system(); 953 // Should always have at least the fallback nameservers 954 assert!(!resolver.nameservers.is_empty()); 955 } 956 957 #[test] 958 fn resolver_set_timeout() { 959 let mut resolver = DnsResolver::system(); 960 resolver.set_timeout(Duration::from_secs(1)); 961 assert_eq!(resolver.timeout, Duration::from_secs(1)); 962 } 963 964 // -- Integration test: real DNS resolution -- 965 966 #[test] 967 fn resolve_real_hostname() { 968 // This test requires network access; skip in offline CI 969 let resolver = DnsResolver::system(); 970 match resolver.resolve("dns.google") { 971 Ok(addrs) => { 972 assert!(!addrs.is_empty(), "expected at least one address"); 973 // dns.google should resolve to 8.8.8.8 and/or 8.8.4.4 974 let has_known = addrs.iter().any(|a| { 975 *a == IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)) 976 || *a == IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)) 977 }); 978 assert!( 979 has_known, 980 "expected dns.google to resolve to 8.8.8.8 or 8.8.4.4, got {addrs:?}" 981 ); 982 } 983 Err(DnsError::Timeout) => { 984 // Network may not be available in CI 985 } 986 Err(e) => panic!("unexpected error: {e}"), 987 } 988 } 989}