//! DNS stub resolver (RFC 1035). //! //! Resolves hostnames to IP addresses by sending UDP queries to system nameservers //! parsed from `/etc/resolv.conf`. use std::io; use std::net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr, UdpSocket}; use std::time::Duration; // --------------------------------------------------------------------------- // Constants // --------------------------------------------------------------------------- /// Default DNS port. const DNS_PORT: u16 = 53; /// Default query timeout. const DEFAULT_TIMEOUT: Duration = Duration::from_secs(3); /// Maximum UDP DNS message size. const MAX_UDP_SIZE: usize = 512; /// DNS record types. const TYPE_A: u16 = 1; const TYPE_AAAA: u16 = 28; /// DNS class: Internet. const CLASS_IN: u16 = 1; /// DNS header flags. const FLAG_RD: u16 = 1 << 8; // Recursion Desired const FLAG_QR: u16 = 1 << 15; // Query/Response /// DNS compression pointer mask. const COMPRESSION_MASK: u8 = 0xC0; // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- /// DNS resolution errors. #[derive(Debug)] pub enum DnsError { /// Failed to parse `/etc/resolv.conf` or no nameservers found. NoNameservers, /// The query timed out on all nameservers. Timeout, /// The response was malformed. MalformedResponse(String), /// The server returned a non-zero RCODE. ServerError(u8), /// An I/O error occurred. Io(io::Error), } impl std::fmt::Display for DnsError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::NoNameservers => write!(f, "no DNS nameservers configured"), Self::Timeout => write!(f, "DNS query timed out"), Self::MalformedResponse(msg) => write!(f, "malformed DNS response: {msg}"), Self::ServerError(code) => write!(f, "DNS server error (RCODE {code})"), Self::Io(e) => write!(f, "DNS I/O error: {e}"), } } } impl From for DnsError { fn from(err: io::Error) -> Self { if err.kind() == io::ErrorKind::TimedOut || err.kind() == io::ErrorKind::WouldBlock { DnsError::Timeout } else { DnsError::Io(err) } } } pub type Result = std::result::Result; // --------------------------------------------------------------------------- // DNS message building // --------------------------------------------------------------------------- /// Build a DNS query message for the given hostname and record type. fn build_query(id: u16, hostname: &str, qtype: u16) -> std::result::Result, DnsError> { let mut buf = Vec::with_capacity(64); // Header (12 bytes) buf.extend_from_slice(&id.to_be_bytes()); buf.extend_from_slice(&FLAG_RD.to_be_bytes()); // flags: RD=1 buf.extend_from_slice(&1u16.to_be_bytes()); // QDCOUNT = 1 buf.extend_from_slice(&0u16.to_be_bytes()); // ANCOUNT = 0 buf.extend_from_slice(&0u16.to_be_bytes()); // NSCOUNT = 0 buf.extend_from_slice(&0u16.to_be_bytes()); // ARCOUNT = 0 // Question section: QNAME encode_name(&mut buf, hostname)?; // QTYPE and QCLASS buf.extend_from_slice(&qtype.to_be_bytes()); buf.extend_from_slice(&CLASS_IN.to_be_bytes()); Ok(buf) } /// Encode a domain name in DNS wire format (length-prefixed labels). fn encode_name(buf: &mut Vec, name: &str) -> std::result::Result<(), DnsError> { if name.is_empty() { buf.push(0); return Ok(()); } for label in name.split('.') { if label.is_empty() { continue; } if label.len() > 63 { return Err(DnsError::MalformedResponse(format!( "label too long: {} bytes", label.len() ))); } buf.push(label.len() as u8); buf.extend_from_slice(label.as_bytes()); } buf.push(0); // root label Ok(()) } // --------------------------------------------------------------------------- // DNS message parsing // --------------------------------------------------------------------------- /// A parsed DNS response. struct DnsResponse { id: u16, rcode: u8, answers: Vec, } /// A parsed DNS resource record. struct DnsRecord { rtype: u16, rdata: Vec, } /// Parser state for reading a DNS message. struct DnsParser<'a> { data: &'a [u8], pos: usize, } impl<'a> DnsParser<'a> { fn new(data: &'a [u8]) -> Self { Self { data, pos: 0 } } fn remaining(&self) -> usize { self.data.len().saturating_sub(self.pos) } fn read_u8(&mut self) -> std::result::Result { if self.pos >= self.data.len() { return Err(DnsError::MalformedResponse("unexpected end of data".into())); } let val = self.data[self.pos]; self.pos += 1; Ok(val) } fn read_u16(&mut self) -> std::result::Result { if self.pos + 2 > self.data.len() { return Err(DnsError::MalformedResponse("unexpected end of data".into())); } let val = u16::from_be_bytes([self.data[self.pos], self.data[self.pos + 1]]); self.pos += 2; Ok(val) } fn read_u32(&mut self) -> std::result::Result { if self.pos + 4 > self.data.len() { return Err(DnsError::MalformedResponse("unexpected end of data".into())); } let val = u32::from_be_bytes([ self.data[self.pos], self.data[self.pos + 1], self.data[self.pos + 2], self.data[self.pos + 3], ]); self.pos += 4; Ok(val) } fn read_bytes(&mut self, n: usize) -> std::result::Result<&'a [u8], DnsError> { if self.pos + n > self.data.len() { return Err(DnsError::MalformedResponse("unexpected end of data".into())); } let slice = &self.data[self.pos..self.pos + n]; self.pos += n; Ok(slice) } /// Skip a DNS name (handling compression pointers). fn skip_name(&mut self) -> std::result::Result<(), DnsError> { let mut jumps = 0; loop { if jumps > 64 { return Err(DnsError::MalformedResponse( "too many compression jumps".into(), )); } let len = self.read_u8()?; if len == 0 { break; } if len & COMPRESSION_MASK == COMPRESSION_MASK { // Compression pointer: read the second byte and stop let _offset_low = self.read_u8()?; break; } // Regular label: skip the label bytes if self.pos + len as usize > self.data.len() { return Err(DnsError::MalformedResponse( "label extends past data".into(), )); } self.pos += len as usize; jumps += 1; } Ok(()) } /// Parse a complete DNS response. fn parse_response(&mut self) -> std::result::Result { if self.remaining() < 12 { return Err(DnsError::MalformedResponse("response too short".into())); } let id = self.read_u16()?; let flags = self.read_u16()?; let qdcount = self.read_u16()?; let ancount = self.read_u16()?; let _nscount = self.read_u16()?; let _arcount = self.read_u16()?; // Verify QR bit is set (this is a response) if flags & FLAG_QR == 0 { return Err(DnsError::MalformedResponse("QR bit not set".into())); } let rcode = (flags & 0x000F) as u8; // Skip question section for _ in 0..qdcount { self.skip_name()?; let _qtype = self.read_u16()?; let _qclass = self.read_u16()?; } // Parse answer section let mut answers = Vec::new(); for _ in 0..ancount { self.skip_name()?; // NAME let rtype = self.read_u16()?; // TYPE let _rclass = self.read_u16()?; // CLASS let _ttl = self.read_u32()?; // TTL let rdlength = self.read_u16()? as usize; // RDLENGTH let rdata = self.read_bytes(rdlength)?; // RDATA answers.push(DnsRecord { rtype, rdata: rdata.to_vec(), }); } Ok(DnsResponse { id, rcode, answers }) } } // --------------------------------------------------------------------------- // resolv.conf parsing // --------------------------------------------------------------------------- /// Parse nameserver addresses from `/etc/resolv.conf`. pub fn parse_resolv_conf(content: &str) -> Vec { let mut nameservers = Vec::new(); for line in content.lines() { let line = line.trim(); if line.starts_with('#') || line.is_empty() { continue; } if let Some(rest) = line.strip_prefix("nameserver") { let addr_str = rest.trim(); if let Ok(addr) = addr_str.parse::() { nameservers.push(addr); } } } nameservers } /// Read nameservers from the system `/etc/resolv.conf`. fn system_nameservers() -> Vec { match std::fs::read_to_string("/etc/resolv.conf") { Ok(content) => { let servers = parse_resolv_conf(&content); if servers.is_empty() { // Fallback to common defaults vec![ IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), ] } else { servers } } Err(_) => { // Fallback to Google DNS vec![ IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)), IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)), ] } } } // --------------------------------------------------------------------------- // Public API // --------------------------------------------------------------------------- /// Configuration for DNS resolution. pub struct DnsResolver { nameservers: Vec, timeout: Duration, } impl DnsResolver { /// Create a resolver using system nameservers from `/etc/resolv.conf`. pub fn system() -> Self { Self { nameservers: system_nameservers(), timeout: DEFAULT_TIMEOUT, } } /// Create a resolver with explicit nameserver addresses. pub fn with_nameservers(nameservers: Vec) -> Self { Self { nameservers, timeout: DEFAULT_TIMEOUT, } } /// Set the query timeout. pub fn set_timeout(&mut self, timeout: Duration) { self.timeout = timeout; } /// Resolve a hostname to a list of IP addresses. /// /// Sends A and AAAA queries and collects all results. pub fn resolve(&self, hostname: &str) -> Result> { if self.nameservers.is_empty() { return Err(DnsError::NoNameservers); } // Try to parse as an IP address directly if let Ok(addr) = hostname.parse::() { return Ok(vec![addr]); } let mut addrs = Vec::new(); // Send A query if let Ok(mut v4) = self.query(hostname, TYPE_A) { addrs.append(&mut v4); } // Send AAAA query if let Ok(mut v6) = self.query(hostname, TYPE_AAAA) { addrs.append(&mut v6); } if addrs.is_empty() { // Retry once more, this time propagating the error addrs = self.query(hostname, TYPE_A)?; } Ok(addrs) } /// Send a single DNS query for a specific record type. fn query(&self, hostname: &str, qtype: u16) -> Result> { let id = generate_id(); let query_msg = build_query(id, hostname, qtype)?; let socket = UdpSocket::bind("0.0.0.0:0").map_err(DnsError::from)?; socket .set_read_timeout(Some(self.timeout)) .map_err(DnsError::from)?; let mut last_err = None; for &ns in &self.nameservers { let dest = SocketAddr::new(ns, DNS_PORT); if let Err(e) = socket.send_to(&query_msg, dest) { last_err = Some(DnsError::from(e)); continue; } let mut resp_buf = [0u8; MAX_UDP_SIZE]; match socket.recv_from(&mut resp_buf) { Ok((len, _)) => { let mut parser = DnsParser::new(&resp_buf[..len]); match parser.parse_response() { Ok(response) => { if response.id != id { last_err = Some(DnsError::MalformedResponse("ID mismatch".into())); continue; } if response.rcode != 0 { // RCODE 3 = NXDOMAIN: name doesn't exist, no point retrying if response.rcode == 3 { return Ok(Vec::new()); } last_err = Some(DnsError::ServerError(response.rcode)); continue; } return Ok(extract_addrs(&response.answers, qtype)); } Err(e) => { last_err = Some(e); continue; } } } Err(e) => { last_err = Some(DnsError::from(e)); continue; } } } Err(last_err.unwrap_or(DnsError::Timeout)) } } /// Extract IP addresses from DNS answer records. fn extract_addrs(answers: &[DnsRecord], qtype: u16) -> Vec { let mut addrs = Vec::new(); for record in answers { if record.rtype == TYPE_A && qtype == TYPE_A && record.rdata.len() == 4 { addrs.push(IpAddr::V4(Ipv4Addr::new( record.rdata[0], record.rdata[1], record.rdata[2], record.rdata[3], ))); } else if record.rtype == TYPE_AAAA && qtype == TYPE_AAAA && record.rdata.len() == 16 { let mut octets = [0u8; 16]; octets.copy_from_slice(&record.rdata); addrs.push(IpAddr::V6(Ipv6Addr::from(octets))); } } addrs } /// Generate a pseudo-random 16-bit transaction ID. fn generate_id() -> u16 { // Use a simple approach based on time for uniqueness let duration = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default(); // Mix nanoseconds for some entropy let nanos = duration.subsec_nanos(); let secs = duration.as_secs(); ((nanos ^ (secs as u32)) & 0xFFFF) as u16 } /// Convenience function: resolve a hostname using system nameservers. pub fn resolve(hostname: &str) -> Result> { DnsResolver::system().resolve(hostname) } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; // -- Name encoding tests -- #[test] fn encode_simple_name() { let mut buf = Vec::new(); encode_name(&mut buf, "example.com").unwrap(); assert_eq!( buf, vec![7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0] ); } #[test] fn encode_single_label() { let mut buf = Vec::new(); encode_name(&mut buf, "localhost").unwrap(); assert_eq!( buf, vec![9, b'l', b'o', b'c', b'a', b'l', b'h', b'o', b's', b't', 0] ); } #[test] fn encode_empty_name() { let mut buf = Vec::new(); encode_name(&mut buf, "").unwrap(); assert_eq!(buf, vec![0]); } #[test] fn encode_trailing_dot() { let mut buf = Vec::new(); encode_name(&mut buf, "example.com.").unwrap(); // Should handle trailing dot gracefully (skip empty label) assert_eq!( buf, vec![7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0] ); } #[test] fn encode_label_too_long() { let long_label = "a".repeat(64); let mut buf = Vec::new(); assert!(encode_name(&mut buf, &long_label).is_err()); } // -- Query building tests -- #[test] fn build_query_a_record() { let query = build_query(0x1234, "example.com", TYPE_A).unwrap(); // Header: 12 bytes assert_eq!(query[0..2], [0x12, 0x34]); // ID assert_eq!(query[2..4], [0x01, 0x00]); // Flags: RD=1 assert_eq!(query[4..6], [0x00, 0x01]); // QDCOUNT=1 assert_eq!(query[6..8], [0x00, 0x00]); // ANCOUNT=0 // QTYPE at end - 2 (A=1) let qtype_offset = query.len() - 4; assert_eq!(query[qtype_offset..qtype_offset + 2], [0x00, 0x01]); // TYPE A assert_eq!(query[qtype_offset + 2..qtype_offset + 4], [0x00, 0x01]); // CLASS IN } #[test] fn build_query_aaaa_record() { let query = build_query(0xABCD, "example.com", TYPE_AAAA).unwrap(); assert_eq!(query[0..2], [0xAB, 0xCD]); // ID let qtype_offset = query.len() - 4; assert_eq!(query[qtype_offset..qtype_offset + 2], [0x00, 0x1C]); // TYPE AAAA (28) } // -- Response parsing tests -- fn make_response(id: u16, rcode: u8, answers: &[(u16, &[u8])]) -> Vec { let mut buf = Vec::new(); // Header buf.extend_from_slice(&id.to_be_bytes()); let flags: u16 = FLAG_QR | (rcode as u16); // QR=1 buf.extend_from_slice(&flags.to_be_bytes()); buf.extend_from_slice(&1u16.to_be_bytes()); // QDCOUNT=1 buf.extend_from_slice(&(answers.len() as u16).to_be_bytes()); // ANCOUNT buf.extend_from_slice(&0u16.to_be_bytes()); // NSCOUNT buf.extend_from_slice(&0u16.to_be_bytes()); // ARCOUNT // Question section buf.extend_from_slice(&[ 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, ]); buf.extend_from_slice(&TYPE_A.to_be_bytes()); // QTYPE buf.extend_from_slice(&CLASS_IN.to_be_bytes()); // QCLASS // Answer section for &(rtype, rdata) in answers { // NAME: compression pointer to offset 12 (question name) buf.extend_from_slice(&[0xC0, 0x0C]); buf.extend_from_slice(&rtype.to_be_bytes()); // TYPE buf.extend_from_slice(&CLASS_IN.to_be_bytes()); // CLASS buf.extend_from_slice(&300u32.to_be_bytes()); // TTL buf.extend_from_slice(&(rdata.len() as u16).to_be_bytes()); // RDLENGTH buf.extend_from_slice(rdata); // RDATA } buf } #[test] fn parse_a_response() { let resp_data = make_response(0x1234, 0, &[(TYPE_A, &[93, 184, 216, 34])]); let mut parser = DnsParser::new(&resp_data); let response = parser.parse_response().unwrap(); assert_eq!(response.id, 0x1234); assert_eq!(response.rcode, 0); assert_eq!(response.answers.len(), 1); assert_eq!(response.answers[0].rtype, TYPE_A); assert_eq!(response.answers[0].rdata, vec![93, 184, 216, 34]); } #[test] fn parse_aaaa_response() { let ipv6_data: [u8; 16] = [ 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, ]; let resp_data = make_response(0x5678, 0, &[(TYPE_AAAA, &ipv6_data)]); let mut parser = DnsParser::new(&resp_data); let response = parser.parse_response().unwrap(); assert_eq!(response.id, 0x5678); assert_eq!(response.answers.len(), 1); assert_eq!(response.answers[0].rtype, TYPE_AAAA); assert_eq!(response.answers[0].rdata, ipv6_data.to_vec()); } #[test] fn parse_multiple_a_records() { let resp_data = make_response( 0xAAAA, 0, &[ (TYPE_A, &[1, 2, 3, 4]), (TYPE_A, &[5, 6, 7, 8]), (TYPE_A, &[9, 10, 11, 12]), ], ); let mut parser = DnsParser::new(&resp_data); let response = parser.parse_response().unwrap(); assert_eq!(response.answers.len(), 3); let addrs = extract_addrs(&response.answers, TYPE_A); assert_eq!(addrs.len(), 3); assert_eq!(addrs[0], IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4))); assert_eq!(addrs[1], IpAddr::V4(Ipv4Addr::new(5, 6, 7, 8))); assert_eq!(addrs[2], IpAddr::V4(Ipv4Addr::new(9, 10, 11, 12))); } #[test] fn parse_server_error_rcode() { let resp_data = make_response(0x1111, 2, &[]); // RCODE 2 = SERVFAIL let mut parser = DnsParser::new(&resp_data); let response = parser.parse_response().unwrap(); assert_eq!(response.rcode, 2); } #[test] fn parse_nxdomain() { let resp_data = make_response(0x2222, 3, &[]); // RCODE 3 = NXDOMAIN let mut parser = DnsParser::new(&resp_data); let response = parser.parse_response().unwrap(); assert_eq!(response.rcode, 3); } #[test] fn parse_response_too_short() { let short = vec![0u8; 6]; let mut parser = DnsParser::new(&short); assert!(parser.parse_response().is_err()); } #[test] fn parse_response_not_response() { // QR bit not set (flags = 0x0100, which is a query with RD) let mut buf = vec![0x00, 0x01]; // ID buf.extend_from_slice(&[0x01, 0x00]); // flags: RD=1, QR=0 buf.extend_from_slice(&[0, 0, 0, 0, 0, 0, 0, 0]); // counts let mut parser = DnsParser::new(&buf); assert!(parser.parse_response().is_err()); } // -- Name compression tests -- #[test] fn parse_compression_pointer() { // Build a response where answer name uses a compression pointer let resp_data = make_response(0x3333, 0, &[(TYPE_A, &[10, 0, 0, 1])]); let mut parser = DnsParser::new(&resp_data); let response = parser.parse_response().unwrap(); assert_eq!(response.answers.len(), 1); assert_eq!(response.answers[0].rdata, vec![10, 0, 0, 1]); } /// Test-only: decode a DNS name from wire format at a given offset, /// following compression pointers. fn decode_name(data: &[u8], offset: usize) -> std::result::Result { fn decode_at( data: &[u8], mut offset: usize, name: &mut String, depth: usize, ) -> std::result::Result<(), DnsError> { if depth > 64 { return Err(DnsError::MalformedResponse( "too many compression jumps".into(), )); } loop { if offset >= data.len() { return Err(DnsError::MalformedResponse("unexpected end of data".into())); } let len = data[offset]; offset += 1; if len == 0 { break; } if len & COMPRESSION_MASK == COMPRESSION_MASK { if offset >= data.len() { return Err(DnsError::MalformedResponse("truncated pointer".into())); } let ptr = (((len & !COMPRESSION_MASK) as u16) << 8 | data[offset] as u16) as usize; return decode_at(data, ptr, name, depth + 1); } let label_end = offset + len as usize; if label_end > data.len() { return Err(DnsError::MalformedResponse( "label extends past data".into(), )); } if !name.is_empty() { name.push('.'); } for &b in &data[offset..label_end] { name.push(b as char); } offset = label_end; } Ok(()) } let mut name = String::new(); decode_at(data, offset, &mut name, 0)?; Ok(name) } #[test] fn decode_name_with_compression() { let mut data = Vec::new(); // Name at offset 0: "test.com" data.extend_from_slice(&[4, b't', b'e', b's', b't', 3, b'c', b'o', b'm', 0]); // Pointer at offset 10: points to offset 0 data.extend_from_slice(&[0xC0, 0x00]); let name = decode_name(&data, 10).unwrap(); assert_eq!(name, "test.com"); } #[test] fn decode_name_no_compression() { let data = vec![ 3, b'w', b'w', b'w', 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, ]; let name = decode_name(&data, 0).unwrap(); assert_eq!(name, "www.example.com"); } #[test] fn decode_name_partial_compression() { // "www" label followed by pointer to "example.com" at offset 0 let mut data = Vec::new(); // Offset 0: "example.com" data.extend_from_slice(&[ 7, b'e', b'x', b'a', b'm', b'p', b'l', b'e', 3, b'c', b'o', b'm', 0, ]); // Offset 13: "www" + pointer to offset 0 data.extend_from_slice(&[3, b'w', b'w', b'w', 0xC0, 0x00]); let name = decode_name(&data, 13).unwrap(); assert_eq!(name, "www.example.com"); } // -- resolv.conf parsing tests -- #[test] fn parse_resolv_conf_basic() { let content = "nameserver 8.8.8.8\nnameserver 8.8.4.4\n"; let servers = parse_resolv_conf(content); assert_eq!(servers.len(), 2); assert_eq!(servers[0], IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); assert_eq!(servers[1], IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4))); } #[test] fn parse_resolv_conf_with_comments() { let content = "\ # DNS config nameserver 1.1.1.1 # backup nameserver 1.0.0.1 search example.com "; let servers = parse_resolv_conf(content); assert_eq!(servers.len(), 2); assert_eq!(servers[0], IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))); assert_eq!(servers[1], IpAddr::V4(Ipv4Addr::new(1, 0, 0, 1))); } #[test] fn parse_resolv_conf_ipv6() { let content = "nameserver 2001:4860:4860::8888\nnameserver 8.8.8.8\n"; let servers = parse_resolv_conf(content); assert_eq!(servers.len(), 2); assert!(servers[0].is_ipv6()); assert!(servers[1].is_ipv4()); } #[test] fn parse_resolv_conf_empty() { let servers = parse_resolv_conf(""); assert!(servers.is_empty()); } #[test] fn parse_resolv_conf_no_nameservers() { let content = "search example.com\noptions ndots:1\n"; let servers = parse_resolv_conf(content); assert!(servers.is_empty()); } #[test] fn parse_resolv_conf_invalid_addr() { let content = "nameserver not.an.ip.addr\nnameserver 8.8.8.8\n"; let servers = parse_resolv_conf(content); assert_eq!(servers.len(), 1); assert_eq!(servers[0], IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8))); } // -- extract_addrs tests -- #[test] fn extract_v4_addrs() { let records = vec![ DnsRecord { rtype: TYPE_A, rdata: vec![192, 168, 1, 1], }, DnsRecord { rtype: TYPE_A, rdata: vec![10, 0, 0, 1], }, ]; let addrs = extract_addrs(&records, TYPE_A); assert_eq!(addrs.len(), 2); assert_eq!(addrs[0], IpAddr::V4(Ipv4Addr::new(192, 168, 1, 1))); assert_eq!(addrs[1], IpAddr::V4(Ipv4Addr::new(10, 0, 0, 1))); } #[test] fn extract_v6_addrs() { let records = vec![DnsRecord { rtype: TYPE_AAAA, rdata: vec![0x20, 0x01, 0x0d, 0xb8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], }]; let addrs = extract_addrs(&records, TYPE_AAAA); assert_eq!(addrs.len(), 1); assert!(addrs[0].is_ipv6()); } #[test] fn extract_ignores_wrong_type() { let records = vec![DnsRecord { rtype: TYPE_AAAA, rdata: vec![0; 16], }]; // Asking for A records should skip AAAA let addrs = extract_addrs(&records, TYPE_A); assert!(addrs.is_empty()); } #[test] fn extract_ignores_bad_length() { let records = vec![DnsRecord { rtype: TYPE_A, rdata: vec![1, 2, 3], // Too short for A record (needs 4 bytes) }]; let addrs = extract_addrs(&records, TYPE_A); assert!(addrs.is_empty()); } // -- Error type tests -- #[test] fn dns_error_display() { assert_eq!( DnsError::NoNameservers.to_string(), "no DNS nameservers configured" ); assert_eq!(DnsError::Timeout.to_string(), "DNS query timed out"); assert_eq!( DnsError::MalformedResponse("bad".into()).to_string(), "malformed DNS response: bad" ); assert_eq!( DnsError::ServerError(2).to_string(), "DNS server error (RCODE 2)" ); } #[test] fn dns_error_from_io_timeout() { let err = DnsError::from(io::Error::new(io::ErrorKind::TimedOut, "timed out")); assert!(matches!(err, DnsError::Timeout)); } #[test] fn dns_error_from_io_would_block() { let err = DnsError::from(io::Error::new(io::ErrorKind::WouldBlock, "would block")); assert!(matches!(err, DnsError::Timeout)); } #[test] fn dns_error_from_io_other() { let err = DnsError::from(io::Error::new(io::ErrorKind::BrokenPipe, "broken")); assert!(matches!(err, DnsError::Io(_))); } // -- DnsResolver tests -- #[test] fn resolver_with_no_nameservers() { let resolver = DnsResolver::with_nameservers(vec![]); let result = resolver.resolve("example.com"); assert!(matches!(result, Err(DnsError::NoNameservers))); } #[test] fn resolver_passthrough_ip_v4() { let resolver = DnsResolver::with_nameservers(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]); let addrs = resolver.resolve("127.0.0.1").unwrap(); assert_eq!(addrs, vec![IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1))]); } #[test] fn resolver_passthrough_ip_v6() { let resolver = DnsResolver::with_nameservers(vec![IpAddr::V4(Ipv4Addr::LOCALHOST)]); let addrs = resolver.resolve("::1").unwrap(); assert_eq!(addrs, vec![IpAddr::V6(Ipv6Addr::LOCALHOST)]); } #[test] fn resolver_system_has_nameservers() { let resolver = DnsResolver::system(); // Should always have at least the fallback nameservers assert!(!resolver.nameservers.is_empty()); } #[test] fn resolver_set_timeout() { let mut resolver = DnsResolver::system(); resolver.set_timeout(Duration::from_secs(1)); assert_eq!(resolver.timeout, Duration::from_secs(1)); } // -- Integration test: real DNS resolution -- #[test] fn resolve_real_hostname() { // This test requires network access; skip in offline CI let resolver = DnsResolver::system(); match resolver.resolve("dns.google") { Ok(addrs) => { assert!(!addrs.is_empty(), "expected at least one address"); // dns.google should resolve to 8.8.8.8 and/or 8.8.4.4 let has_known = addrs.iter().any(|a| { *a == IpAddr::V4(Ipv4Addr::new(8, 8, 8, 8)) || *a == IpAddr::V4(Ipv4Addr::new(8, 8, 4, 4)) }); assert!( has_known, "expected dns.google to resolve to 8.8.8.8 or 8.8.4.4, got {addrs:?}" ); } Err(DnsError::Timeout) => { // Network may not be available in CI } Err(e) => panic!("unexpected error: {e}"), } } }