feat: WIP MDNS crate #6

merged
opened by sachy.dev targeting main from wip-mdns

Adds a mdns resolver/state-machine crate for providing MDNS-SD functionality for an embedded device.

+21
.tangled/workflows/miri.yml
··· 1 + when: 2 + - event: ["push", "pull_request"] 3 + branch: main 4 + 5 + engine: nixery 6 + 7 + dependencies: 8 + nixpkgs: 9 + - clang 10 + - rustup 11 + 12 + steps: 13 + - name: Install Nightly 14 + command: | 15 + rustup toolchain install nightly --component miri 16 + rustup override set nightly 17 + cargo miri setup 18 + - name: Miri Test 19 + command: cargo miri test --locked -p sachy-mdns 20 + environment: 21 + RUSTFLAGS: -Zrandomize-layout
+19 -3
Cargo.lock
··· 420 420 checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" 421 421 dependencies = [ 422 422 "libc", 423 - "windows-sys 0.52.0", 423 + "windows-sys 0.61.2", 424 424 ] 425 425 426 426 [[package]] ··· 1065 1065 "errno", 1066 1066 "libc", 1067 1067 "linux-raw-sys 0.11.0", 1068 - "windows-sys 0.52.0", 1068 + "windows-sys 0.61.2", 1069 1069 ] 1070 1070 1071 1071 [[package]] ··· 1119 1119 name = "sachy-fnv" 1120 1120 version = "0.1.0" 1121 1121 1122 + [[package]] 1123 + name = "sachy-mdns" 1124 + version = "0.1.0" 1125 + dependencies = [ 1126 + "defmt 1.0.1", 1127 + "embassy-time", 1128 + "sachy-fmt", 1129 + "winnow", 1130 + ] 1131 + 1122 1132 [[package]] 1123 1133 name = "sachy-shtc3" 1124 1134 version = "0.1.0" ··· 1292 1302 "getrandom", 1293 1303 "once_cell", 1294 1304 "rustix 1.1.2", 1295 - "windows-sys 0.52.0", 1305 + "windows-sys 0.61.2", 1296 1306 ] 1297 1307 1298 1308 [[package]] ··· 1483 1493 source = "registry+https://github.com/rust-lang/crates.io-index" 1484 1494 checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" 1485 1495 1496 + [[package]] 1497 + name = "winnow" 1498 + version = "0.7.14" 1499 + source = "registry+https://github.com/rust-lang/crates.io-index" 1500 + checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" 1501 + 1486 1502 [[package]] 1487 1503 name = "wit-bindgen" 1488 1504 version = "0.46.0"
+1
Cargo.toml
··· 6 6 "sachy-esphome", 7 7 "sachy-fmt", 8 8 "sachy-fnv", 9 + "sachy-mdns", 9 10 "sachy-shtc3", 10 11 "sachy-sntp", 11 12 ]
+22
sachy-mdns/Cargo.toml
··· 1 + [package] 2 + name = "sachy-mdns" 3 + authors.workspace = true 4 + edition.workspace = true 5 + repository.workspace = true 6 + license.workspace = true 7 + version.workspace = true 8 + rust-version.workspace = true 9 + 10 + [dependencies] 11 + defmt = { workspace = true, optional = true, features = ["alloc"] } 12 + embassy-time = { workspace = true } 13 + sachy-fmt = { path = "../sachy-fmt" } 14 + winnow = { version = "0.7.12", default-features = false } 15 + 16 + [features] 17 + default = [] 18 + std = [] 19 + defmt = ["dep:defmt"] 20 + 21 + [dev-dependencies] 22 + winnow = { version = "0.7.12", default-features = false, features = ["alloc"] }
+6
sachy-mdns/src/dns.rs
··· 1 + pub(crate) mod flags; 2 + pub(crate) mod label; 3 + pub(crate) mod query; 4 + pub(crate) mod records; 5 + pub(crate) mod reqres; 6 + pub mod traits;
+231
sachy-mdns/src/dns/flags.rs
··· 1 + #![allow(dead_code)] 2 + 3 + use core::{convert::Infallible, fmt}; 4 + use winnow::{ModalResult, Parser, binary::be_u16}; 5 + 6 + use crate::{ 7 + dns::traits::{DnsParse, DnsSerialize}, 8 + encoder::Encoder, 9 + }; 10 + 11 + #[derive(Default, Clone, Copy, PartialEq, Eq)] 12 + pub struct Flags(pub u16); 13 + 14 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 15 + #[repr(u8)] 16 + pub enum Opcode { 17 + Query = 0, 18 + IQuery = 1, 19 + Status = 2, 20 + Reserved = 3, 21 + Notify = 4, 22 + Update = 5, 23 + // Other values are reserved 24 + } 25 + 26 + impl Opcode { 27 + const fn cast(value: u8) -> Self { 28 + match value { 29 + 0 => Opcode::Query, 30 + 1 => Opcode::IQuery, 31 + 2 => Opcode::Status, 32 + 4 => Opcode::Notify, 33 + 5 => Opcode::Update, 34 + _ => Opcode::Reserved, 35 + } 36 + } 37 + } 38 + 39 + impl From<u8> for Opcode { 40 + fn from(value: u8) -> Self { 41 + Self::cast(value) 42 + } 43 + } 44 + 45 + impl From<Opcode> for u8 { 46 + fn from(opcode: Opcode) -> Self { 47 + opcode as u8 48 + } 49 + } 50 + 51 + impl Flags { 52 + const fn new() -> Self { 53 + Flags(0) 54 + } 55 + 56 + pub const fn standard_request() -> Self { 57 + let mut flags = Flags::new(); 58 + flags.set_query(true); 59 + flags.set_opcode(Opcode::Query); 60 + flags.set_recursion_desired(true); 61 + flags 62 + } 63 + 64 + pub const fn standard_response() -> Self { 65 + let mut flags = Flags::new(); 66 + flags.set_query(false); 67 + flags.set_opcode(Opcode::Query); 68 + flags.set_authoritative(true); 69 + flags.set_recursion_available(false); 70 + flags 71 + } 72 + 73 + // QR: Query/Response Flag 74 + pub const fn is_query(&self) -> bool { 75 + (self.0 & 0x8000) == 0 76 + } 77 + 78 + pub const fn set_query(&mut self, is_query: bool) { 79 + if is_query { 80 + self.0 &= !0x8000; 81 + } else { 82 + self.0 |= 0x8000; 83 + } 84 + } 85 + 86 + // Opcode (bits 1-4) 87 + pub const fn get_opcode(&self) -> Opcode { 88 + Opcode::cast(((self.0 >> 11) & 0x0F) as u8) 89 + } 90 + 91 + pub const fn set_opcode(&mut self, opcode: Opcode) { 92 + self.0 = (self.0 & !0x7800) | (((opcode as u8) as u16 & 0x0F) << 11); 93 + } 94 + 95 + // AA: Authoritative Answer 96 + pub const fn is_authoritative(&self) -> bool { 97 + (self.0 & 0x0400) != 0 98 + } 99 + 100 + pub const fn set_authoritative(&mut self, authoritative: bool) { 101 + if authoritative { 102 + self.0 |= 0x0400; 103 + } else { 104 + self.0 &= !0x0400; 105 + } 106 + } 107 + 108 + // TC: Truncated 109 + pub const fn is_truncated(&self) -> bool { 110 + (self.0 & 0x0200) != 0 111 + } 112 + 113 + pub const fn set_truncated(&mut self, truncated: bool) { 114 + if truncated { 115 + self.0 |= 0x0200; 116 + } else { 117 + self.0 &= !0x0200; 118 + } 119 + } 120 + 121 + // RD: Recursion Desired 122 + pub const fn is_recursion_desired(&self) -> bool { 123 + (self.0 & 0x0100) != 0 124 + } 125 + 126 + pub const fn set_recursion_desired(&mut self, recursion_desired: bool) { 127 + if recursion_desired { 128 + self.0 |= 0x0100; 129 + } else { 130 + self.0 &= !0x0100; 131 + } 132 + } 133 + 134 + // RA: Recursion Available 135 + pub const fn is_recursion_available(&self) -> bool { 136 + (self.0 & 0x0080) != 0 137 + } 138 + 139 + pub const fn set_recursion_available(&mut self, recursion_available: bool) { 140 + if recursion_available { 141 + self.0 |= 0x0080; 142 + } else { 143 + self.0 &= !0x0080; 144 + } 145 + } 146 + 147 + // Z: Reserved for future use (bits 9-11) 148 + pub const fn get_reserved(&self) -> u8 { 149 + ((self.0 >> 4) & 0x07) as u8 150 + } 151 + 152 + pub const fn set_reserved(&mut self, reserved: u8) { 153 + self.0 = (self.0 & !0x0070) | ((reserved as u16 & 0x07) << 4); 154 + } 155 + 156 + // RCODE: Response Code (bits 12-15) 157 + pub const fn get_rcode(&self) -> u8 { 158 + (self.0 & 0x000F) as u8 159 + } 160 + 161 + pub const fn set_rcode(&mut self, rcode: u8) { 162 + self.0 = (self.0 & !0x000F) | (rcode as u16 & 0x0F); 163 + } 164 + } 165 + 166 + impl<'a> DnsParse<'a> for Flags { 167 + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 168 + be_u16.map(Flags).parse_next(input) 169 + } 170 + } 171 + 172 + impl<'a> DnsSerialize<'a> for Flags { 173 + type Error = Infallible; 174 + 175 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 176 + encoder.write(&self.0.to_be_bytes()); 177 + Ok(()) 178 + } 179 + 180 + fn size(&self) -> usize { 181 + core::mem::size_of::<u16>() 182 + } 183 + } 184 + 185 + impl fmt::Debug for Flags { 186 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 187 + f.debug_struct("Flags") 188 + .field("query", &self.is_query()) 189 + .field("opcode", &self.get_opcode()) 190 + .field("authoritative", &self.is_authoritative()) 191 + .field("truncated", &self.is_truncated()) 192 + .field("recursion_desired", &self.is_recursion_desired()) 193 + .field("recursion_available", &self.is_recursion_available()) 194 + .field("reserved", &self.get_reserved()) 195 + .field("rcode", &self.get_rcode()) 196 + .finish() 197 + } 198 + } 199 + 200 + #[cfg(feature = "defmt")] 201 + impl defmt::Format for Flags { 202 + fn format(&self, fmt: defmt::Formatter) { 203 + defmt::write!( 204 + fmt, 205 + "Flags {{ query: {}, opcode: {:?}, authoritative: {}, truncated: {}, recursion_desired: {}, recursion_available: {}, reserved: {}, rcode: {} }}", 206 + self.is_query(), 207 + self.get_opcode(), 208 + self.is_authoritative(), 209 + self.is_truncated(), 210 + self.is_recursion_desired(), 211 + self.is_recursion_available(), 212 + self.get_reserved(), 213 + self.get_rcode() 214 + ); 215 + } 216 + } 217 + 218 + #[cfg(feature = "defmt")] 219 + impl defmt::Format for Opcode { 220 + fn format(&self, fmt: defmt::Formatter) { 221 + let opcode_str = match self { 222 + Opcode::Query => "Query", 223 + Opcode::IQuery => "IQuery", 224 + Opcode::Status => "Status", 225 + Opcode::Reserved => "Reserved", 226 + Opcode::Notify => "Notify", 227 + Opcode::Update => "Update", 228 + }; 229 + defmt::write!(fmt, "Opcode({=str})", opcode_str); 230 + } 231 + }
+511
sachy-mdns/src/dns/label.rs
··· 1 + use core::{fmt, str}; 2 + 3 + use winnow::{ 4 + ModalResult, Parser, 5 + binary::be_u8, 6 + error::{ContextError, ErrMode, FromExternalError}, 7 + stream::Offset, 8 + token::take, 9 + }; 10 + 11 + use crate::{ 12 + dns::traits::{DnsParse, DnsSerialize}, 13 + encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK}, 14 + }; 15 + 16 + #[derive(Clone, Copy)] 17 + pub struct Label<'a> { 18 + repr: LabelRepr<'a>, 19 + } 20 + 21 + impl<'a> From<&'a str> for Label<'a> { 22 + fn from(value: &'a str) -> Self { 23 + Self { 24 + repr: LabelRepr::Str(value), 25 + } 26 + } 27 + } 28 + 29 + impl<'a> DnsSerialize<'a> for Label<'a> { 30 + type Error = DnsError; 31 + 32 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 33 + match self.repr { 34 + LabelRepr::Bytes { 35 + context, 36 + start, 37 + end, 38 + } => { 39 + encoder.write(&context[start..end]); 40 + 41 + Ok(()) 42 + } 43 + LabelRepr::Str(label) => encoder.write_label(label), 44 + } 45 + } 46 + 47 + fn size(&self) -> usize { 48 + match self.repr { 49 + LabelRepr::Bytes { 50 + context, 51 + start, 52 + end, 53 + } => core::mem::size_of_val(&context[start..end]), 54 + LabelRepr::Str(label) => core::mem::size_of_val(label) + 1, 55 + } 56 + } 57 + } 58 + 59 + impl<'a> DnsParse<'a> for Label<'a> { 60 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 61 + let start = input.offset_from(&context); 62 + let mut end = start; 63 + 64 + loop { 65 + match LabelSegment::parse(input)? { 66 + LabelSegment::Empty => { 67 + end += 1; 68 + break; 69 + } 70 + LabelSegment::String(label) => { 71 + end += 1 + label.len(); 72 + } 73 + LabelSegment::Pointer(_) => { 74 + end += 2; 75 + break; 76 + } 77 + } 78 + } 79 + 80 + Ok(Self { 81 + repr: LabelRepr::Bytes { 82 + context, 83 + start, 84 + end, 85 + }, 86 + }) 87 + } 88 + } 89 + 90 + impl Label<'_> { 91 + pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'_>> { 92 + self.repr.iter() 93 + } 94 + 95 + pub fn names(&self) -> impl Iterator<Item = &'_ str> { 96 + match self.repr { 97 + LabelRepr::Str(view) => Either::A(view.split('.')), 98 + LabelRepr::Bytes { context, start, .. } => Either::B( 99 + LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()), 100 + ), 101 + } 102 + } 103 + 104 + pub fn is_empty(&self) -> bool { 105 + self.repr.iter().next().is_none() 106 + } 107 + } 108 + 109 + #[derive(Clone, Copy, PartialEq, Eq)] 110 + enum LabelRepr<'a> { 111 + Bytes { 112 + context: &'a [u8], 113 + start: usize, 114 + end: usize, 115 + }, 116 + Str(&'a str), 117 + } 118 + 119 + /// A DNS-compatible label segment. 120 + #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] 121 + pub enum LabelSegment<'a> { 122 + /// The empty terminator. 123 + Empty, 124 + 125 + /// A string label. 126 + String(&'a str), 127 + 128 + /// A pointer to a previous name. 129 + Pointer(u16), 130 + } 131 + 132 + impl<'a> LabelSegment<'a> { 133 + fn parse(input: &mut &'a [u8]) -> ModalResult<Self> { 134 + let b1 = be_u8(input)?; 135 + 136 + match b1 { 137 + 0 => Ok(Self::Empty), 138 + b1 if b1 & PTR_MASK == PTR_MASK => { 139 + let b2 = be_u8(input)?; 140 + 141 + let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); 142 + 143 + Ok(Self::Pointer(ptr)) 144 + } 145 + len => { 146 + if len > MAX_STR_LEN { 147 + return Err(ErrMode::Cut(ContextError::from_external_error( 148 + input, DnsError::LabelTooLong, 149 + ))); 150 + } 151 + 152 + let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?; 153 + 154 + Ok(Self::String(segment)) 155 + } 156 + } 157 + } 158 + 159 + /// ## Safety 160 + /// The caller upholds that this function is not called when parsing from newly received data. Data that 161 + /// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`], 162 + /// and that the entire data/context has been validated, not just a portion of it. 163 + #[inline] 164 + unsafe fn parse_unchecked(input: &'a [u8]) -> Option<Self> { 165 + input.split_first().map(|(b1, input)| match *b1 { 166 + 0 => Self::Empty, 167 + b1 if b1 & PTR_MASK == PTR_MASK => { 168 + // SAFETY: The caller has already validated that a second byte is available for a 169 + // Pointer segment. 170 + let b2 = unsafe { *input.get_unchecked(0) }; 171 + 172 + let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); 173 + 174 + Self::Pointer(ptr) 175 + } 176 + len => { 177 + // SAFETY: The caller has validated that this length value is correct and will only 178 + // access within the bounds of the provided slice. 179 + let segment = unsafe { input.get_unchecked(0..(len as usize)) }; 180 + // SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before. 181 + let segment = unsafe { core::str::from_utf8_unchecked(segment) }; 182 + 183 + Self::String(segment) 184 + } 185 + }) 186 + } 187 + 188 + fn as_str(&self) -> Option<&'a str> { 189 + match self { 190 + Self::String(label) => Some(*label), 191 + _ => None, 192 + } 193 + } 194 + } 195 + 196 + pub struct LabelSegmentBytesIter<'a> { 197 + context: &'a [u8], 198 + start: usize, 199 + } 200 + 201 + impl<'a> LabelSegmentBytesIter<'a> { 202 + pub(crate) fn new(context: &'a [u8], start: usize) -> Self { 203 + Self { context, start } 204 + } 205 + } 206 + 207 + impl<'a> Iterator for LabelSegmentBytesIter<'a> { 208 + type Item = LabelSegment<'a>; 209 + 210 + fn next(&mut self) -> Option<Self::Item> { 211 + loop { 212 + let view = &self.context[self.start..]; 213 + 214 + // SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes 215 + let segment = unsafe { LabelSegment::parse_unchecked(view)? }; 216 + 217 + match segment { 218 + LabelSegment::String(label) => { 219 + self.start = self.start.saturating_add(label.len() + 1); 220 + return Some(segment); 221 + } 222 + LabelSegment::Pointer(ptr) => { 223 + self.start = ptr as usize; 224 + } 225 + LabelSegment::Empty => { 226 + // Set the index offset to be len() so that the view is empty and terminates the loop 227 + self.start = self.context.len(); 228 + return Some(LabelSegment::Empty); 229 + } 230 + } 231 + } 232 + } 233 + } 234 + 235 + impl<'a> LabelRepr<'a> { 236 + fn iter(&self) -> impl Iterator<Item = LabelSegment<'a>> { 237 + match *self { 238 + LabelRepr::Bytes { context, start, .. } => { 239 + Either::A(LabelSegmentBytesIter::new(context, start)) 240 + } 241 + LabelRepr::Str(view) => Either::B( 242 + view.split('.') 243 + .map(LabelSegment::String) 244 + .chain(Some(LabelSegment::Empty)), 245 + ), 246 + } 247 + } 248 + } 249 + 250 + impl fmt::Debug for Label<'_> { 251 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 252 + struct LabelFmt<'a>(&'a Label<'a>); 253 + 254 + impl fmt::Debug for LabelFmt<'_> { 255 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 256 + fmt::Display::fmt(self.0, f) 257 + } 258 + } 259 + 260 + f.debug_tuple("Label").field(&LabelFmt(self)).finish() 261 + } 262 + } 263 + 264 + impl fmt::Display for Label<'_> { 265 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 266 + let mut names = self.names(); 267 + 268 + if let Some(name) = names.next() { 269 + f.write_str(name)?; 270 + 271 + names.try_for_each(|name| { 272 + f.write_str(".")?; 273 + f.write_str(name) 274 + }) 275 + } else { 276 + Ok(()) 277 + } 278 + } 279 + } 280 + 281 + impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> { 282 + fn eq(&self, other: &Label<'a>) -> bool { 283 + self.segments().eq(other.segments()) 284 + } 285 + } 286 + 287 + impl Eq for Label<'_> {} 288 + 289 + impl PartialEq<&str> for Label<'_> { 290 + fn eq(&self, other: &&str) -> bool { 291 + let mut self_iter = self.names(); 292 + let mut other_iter = other.split('.'); 293 + 294 + loop { 295 + match (self_iter.next(), other_iter.next()) { 296 + (Some(self_part), Some(other_part)) => { 297 + if self_part != other_part { 298 + return false; 299 + } 300 + } 301 + (None, None) => return true, 302 + _ => return false, 303 + } 304 + } 305 + } 306 + } 307 + 308 + #[cfg(feature = "defmt")] 309 + impl defmt::Format for Label<'_> { 310 + fn format(&self, fmt: defmt::Formatter) { 311 + defmt::write!(fmt, "Label("); 312 + let mut iter = self.names(); 313 + if let Some(first) = iter.next() { 314 + defmt::write!(fmt, "{}", first); 315 + 316 + iter.for_each(|part| defmt::write!(fmt, ".{}", part)); 317 + } 318 + defmt::write!(fmt, ")"); 319 + } 320 + } 321 + 322 + /// One iterator or another. 323 + enum Either<A, B> { 324 + A(A), 325 + B(B), 326 + } 327 + 328 + impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> { 329 + type Item = A::Item; 330 + 331 + fn next(&mut self) -> Option<Self::Item> { 332 + match self { 333 + Either::A(a) => a.next(), 334 + Either::B(b) => b.next(), 335 + } 336 + } 337 + 338 + fn size_hint(&self) -> (usize, Option<usize>) { 339 + match self { 340 + Either::A(a) => a.size_hint(), 341 + Either::B(b) => b.size_hint(), 342 + } 343 + } 344 + 345 + fn fold<B, F>(self, init: B, f: F) -> B 346 + where 347 + Self: Sized, 348 + F: FnMut(B, Self::Item) -> B, 349 + { 350 + match self { 351 + Either::A(a) => a.fold(init, f), 352 + Either::B(b) => b.fold(init, f), 353 + } 354 + } 355 + } 356 + 357 + #[cfg(test)] 358 + mod test { 359 + use super::*; 360 + 361 + #[test] 362 + fn segments_iter_test() { 363 + let label: Label<'static> = Label::from("_service._udp.local"); 364 + let mut segments = label.segments(); 365 + 366 + assert_eq!(segments.next(), Some(LabelSegment::String("_service"))); 367 + assert_eq!(segments.next(), Some(LabelSegment::String("_udp"))); 368 + assert_eq!(segments.next(), Some(LabelSegment::String("local"))); 369 + assert_eq!(segments.next(), Some(LabelSegment::Empty)); 370 + assert_eq!(segments.next(), None); 371 + 372 + // example.com with a pointer to the start 373 + let data = b"\x07example\x03com\x00\xC0\x00"; 374 + let context = &data[..]; 375 + // The data here is entirely valid, even though we parse only a portion of it. 376 + let label = Label::parse(&mut &data[13..], context).unwrap(); 377 + 378 + let mut segments = label.segments(); 379 + assert_eq!(segments.next(), Some(LabelSegment::String("example"))); 380 + assert_eq!(segments.next(), Some(LabelSegment::String("com"))); 381 + assert_eq!(segments.next(), Some(LabelSegment::Empty)); 382 + assert_eq!(segments.next(), None); 383 + } 384 + 385 + #[test] 386 + fn names_iter_test() { 387 + let label: Label<'static> = Label::from("_service._udp.local"); 388 + let mut names = label.names(); 389 + 390 + assert_eq!(names.next(), Some("_service")); 391 + assert_eq!(names.next(), Some("_udp")); 392 + assert_eq!(names.next(), Some("local")); 393 + assert_eq!(names.next(), None); 394 + 395 + let data = b"\x07example\x03com\x00\xC0\x00"; 396 + let context = &data[..]; 397 + // The data here is entirely valid, even though we parse only a portion of it. 398 + let label = Label::parse(&mut &data[13..], context).unwrap(); 399 + 400 + let mut names = label.names(); 401 + assert_eq!(names.next(), Some("example")); 402 + assert_eq!(names.next(), Some("com")); 403 + assert_eq!(names.next(), None); 404 + } 405 + 406 + #[test] 407 + fn serialize_str_label() { 408 + let label: Label<'static> = Label::from("_service._udp.local"); 409 + let mut buffer = [0u8; 256]; 410 + let mut buffer = Encoder::new(&mut buffer); 411 + label.serialize(&mut buffer).unwrap(); 412 + assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00"); 413 + } 414 + 415 + #[test] 416 + fn serialize_compressed_str_label() { 417 + let label: Label<'static> = Label::from("_service._udp.local"); 418 + let label2: Label<'static> = Label::from("other._udp.local"); 419 + let mut buffer = [0u8; 256]; 420 + let mut buffer = Encoder::new(&mut buffer); 421 + label.serialize(&mut buffer).unwrap(); 422 + label2.serialize(&mut buffer).unwrap(); 423 + assert_eq!( 424 + buffer.finish(), 425 + b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" 426 + ); 427 + } 428 + 429 + #[test] 430 + fn round_trip_compressed_str_label() { 431 + let label: Label<'static> = Label::from("_service._udp.local"); 432 + let label2: Label<'static> = Label::from("other._udp.local"); 433 + let mut buffer = [0u8; 256]; 434 + let mut buffer = Encoder::new(&mut buffer); 435 + label.serialize(&mut buffer).unwrap(); 436 + label2.serialize(&mut buffer).unwrap(); 437 + let context = buffer.finish(); 438 + assert_eq!( 439 + context, 440 + b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" 441 + ); 442 + let view = &mut &context[..]; 443 + 444 + let parsed_label = Label::parse(view, context).unwrap(); 445 + let parsed_label2 = Label::parse(view, context).unwrap(); 446 + 447 + let parsed_label_count = parsed_label.segments().count(); 448 + let parsed_label2_count = parsed_label2.segments().count(); 449 + 450 + // Both have same amount of segments. 451 + assert_eq!(parsed_label_count, parsed_label2_count); 452 + 453 + assert_eq!(parsed_label, label); 454 + assert_eq!(parsed_label2, label2); 455 + } 456 + 457 + #[test] 458 + fn label_byte_repr_serialization_quick_path() { 459 + let data = b"\x07example\x03com\x00\xC0\x00"; 460 + let context = &data[..]; 461 + // The data here is entirely valid, even though we parse only a portion of it. 462 + let label = Label::parse(&mut &data[13..], context).unwrap(); 463 + 464 + let mut buffer = [0u8; 256]; 465 + let mut buffer = Encoder::new(&mut buffer); 466 + label.serialize(&mut buffer).unwrap(); 467 + // If the original Label is just a pointer, the new output will be a pointer, assuming 468 + // the original data is also present in the output 469 + assert_eq!(buffer.finish(), b"\xC0\x00"); 470 + } 471 + 472 + #[test] 473 + fn parse_and_eq_created_label() { 474 + let data = b"\x07example\x03com\x00\xC0\x00"; 475 + let context = &data[..]; 476 + 477 + // The data here is entirely valid, even though we parse only a portion of it. 478 + let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); 479 + 480 + let created_label = Label::from("example.com"); 481 + 482 + assert_eq!(parsed_label, created_label); 483 + } 484 + 485 + #[test] 486 + fn parse_and_eq_label_with_str() { 487 + let data = b"\x07example\x03com\x00"; 488 + let context = &data[..]; 489 + 490 + let parsed_label = Label::parse(&mut &data[..], context).unwrap(); 491 + 492 + assert_eq!(parsed_label, "example.com"); 493 + } 494 + 495 + #[test] 496 + fn parse_ptr_label_and_eq_with_str() { 497 + let data = b"\x07example\x03com\x00\xC0\x00"; 498 + let context = &data[..]; 499 + 500 + // The data here is entirely valid, even though we parse only a portion of it. 501 + let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); 502 + 503 + assert_eq!(parsed_label, "example.com"); 504 + } 505 + 506 + #[test] 507 + fn label_new_without_dot_is_not_empty() { 508 + let label: Label = Label::from("example"); 509 + assert!(!label.is_empty()); 510 + } 511 + }
+228
sachy-mdns/src/dns/query.rs
··· 1 + use winnow::binary::{be_u16, be_u32}; 2 + use winnow::{ModalResult, Parser}; 3 + 4 + use super::label::Label; 5 + use super::records::Record; 6 + use crate::encoder::Encoder; 7 + use crate::{ 8 + dns::{ 9 + records::QType, 10 + traits::{DnsParse, DnsParseKind, DnsSerialize}, 11 + }, 12 + encoder::DnsError, 13 + }; 14 + 15 + #[derive(Debug, PartialEq, Eq)] 16 + pub struct Query<'a> { 17 + pub name: Label<'a>, 18 + pub qtype: QType, 19 + pub qclass: QClass, 20 + } 21 + 22 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 23 + #[repr(u16)] 24 + pub enum QClass { 25 + IN = 1, 26 + Multicast = 32769, // (IN + Cache flush bit) 27 + Unknown(u16), 28 + } 29 + 30 + impl<'a> DnsParse<'a> for Query<'a> { 31 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 32 + let name = Label::parse(input, context)?; 33 + let qtype = QType::parse(input, context)?; 34 + let qclass = be_u16.map(QClass::from_u16).parse_next(input)?; 35 + 36 + Ok(Query { 37 + name, 38 + qtype, 39 + qclass, 40 + }) 41 + } 42 + } 43 + 44 + impl<'a> DnsSerialize<'a> for Query<'a> { 45 + type Error = DnsError; 46 + 47 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 48 + self.name.serialize(encoder)?; 49 + self.qtype.serialize(encoder).ok(); 50 + encoder.write(&self.qclass.to_u16().to_be_bytes()); 51 + Ok(()) 52 + } 53 + 54 + fn size(&self) -> usize { 55 + self.name.size() + self.qtype.size() + core::mem::size_of::<QClass>() 56 + } 57 + } 58 + 59 + #[derive(Debug, PartialEq, Eq)] 60 + pub struct Answer<'a> { 61 + pub name: Label<'a>, 62 + pub atype: QType, 63 + pub aclass: QClass, 64 + pub ttl: u32, 65 + pub record: Record<'a>, 66 + } 67 + 68 + impl QClass { 69 + fn from_u16(value: u16) -> Self { 70 + match value { 71 + 1 => QClass::IN, 72 + 32769 => QClass::Multicast, 73 + _ => QClass::Unknown(value), 74 + } 75 + } 76 + 77 + fn to_u16(self) -> u16 { 78 + match self { 79 + QClass::IN => 1, 80 + QClass::Multicast => 32769, 81 + QClass::Unknown(value) => value, 82 + } 83 + } 84 + } 85 + 86 + impl<'a> DnsParse<'a> for Answer<'a> { 87 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 88 + let name = Label::parse(input, context)?; 89 + let atype = QType::parse(input, context)?; 90 + let aclass = be_u16.map(QClass::from_u16).parse_next(input)?; 91 + 92 + let ttl = be_u32.parse_next(input)?; 93 + let record = atype.parse_kind(input, context)?; 94 + 95 + Ok(Answer { 96 + name, 97 + atype, 98 + aclass, 99 + ttl, 100 + record, 101 + }) 102 + } 103 + } 104 + 105 + impl<'a> DnsSerialize<'a> for Answer<'a> { 106 + type Error = DnsError; 107 + 108 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 109 + self.name.serialize(encoder)?; 110 + self.atype.serialize(encoder).ok(); 111 + encoder.write(&self.aclass.to_u16().to_be_bytes()); 112 + encoder.write(&self.ttl.to_be_bytes()); 113 + self.record.serialize(encoder) 114 + } 115 + 116 + fn size(&self) -> usize { 117 + self.name.size() 118 + + self.atype.size() 119 + + core::mem::size_of::<QClass>() 120 + + core::mem::size_of::<u32>() 121 + + self.record.size() 122 + } 123 + } 124 + 125 + #[cfg(feature = "defmt")] 126 + impl<'a> defmt::Format for Query<'a> { 127 + fn format(&self, fmt: defmt::Formatter) { 128 + defmt::write!( 129 + fmt, 130 + "Query {{ name: {:?}, qtype: {:?}, qclass: {:?} }}", 131 + self.name, 132 + self.qtype, 133 + self.qclass 134 + ); 135 + } 136 + } 137 + 138 + #[cfg(feature = "defmt")] 139 + impl defmt::Format for QType { 140 + fn format(&self, fmt: defmt::Formatter) { 141 + let qtype_str = match self { 142 + QType::A => "A", 143 + QType::AAAA => "AAAA", 144 + QType::PTR => "PTR", 145 + QType::TXT => "TXT", 146 + QType::SRV => "SRV", 147 + QType::Any => "Any", 148 + QType::Unknown(_) => "Unknown", 149 + }; 150 + defmt::write!(fmt, "QType({=str})", qtype_str); 151 + } 152 + } 153 + 154 + #[cfg(feature = "defmt")] 155 + impl defmt::Format for QClass { 156 + fn format(&self, fmt: defmt::Formatter) { 157 + let qclass_str = match self { 158 + QClass::IN => "IN", 159 + QClass::Multicast => "Multicast", 160 + QClass::Unknown(_) => "Unknown", 161 + }; 162 + defmt::write!(fmt, "QClass({=str})", qclass_str); 163 + } 164 + } 165 + 166 + #[cfg(feature = "defmt")] 167 + impl<'a> defmt::Format for Answer<'a> { 168 + fn format(&self, fmt: defmt::Formatter) { 169 + defmt::write!( 170 + fmt, 171 + "Answer {{ name: {:?}, atype: {:?}, aclass: {:?}, ttl: {}, record: {:?} }}", 172 + self.name, 173 + self.atype, 174 + self.aclass, 175 + self.ttl, 176 + self.record 177 + ); 178 + } 179 + } 180 + 181 + #[cfg(test)] 182 + mod tests { 183 + use super::*; 184 + use crate::dns::records::A; 185 + use core::net::Ipv4Addr; 186 + 187 + #[test] 188 + fn roundtrip_query() { 189 + let name = Label::from("example.local"); 190 + 191 + let query = Query { 192 + name, 193 + qtype: QType::A, 194 + qclass: QClass::IN, 195 + }; 196 + 197 + let mut buffer = [0u8; 256]; 198 + let mut buffer = Encoder::new(&mut buffer); 199 + query.serialize(&mut buffer).unwrap(); 200 + let buffer = buffer.finish(); 201 + let parsed_query = Query::parse(&mut &buffer[..], buffer).unwrap(); 202 + 203 + assert_eq!(query, parsed_query); 204 + } 205 + 206 + #[test] 207 + fn roundtrip_answer() { 208 + let name = Label::from("example.local"); 209 + 210 + let answer: Answer = Answer { 211 + name, 212 + atype: QType::A, 213 + aclass: QClass::IN, 214 + ttl: 120, 215 + record: Record::A(A { 216 + address: Ipv4Addr::new(192, 168, 1, 1), 217 + }), 218 + }; 219 + 220 + let mut buffer = [0u8; 256]; 221 + let mut buffer = Encoder::new(&mut buffer); 222 + answer.serialize(&mut buffer).unwrap(); 223 + let buffer = buffer.finish(); 224 + let parsed_answer = Answer::parse(&mut &buffer[..], buffer).unwrap(); 225 + 226 + assert_eq!(answer, parsed_answer); 227 + } 228 + }
+413
sachy-mdns/src/dns/records.rs
··· 1 + use core::{ 2 + convert::Infallible, 3 + net::{Ipv4Addr, Ipv6Addr}, 4 + str, 5 + }; 6 + 7 + use alloc::vec::Vec; 8 + use winnow::token::take; 9 + use winnow::{ModalResult, Parser}; 10 + use winnow::{binary::be_u8, error::ContextError}; 11 + use winnow::{binary::be_u16, error::FromExternalError}; 12 + 13 + use super::label::Label; 14 + use crate::{ 15 + dns::traits::{DnsParse, DnsParseKind, DnsSerialize}, 16 + encoder::{DnsError, Encoder}, 17 + }; 18 + 19 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 20 + #[repr(u16)] 21 + #[allow(clippy::upper_case_acronyms)] 22 + pub enum QType { 23 + A = 1, 24 + AAAA = 28, 25 + PTR = 12, 26 + TXT = 16, 27 + SRV = 33, 28 + Any = 255, 29 + Unknown(u16), 30 + } 31 + 32 + impl<'a> DnsParse<'a> for QType { 33 + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 34 + be_u16.map(QType::from_u16).parse_next(input) 35 + } 36 + } 37 + 38 + impl<'a> DnsSerialize<'a> for QType { 39 + type Error = Infallible; 40 + 41 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 42 + encoder.write(&self.to_u16().to_be_bytes()); 43 + Ok(()) 44 + } 45 + 46 + fn size(&self) -> usize { 47 + core::mem::size_of::<QType>() 48 + } 49 + } 50 + 51 + impl<'a> DnsParseKind<'a> for QType { 52 + type Output = Record<'a>; 53 + 54 + fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output> { 55 + match self { 56 + QType::A => { 57 + let record = A::parse(input, context)?; 58 + Ok(Record::A(record)) 59 + } 60 + QType::AAAA => { 61 + let record = AAAA::parse(input, context)?; 62 + Ok(Record::AAAA(record)) 63 + } 64 + QType::PTR => { 65 + let record = PTR::parse(input, context)?; 66 + Ok(Record::PTR(record)) 67 + } 68 + QType::TXT => { 69 + let record = TXT::parse(input, context)?; 70 + Ok(Record::TXT(record)) 71 + } 72 + QType::SRV => { 73 + let record = SRV::parse(input, context)?; 74 + Ok(Record::SRV(record)) 75 + } 76 + QType::Any => Err(winnow::error::ErrMode::Backtrack( 77 + ContextError::from_external_error(input, DnsError::Unsupported), 78 + )), 79 + QType::Unknown(_) => Err(winnow::error::ErrMode::Backtrack( 80 + ContextError::from_external_error(input, DnsError::Unsupported), 81 + )), 82 + } 83 + } 84 + } 85 + 86 + impl QType { 87 + fn from_u16(value: u16) -> Self { 88 + match value { 89 + 1 => QType::A, 90 + 28 => QType::AAAA, 91 + 12 => QType::PTR, 92 + 16 => QType::TXT, 93 + 33 => QType::SRV, 94 + 255 => QType::Any, 95 + _ => QType::Unknown(value), 96 + } 97 + } 98 + 99 + fn to_u16(self) -> u16 { 100 + match self { 101 + QType::A => 1, 102 + QType::AAAA => 28, 103 + QType::PTR => 12, 104 + QType::TXT => 16, 105 + QType::SRV => 33, 106 + QType::Any => 255, 107 + QType::Unknown(value) => value, 108 + } 109 + } 110 + } 111 + 112 + #[derive(Debug, PartialEq, Eq)] 113 + #[allow(clippy::upper_case_acronyms)] 114 + // Enum for DNS-SD records 115 + pub enum Record<'a> { 116 + A(A), 117 + AAAA(AAAA), 118 + PTR(PTR<'a>), 119 + TXT(TXT<'a>), 120 + SRV(SRV<'a>), 121 + } 122 + 123 + impl<'a> DnsSerialize<'a> for Record<'a> { 124 + type Error = DnsError; 125 + 126 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 127 + match self { 128 + Record::A(record) => { 129 + record.serialize(encoder).ok(); 130 + } 131 + Record::AAAA(record) => { 132 + record.serialize(encoder).ok(); 133 + } 134 + Record::PTR(record) => { 135 + record.serialize(encoder)?; 136 + } 137 + Record::TXT(record) => { 138 + record.serialize(encoder).ok(); 139 + } 140 + Record::SRV(record) => { 141 + record.serialize(encoder)?; 142 + } 143 + }; 144 + 145 + Ok(()) 146 + } 147 + 148 + fn size(&self) -> usize { 149 + match self { 150 + Self::A(a) => a.size(), 151 + Self::AAAA(aaaa) => aaaa.size(), 152 + Self::PTR(ptr) => ptr.size(), 153 + Self::TXT(txt) => txt.size(), 154 + Self::SRV(srv) => srv.size(), 155 + } 156 + } 157 + } 158 + 159 + // Struct for A record 160 + #[derive(Debug, PartialEq, Eq)] 161 + pub struct A { 162 + pub address: Ipv4Addr, 163 + } 164 + 165 + impl<'a> DnsParse<'a> for A { 166 + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 167 + let len = be_u16.parse_next(input)?; 168 + let address = take(len) 169 + .try_map(<[u8; 4]>::try_from) 170 + .map(Ipv4Addr::from) 171 + .parse_next(input)?; 172 + 173 + Ok(A { address }) 174 + } 175 + } 176 + 177 + impl<'a> DnsSerialize<'a> for A { 178 + type Error = Infallible; 179 + 180 + fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> { 181 + let len = 4u16.to_be_bytes(); 182 + writer.write(&len); 183 + writer.write(&self.address.octets()); 184 + Ok(()) 185 + } 186 + 187 + fn size(&self) -> usize { 188 + core::mem::size_of::<Ipv4Addr>() + core::mem::size_of::<u16>() 189 + } 190 + } 191 + 192 + // Struct for AAAA record 193 + #[derive(Debug, PartialEq, Eq)] 194 + #[allow(clippy::upper_case_acronyms)] 195 + pub struct AAAA { 196 + pub address: Ipv6Addr, 197 + } 198 + 199 + impl<'a> DnsParse<'a> for AAAA { 200 + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 201 + let len = be_u16.parse_next(input)?; 202 + let address = take(len) 203 + .try_map(<[u8; 16]>::try_from) 204 + .map(Ipv6Addr::from) 205 + .parse_next(input)?; 206 + 207 + Ok(AAAA { address }) 208 + } 209 + } 210 + 211 + impl<'a> DnsSerialize<'a> for AAAA { 212 + type Error = Infallible; 213 + 214 + fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> { 215 + let len = 16u16.to_be_bytes(); 216 + writer.write(&len); 217 + writer.write(&self.address.octets()); 218 + Ok(()) 219 + } 220 + 221 + fn size(&self) -> usize { 222 + core::mem::size_of::<Ipv6Addr>() + core::mem::size_of::<u16>() 223 + } 224 + } 225 + 226 + // Struct for PTR record 227 + #[derive(Debug, PartialEq, Eq)] 228 + #[allow(clippy::upper_case_acronyms)] 229 + pub struct PTR<'a> { 230 + pub name: Label<'a>, 231 + } 232 + 233 + impl<'a> DnsParse<'a> for PTR<'a> { 234 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 235 + let _ = be_u16.parse_next(input)?; 236 + let name = Label::parse(input, context)?; 237 + Ok(PTR { name }) 238 + } 239 + } 240 + 241 + impl<'a> DnsSerialize<'a> for PTR<'a> { 242 + type Error = DnsError; 243 + 244 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 245 + encoder.with_record_length(|enc| self.name.serialize(enc)) 246 + } 247 + 248 + fn size(&self) -> usize { 249 + self.name.size() + core::mem::size_of::<u16>() 250 + } 251 + } 252 + 253 + // Struct for TXT record 254 + #[derive(Debug, PartialEq, Eq)] 255 + #[allow(clippy::upper_case_acronyms)] 256 + pub struct TXT<'a> { 257 + pub text: Vec<&'a str>, 258 + } 259 + 260 + impl<'a> DnsParse<'a> for TXT<'a> { 261 + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> { 262 + let text_len = be_u16.parse_next(input)?; 263 + 264 + let mut total = 0u16; 265 + let mut text = Vec::new(); 266 + 267 + while total < text_len { 268 + let len = be_u8(input)?; 269 + 270 + total += 1 + len as u16; 271 + 272 + if len > 0 { 273 + let part = take(len).try_map(core::str::from_utf8).parse_next(input)?; 274 + text.push(part); 275 + } 276 + } 277 + 278 + Ok(TXT { text }) 279 + } 280 + } 281 + 282 + impl<'a> DnsSerialize<'a> for TXT<'a> { 283 + type Error = DnsError; 284 + 285 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 286 + encoder.with_record_length(|enc| { 287 + self.text.iter().try_for_each(|&part| { 288 + let text_len = u8::try_from(part.len()) 289 + .map_err(|_| DnsError::InvalidTxt) 290 + .map(u8::to_be_bytes)?; 291 + 292 + enc.write(&text_len); 293 + enc.write(part.as_bytes()); 294 + 295 + Ok(()) 296 + }) 297 + }) 298 + } 299 + 300 + fn size(&self) -> usize { 301 + let len_size = core::mem::size_of::<u16>(); 302 + 303 + let text_size = if self.text.is_empty() { 304 + 1 305 + } else { 306 + self.text.iter().map(|part| part.len() + 1).sum() 307 + }; 308 + 309 + len_size + text_size 310 + } 311 + } 312 + 313 + // Struct for SRV record 314 + #[derive(Debug, PartialEq, Eq)] 315 + #[allow(clippy::upper_case_acronyms)] 316 + pub struct SRV<'a> { 317 + pub priority: u16, 318 + pub weight: u16, 319 + pub port: u16, 320 + pub target: Label<'a>, 321 + } 322 + 323 + impl<'a> DnsParse<'a> for SRV<'a> { 324 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 325 + let _ = be_u16.parse_next(input)?; 326 + let priority = be_u16.parse_next(input)?; 327 + let weight = be_u16.parse_next(input)?; 328 + let port = be_u16.parse_next(input)?; 329 + let target = Label::parse(input, context)?; 330 + 331 + Ok(SRV { 332 + priority, 333 + weight, 334 + port, 335 + target, 336 + }) 337 + } 338 + } 339 + 340 + impl<'a> DnsSerialize<'a> for SRV<'a> { 341 + type Error = DnsError; 342 + 343 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 344 + encoder.with_record_length(|enc| { 345 + enc.write(&self.priority.to_be_bytes()); 346 + enc.write(&self.weight.to_be_bytes()); 347 + enc.write(&self.port.to_be_bytes()); 348 + 349 + self.target.serialize(enc) 350 + }) 351 + } 352 + 353 + fn size(&self) -> usize { 354 + (core::mem::size_of::<u16>() * 4) + self.target.size() 355 + } 356 + } 357 + 358 + #[cfg(feature = "defmt")] 359 + impl defmt::Format for A { 360 + fn format(&self, fmt: defmt::Formatter) { 361 + // use crate::format::FormatIpv4Addr; 362 + defmt::write!(fmt, "A({})", self.address) 363 + } 364 + } 365 + 366 + #[cfg(feature = "defmt")] 367 + impl defmt::Format for AAAA { 368 + fn format(&self, fmt: defmt::Formatter) { 369 + // use crate::format::FormatIpv6Addr; 370 + defmt::write!(fmt, "AAAA({})", self.address) 371 + } 372 + } 373 + 374 + #[cfg(feature = "defmt")] 375 + impl<'a> defmt::Format for Record<'a> { 376 + fn format(&self, fmt: defmt::Formatter) { 377 + match self { 378 + Record::A(record) => defmt::write!(fmt, "Record::A({:?})", record), 379 + Record::AAAA(record) => defmt::write!(fmt, "Record::AAAA({:?})", record), 380 + Record::PTR(record) => defmt::write!(fmt, "Record::PTR({:?})", record), 381 + Record::TXT(record) => defmt::write!(fmt, "Record::TXT({:?})", record), 382 + Record::SRV(record) => defmt::write!(fmt, "Record::SRV({:?})", record), 383 + } 384 + } 385 + } 386 + 387 + #[cfg(feature = "defmt")] 388 + impl<'a> defmt::Format for PTR<'a> { 389 + fn format(&self, fmt: defmt::Formatter) { 390 + defmt::write!(fmt, "PTR {{ name: {:?} }}", self.name); 391 + } 392 + } 393 + 394 + #[cfg(feature = "defmt")] 395 + impl<'a> defmt::Format for TXT<'a> { 396 + fn format(&self, fmt: defmt::Formatter) { 397 + defmt::write!(fmt, "TXT {{ text: {:?} }}", self.text); 398 + } 399 + } 400 + 401 + #[cfg(feature = "defmt")] 402 + impl<'a> defmt::Format for SRV<'a> { 403 + fn format(&self, fmt: defmt::Formatter) { 404 + defmt::write!( 405 + fmt, 406 + "SRV {{ priority: {}, weight: {}, port: {}, target: {:?} }}", 407 + self.priority, 408 + self.weight, 409 + self.port, 410 + self.target 411 + ); 412 + } 413 + }
+481
sachy-mdns/src/dns/reqres.rs
··· 1 + use alloc::vec::Vec; 2 + use winnow::ModalResult; 3 + use winnow::binary::be_u16; 4 + 5 + use super::flags::Flags; 6 + use super::query::{Answer, Query}; 7 + use crate::{ 8 + dns::traits::{DnsParse, DnsSerialize}, 9 + encoder::{DnsError, Encoder}, 10 + }; 11 + 12 + const ZERO_U16: [u8; 2] = 0u16.to_be_bytes(); 13 + 14 + #[derive(Debug, PartialEq, Eq)] 15 + pub struct Request<'a> { 16 + pub id: u16, 17 + pub flags: Flags, 18 + pub(crate) queries: Vec<Query<'a>>, 19 + } 20 + 21 + impl<'a> DnsParse<'a> for Request<'a> { 22 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 23 + let id = be_u16(input)?; 24 + let flags = Flags::parse(input, context)?; 25 + let qdcount = be_u16(input)?; 26 + let _ancount = be_u16(input)?; 27 + let _nscount = be_u16(input)?; 28 + let _arcount = be_u16(input)?; 29 + let queries = (0..qdcount) 30 + .map(|_| Query::parse(input, context)) 31 + .collect::<Result<Vec<_>, _>>()?; 32 + Ok(Request { id, flags, queries }) 33 + } 34 + } 35 + 36 + impl<'a> DnsSerialize<'a> for Request<'a> { 37 + type Error = DnsError; 38 + 39 + fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 40 + writer.write(&self.id.to_be_bytes()); 41 + self.flags.serialize(writer).ok(); 42 + writer.write(&(self.queries.len() as u16).to_be_bytes()); 43 + writer.write(&ZERO_U16); 44 + writer.write(&ZERO_U16); 45 + writer.write(&ZERO_U16); 46 + 47 + self.queries 48 + .iter() 49 + .try_for_each(|query| query.serialize(writer)) 50 + } 51 + 52 + fn size(&self) -> usize { 53 + let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); 54 + 55 + core::mem::size_of::<u16>() 56 + + self.flags.size() 57 + + (core::mem::size_of::<u16>() * 4) 58 + + total_query_size 59 + } 60 + } 61 + 62 + #[derive(Debug, PartialEq, Eq)] 63 + pub struct Response<'a> { 64 + pub id: u16, 65 + pub flags: Flags, 66 + pub queries: Vec<Query<'a>>, 67 + pub answers: Vec<Answer<'a>>, 68 + } 69 + 70 + impl<'a> DnsParse<'a> for Response<'a> { 71 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> { 72 + let id = be_u16(input)?; 73 + let flags = Flags::parse(input, context)?; 74 + let qdcount = be_u16(input)?; 75 + let ancount = be_u16(input)?; 76 + let _nscount = be_u16(input)?; 77 + let _arcount = be_u16(input)?; 78 + 79 + let queries = (0..qdcount) 80 + .map(|_| Query::parse(input, context)) 81 + .collect::<Result<Vec<_>, _>>()?; 82 + 83 + let answers = (0..ancount) 84 + .map(|_| Answer::parse(input, context)) 85 + .collect::<Result<Vec<_>, _>>()?; 86 + 87 + Ok(Response { 88 + id, 89 + flags, 90 + queries, 91 + answers, 92 + }) 93 + } 94 + } 95 + 96 + impl<'a> DnsSerialize<'a> for Response<'a> { 97 + type Error = DnsError; 98 + 99 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { 100 + encoder.write(&self.id.to_be_bytes()); 101 + self.flags.serialize(encoder).ok(); 102 + encoder.write(&(self.queries.len() as u16).to_be_bytes()); 103 + encoder.write(&(self.answers.len() as u16).to_be_bytes()); 104 + encoder.write(&ZERO_U16); 105 + encoder.write(&ZERO_U16); 106 + 107 + self.queries 108 + .iter() 109 + .try_for_each(|query| query.serialize(encoder))?; 110 + self.answers 111 + .iter() 112 + .try_for_each(|answer| answer.serialize(encoder)) 113 + } 114 + 115 + fn size(&self) -> usize { 116 + let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); 117 + let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum(); 118 + 119 + core::mem::size_of::<u16>() 120 + + self.flags.size() 121 + + (core::mem::size_of::<u16>() * 4) 122 + + total_query_size 123 + + total_answer_size 124 + } 125 + } 126 + 127 + #[cfg(feature = "defmt")] 128 + impl<'a> defmt::Format for Request<'a> { 129 + fn format(&self, fmt: defmt::Formatter) { 130 + defmt::write!( 131 + fmt, 132 + "Request {{ id: {}, flags: {:?}, queries: {:?} }}", 133 + self.id, 134 + self.flags, 135 + self.queries 136 + ); 137 + } 138 + } 139 + 140 + #[cfg(feature = "defmt")] 141 + impl<'a> defmt::Format for Response<'a> { 142 + fn format(&self, fmt: defmt::Formatter) { 143 + defmt::write!( 144 + fmt, 145 + "Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}", 146 + self.id, 147 + self.flags, 148 + self.queries, 149 + self.answers 150 + ); 151 + } 152 + } 153 + 154 + #[cfg(test)] 155 + mod tests { 156 + use alloc::vec; 157 + 158 + use super::*; 159 + use crate::dns::{ 160 + label::Label, 161 + query::QClass, 162 + records::{A, PTR, QType, Record, SRV, TXT}, 163 + }; 164 + use core::net::Ipv4Addr; 165 + 166 + #[test] 167 + fn parse_query() { 168 + let data = [ 169 + 0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65, 170 + // example . com in label format 171 + 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 172 + // 173 + 0x00, 0x01, 0x00, 0x01, 174 + ]; 175 + 176 + let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 177 + 178 + assert_eq!(request.id, 0xAAAA); 179 + assert_eq!(request.flags.0, 0x0100); 180 + assert_eq!(request.queries.len(), 1); 181 + assert_eq!(request.queries[0].name, "example.com"); 182 + assert_eq!(request.queries[0].qtype, QType::A); 183 + assert_eq!(request.queries[0].qclass, QClass::IN); 184 + } 185 + 186 + #[test] 187 + fn parse_response() { 188 + let data = [ 189 + 0xAA, 0xAA, // transaction ID 190 + 0x81, 0x80, // flags 191 + 0x00, 0x01, // 1 question 192 + 0x00, 0x01, // 1 A-answer 193 + 0x00, 0x00, // no authority 194 + 0x00, 0x00, // no additional answers 195 + // example . com in label format 196 + 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 197 + // 198 + 0x00, 0x01, 0x00, 0x01, // 199 + // 200 + 0xC0, 0x0C, // ptr to question section 201 + // 202 + 0x00, 0x01, 0x00, 0x01, // A and IN 203 + // 204 + 0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds 205 + // 206 + 0x00, 0x04, // length of address 207 + // IP address: 208 + 192, 168, 1, 3, 209 + ]; 210 + 211 + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 212 + 213 + assert_eq!(response.id, 0xAAAA); 214 + assert_eq!(response.flags.0, 0x8180); 215 + assert_eq!(response.answers.len(), 1); 216 + assert_eq!(response.answers[0].name, "example.com"); 217 + assert_eq!(response.answers[0].atype, QType::A); 218 + assert_eq!(response.answers[0].aclass, QClass::IN); 219 + assert_eq!(response.answers[0].ttl, 60); 220 + if let Record::A(a) = &response.answers[0].record { 221 + assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3)); 222 + } else { 223 + panic!("Expected A record"); 224 + } 225 + } 226 + 227 + #[test] 228 + fn parse_response_two_records() { 229 + let data = [ 230 + 0xAA, 0xAA, // 231 + 0x81, 0x80, // 232 + 0x00, 0x01, // 233 + 0x00, 0x02, // 234 + 0x00, 0x00, // 235 + 0x00, 0x00, // 236 + // example . com in label format 237 + 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 238 + // 239 + 0x00, 0x01, // query type 240 + 0x00, 0x01, // query class 241 + // 242 + 0xC0, 0x0C, // pointer 243 + 0x00, 0x01, // 244 + 0x00, 0x01, // 245 + 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds 246 + 0x00, 0x04, // length of A-record 247 + 0x5D, 0xB8, 0xD8, 0x22, // a-record 248 + // 249 + 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // 250 + // 251 + 0x00, 0x10, // TXT 252 + 0x00, 0x01, // IN 253 + // 254 + 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds 255 + // 256 + 0x00, 0x10, // length of txt record 257 + // (len) "test txt record" 258 + 0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72, 259 + 0x64, 260 + ]; 261 + 262 + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 263 + 264 + assert_eq!(response.id, 0xAAAA); 265 + assert_eq!(response.flags.0, 0x8180); 266 + assert_eq!(response.answers.len(), 2); 267 + 268 + // First answer 269 + assert_eq!(response.answers[0].name, "example.com"); 270 + assert_eq!(response.answers[0].atype, QType::A); 271 + assert_eq!(response.answers[0].aclass, QClass::IN); 272 + assert_eq!(response.answers[0].ttl, 60); 273 + if let Record::A(a) = &response.answers[0].record { 274 + assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34)); 275 + } else { 276 + panic!("Expected A record"); 277 + } 278 + 279 + // Second answer 280 + assert_eq!(response.answers[1].name, "example.com"); 281 + assert_eq!(response.answers[1].atype, QType::TXT); 282 + assert_eq!(response.answers[1].aclass, QClass::IN); 283 + assert_eq!(response.answers[1].ttl, 60); 284 + if let Record::TXT(txt) = &response.answers[1].record 285 + && let Some(&text) = txt.text.first() 286 + { 287 + assert_eq!(text, "test txt record"); 288 + } else { 289 + panic!("Expected TXT record"); 290 + } 291 + } 292 + 293 + #[test] 294 + fn parse_response_srv() { 295 + let data = [ 296 + // 297 + 0xAA, 0xAA, // id 298 + 0x81, 0x80, // flags 299 + 0x00, 0x01, // one question 300 + 0x00, 0x01, // one answer 301 + 0x00, 0x00, // no authority 302 + 0x00, 0x00, // no extra 303 + // 304 + 0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61, 305 + 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // 306 + // 307 + 0x00, 0x21, // type SRV 308 + 0x00, 0x01, // IN 309 + // 310 + 0xc0, 0x0c, // 311 + // 312 + 0x00, 0x21, // SRV 313 + 0x00, 0x01, // IN 314 + 0x00, 0x00, 0x00, 0x3C, // ttl 60 315 + // 316 + 0x00, 0x19, // data len 317 + 0x00, 0x0A, // prio 318 + 0x00, 0x05, // weight 319 + 0x13, 0xC4, // PORT 320 + // 321 + 0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61, 322 + 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // 323 + ]; 324 + 325 + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 326 + 327 + assert_eq!(response.id, 0xAAAA); 328 + assert_eq!(response.flags.0, 0x8180); 329 + assert_eq!(response.answers.len(), 1); 330 + 331 + // Answer 332 + assert_eq!(response.answers[0].name, "_sip._tcp.example.com"); 333 + assert_eq!(response.answers[0].atype, QType::SRV); 334 + assert_eq!(response.answers[0].aclass, QClass::IN); 335 + assert_eq!(response.answers[0].ttl, 60); 336 + let Record::SRV(srv) = &response.answers[0].record else { 337 + panic!("Expected SRV record"); 338 + }; 339 + 340 + assert_eq!(srv.priority, 10); 341 + assert_eq!(srv.weight, 5); 342 + assert_eq!(srv.port, 5060); 343 + assert_eq!(srv.target, "sipserver.example.com"); 344 + } 345 + 346 + #[test] 347 + fn parse_response_back_forth() { 348 + let data = [ 349 + 0, 0, // Transaction ID 350 + 132, 0, // Response, Authoritative Answer, No Recursion 351 + 0, 0, // 0 questions 352 + 0, 4, // 4 answers 353 + 0, 0, // 0 authority RRs 354 + 0, 0, // 0 additional RRs 355 + // _midiriff 356 + 9, 95, 109, 105, 100, 105, 114, 105, 102, 102, // 357 + // _udp 358 + 4, 95, 117, 100, 112, // 359 + // local 360 + 5, 108, 111, 99, 97, 108, // 361 + 0, // <end> 362 + // 363 + 0, 12, // PTR 364 + 0, 1, // Class IN 365 + 0, 0, 0, 120, // TTL 120 seconds 366 + 0, 10, // Data Length 10 367 + // pi35291 368 + 7, 112, 105, 51, 53, 50, 57, 49, // 369 + // 370 + 192, 12, // Pointer to _midirif._udp._local. 371 + // 372 + 192, 44, // Pointer to instace name: pi35291._midirif._udp._local. 373 + 0, 33, // SRV 374 + 128, 1, // IN (Cache flush bit set) 375 + 0, 0, 0, 120, // TTL 120 seconds 376 + 0, 11, // Data Length 11 377 + 0, 0, // Priority 0 378 + 0, 0, // Weight 0 379 + 137, 219, // Port 35291 380 + 2, 112, 105, // _pi 381 + 192, 27, // Pointer to: .local. 382 + // TXT (Empty) 383 + 192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0, 384 + // A (10.1.1.9) 385 + 192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9, 386 + ]; 387 + 388 + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); 389 + 390 + assert_eq!(response.answers[0].name, "_midiriff._udp.local"); 391 + assert_eq!(response.answers[0].ttl, 120); 392 + let Record::PTR(ptr) = &response.answers[0].record else { 393 + panic!() 394 + }; 395 + assert_eq!(ptr.name, "pi35291._midiriff._udp.local"); 396 + 397 + let mut buffer = [0u8; 256]; 398 + let mut buffer = Encoder::new(&mut buffer); 399 + response.serialize(&mut buffer).unwrap(); 400 + 401 + let buffer = buffer.finish(); 402 + 403 + let response2 = Response::parse(&mut &buffer[..], buffer).unwrap(); 404 + 405 + assert_eq!(response, response2); 406 + } 407 + 408 + #[test] 409 + fn mdns_service_response() { 410 + let mut response = Response { 411 + id: 0x1234, 412 + flags: Flags::standard_response(), 413 + queries: Vec::new(), 414 + answers: Vec::new(), 415 + }; 416 + 417 + let query = Query { 418 + name: Label::from("_test._udp.local"), 419 + qtype: QType::PTR, 420 + qclass: QClass::IN, 421 + }; 422 + response.queries.push(query); 423 + 424 + let ptr_answer = Answer { 425 + name: Label::from("_test._udp.local"), 426 + atype: QType::PTR, 427 + aclass: QClass::IN, 428 + ttl: 4500, 429 + record: Record::PTR(PTR { 430 + name: Label::from("test-service._test._udp.local"), 431 + }), 432 + }; 433 + response.answers.push(ptr_answer); 434 + 435 + let srv_answer = Answer { 436 + name: Label::from("test-service._test._udp.local"), 437 + atype: QType::SRV, 438 + aclass: QClass::IN, 439 + ttl: 120, 440 + record: Record::SRV(SRV { 441 + priority: 0, 442 + weight: 0, 443 + port: 8080, 444 + target: Label::from("host.local"), 445 + }), 446 + }; 447 + response.answers.push(srv_answer); 448 + 449 + let txt_answer = Answer { 450 + name: Label::from("test-service._test._udp.local"), 451 + atype: QType::TXT, 452 + aclass: QClass::IN, 453 + ttl: 120, 454 + record: Record::TXT(TXT { 455 + text: vec!["path=/test"], 456 + }), 457 + }; 458 + response.answers.push(txt_answer); 459 + 460 + let a_answer = Answer { 461 + name: Label::from("host.local"), 462 + atype: QType::A, 463 + aclass: QClass::IN, 464 + ttl: 120, 465 + record: Record::A(A { 466 + address: Ipv4Addr::new(192, 168, 1, 100), 467 + }), 468 + }; 469 + response.answers.push(a_answer); 470 + 471 + let mut buffer = [0u8; 256]; 472 + let mut buffer = Encoder::new(&mut buffer); 473 + response.serialize(&mut buffer).unwrap(); 474 + 475 + let buffer = buffer.finish(); 476 + 477 + let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap(); 478 + 479 + assert_eq!(response, parsed_response); 480 + } 481 + }
+21
sachy-mdns/src/dns/traits.rs
··· 1 + use winnow::ModalResult; 2 + 3 + use crate::encoder::Encoder; 4 + 5 + pub trait DnsParse<'a>: Sized { 6 + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self>; 7 + } 8 + 9 + pub trait DnsParseKind<'a> { 10 + type Output; 11 + 12 + fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output>; 13 + } 14 + 15 + pub trait DnsSerialize<'a> { 16 + type Error; 17 + 18 + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error>; 19 + #[allow(dead_code)] 20 + fn size(&self) -> usize; 21 + }
+145
sachy-mdns/src/encoder.rs
··· 1 + use core::ops::Range; 2 + 3 + use alloc::collections::BTreeMap; 4 + 5 + use crate::dns::traits::DnsSerialize; 6 + 7 + pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK; 8 + pub(crate) const PTR_MASK: u8 = 0b1100_0000; 9 + 10 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 11 + #[cfg_attr(feature = "defmt", derive(defmt::Format))] 12 + pub enum DnsError { 13 + LabelTooLong, 14 + InvalidTxt, 15 + Unsupported, 16 + } 17 + 18 + impl core::fmt::Display for DnsError { 19 + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { 20 + match self { 21 + Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"), 22 + Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"), 23 + Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"), 24 + } 25 + } 26 + } 27 + 28 + impl core::error::Error for DnsError {} 29 + 30 + #[derive(Debug)] 31 + pub struct Encoder<'a, 'b> { 32 + output: &'b mut [u8], 33 + position: usize, 34 + lookup: BTreeMap<&'a str, u16>, 35 + reservation: Option<usize>, 36 + } 37 + 38 + impl<'a, 'b> Encoder<'a, 'b> { 39 + pub const fn new(buffer: &'b mut [u8]) -> Self { 40 + Self { 41 + output: buffer, 42 + position: 0, 43 + lookup: BTreeMap::new(), 44 + reservation: None, 45 + } 46 + } 47 + 48 + /// Takes a payload and encodes it, consuming the encoder and yielding the resulting 49 + /// slice. 50 + pub fn encode<T, E>(mut self, payload: T) -> Result<&'b [u8], E> 51 + where 52 + E: core::error::Error, 53 + T: DnsSerialize<'a, Error = E>, 54 + { 55 + payload.serialize(&mut self)?; 56 + Ok(self.finish()) 57 + } 58 + 59 + pub(crate) fn finish(self) -> &'b [u8] { 60 + &self.output[..self.position] 61 + } 62 + 63 + fn increment(&mut self, amount: usize) { 64 + self.position += amount; 65 + } 66 + 67 + pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> { 68 + loop { 69 + if let Some(pos) = self.get_label_position(label) { 70 + let [b1, b2] = u16::to_be_bytes(pos); 71 + self.write(&[b1 | PTR_MASK, b2]); 72 + return Ok(()); 73 + } 74 + 75 + let dot = label.find('.'); 76 + 77 + let end = dot.unwrap_or(label.len()); 78 + let segment = &label[..end]; 79 + let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?; 80 + 81 + if len > MAX_STR_LEN { 82 + return Err(DnsError::LabelTooLong); 83 + } 84 + 85 + self.store_label_position(label); 86 + self.write(&len.to_be_bytes()); 87 + self.write(segment.as_bytes()); 88 + 89 + match dot { 90 + Some(end) => { 91 + label = &label[end + 1..]; 92 + } 93 + None => { 94 + self.write(&[0]); 95 + return Ok(()); 96 + } 97 + } 98 + } 99 + } 100 + 101 + pub(crate) fn write(&mut self, bytes: &[u8]) { 102 + let len = bytes.len(); 103 + let end = self.position + len; 104 + self.output[self.position..end].copy_from_slice(bytes); 105 + self.increment(len); 106 + } 107 + 108 + fn get_label_position(&mut self, label: &str) -> Option<u16> { 109 + self.lookup.get(label).copied() 110 + } 111 + 112 + fn store_label_position(&mut self, label: &'a str) { 113 + self.lookup.insert(label, self.position as u16); 114 + } 115 + 116 + fn reserve_record_length(&mut self) { 117 + if self.reservation.is_none() { 118 + self.reservation = Some(self.position); 119 + self.increment(2); 120 + } 121 + } 122 + 123 + fn distance_from_reservation(&mut self) -> Option<(Range<usize>, u16)> { 124 + self.reservation 125 + .take() 126 + .map(|start| (start..(start + 2), (self.position - start - 2) as u16)) 127 + } 128 + 129 + fn write_record_length(&mut self) { 130 + if let Some((reservation, len)) = self.distance_from_reservation() { 131 + self.output[reservation].copy_from_slice(&len.to_be_bytes()); 132 + } 133 + } 134 + 135 + pub(crate) fn with_record_length<E, F>(&mut self, encoding_scope: F) -> Result<(), E> 136 + where 137 + E: core::error::Error, 138 + F: FnOnce(&mut Self) -> Result<(), E>, 139 + { 140 + self.reserve_record_length(); 141 + encoding_scope(self)?; 142 + self.write_record_length(); 143 + Ok(()) 144 + } 145 + }
+60
sachy-mdns/src/lib.rs
··· 1 + #![no_std] 2 + 3 + mod dns; 4 + pub(crate) mod encoder; 5 + pub mod server; 6 + mod service; 7 + mod state; 8 + 9 + extern crate alloc; 10 + 11 + use core::net::{Ipv4Addr, SocketAddrV4}; 12 + 13 + pub use crate::service::Service; 14 + pub use crate::state::MdnsAction; 15 + use crate::{dns::flags::Flags, server::Server, state::MdnsStateMachine}; 16 + 17 + /// Standard port for mDNS (5353). 18 + pub const MDNS_PORT: u16 = 5353; 19 + 20 + /// Standard IPv4 multicast address for mDNS (224.0.0.251). 21 + pub const GROUP_ADDR_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251); 22 + pub const GROUP_SOCK_V4: SocketAddrV4 = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT); 23 + 24 + #[derive(Debug)] 25 + #[cfg_attr(feature = "defmt", derive(defmt::Format))] 26 + pub struct MdnsService { 27 + server: Server, 28 + state: MdnsStateMachine, 29 + } 30 + 31 + impl MdnsService { 32 + pub fn new(service: Service) -> Self { 33 + Self { 34 + server: Server::new(service), 35 + state: Default::default(), 36 + } 37 + } 38 + 39 + pub fn next_action(&mut self) -> MdnsAction { 40 + self.state.drive_next_action() 41 + } 42 + 43 + pub fn send_announcement<'buffer>(&self, outgoing: &'buffer mut [u8]) -> Option<&'buffer [u8]> { 44 + self.server.broadcast( 45 + server::ResponseKind::Announcement, 46 + Flags::standard_response(), 47 + 1, 48 + alloc::vec::Vec::new(), 49 + outgoing, 50 + ) 51 + } 52 + 53 + pub fn listen_for_queries<'buffer>( 54 + &mut self, 55 + incoming: &[u8], 56 + outgoing: &'buffer mut [u8], 57 + ) -> Option<&'buffer [u8]> { 58 + self.server.respond(incoming, outgoing) 59 + } 60 + }
+95
sachy-mdns/src/server.rs
··· 1 + use alloc::vec::Vec; 2 + use sachy_fmt::{error, info}; 3 + 4 + use crate::{ 5 + dns::{ 6 + flags::Flags, 7 + query::{QClass, Query}, 8 + records::QType, 9 + reqres::{Request, Response}, 10 + traits::DnsParse, 11 + }, 12 + encoder::Encoder, 13 + service::Service, 14 + }; 15 + 16 + pub(crate) enum ResponseKind { 17 + Announcement, 18 + QueryResponse(Vec<(QType, QClass)>), 19 + } 20 + 21 + #[derive(Debug)] 22 + #[cfg_attr(feature = "defmt", derive(defmt::Format))] 23 + pub(crate) struct Server { 24 + service: Service, 25 + } 26 + 27 + impl Server { 28 + pub(crate) fn new(service: Service) -> Self { 29 + Self { service } 30 + } 31 + 32 + pub(crate) fn broadcast<'a, 'b>( 33 + &self, 34 + response_kind: ResponseKind, 35 + flags: Flags, 36 + id: u16, 37 + queries: Vec<Query<'b>>, 38 + outgoing: &'a mut [u8], 39 + ) -> Option<&'a [u8]> { 40 + let answers: Vec<_> = match response_kind { 41 + ResponseKind::Announcement => self.service.as_answers(QClass::Multicast).collect(), 42 + ResponseKind::QueryResponse(valid) => valid 43 + .iter() 44 + .flat_map(|&(qtype, qclass)| match qtype { 45 + QType::A | QType::AAAA => self.service.ip_answer(qclass), 46 + QType::PTR => self.service.ptr_answer(qclass), 47 + QType::TXT => self.service.txt_answer(qclass), 48 + QType::SRV => self.service.srv_answer(qclass), 49 + QType::Any | QType::Unknown(_) => None, 50 + }) 51 + .collect(), 52 + }; 53 + 54 + if !answers.is_empty() { 55 + let res = Response { 56 + flags, 57 + id, 58 + queries, 59 + answers, 60 + }; 61 + 62 + info!("MDNS RESPONSE: {}", res); 63 + 64 + return Encoder::new(outgoing) 65 + .encode(res) 66 + .inspect_err(|err| error!("Encoder errored: {}", err)) 67 + .ok(); 68 + } 69 + 70 + None 71 + } 72 + 73 + pub(crate) fn respond<'a>(&self, incoming: &[u8], outgoing: &'a mut [u8]) -> Option<&'a [u8]> { 74 + Request::parse(&mut &incoming[..], incoming) 75 + .ok() 76 + .and_then(|req| { 77 + let valid_queries = req 78 + .queries 79 + .iter() 80 + .filter_map(|q| match q.qtype { 81 + QType::A | QType::AAAA | QType::TXT | QType::SRV => { 82 + (q.name == self.service.hostname()).then_some((q.qtype, q.qclass)) 83 + } 84 + QType::PTR => (q.name == self.service.service_type()).then_some((q.qtype, q.qclass)), 85 + QType::Any | QType::Unknown(_) => None, 86 + }).collect::<Vec<_>>(); 87 + 88 + if !valid_queries.is_empty() { 89 + self.broadcast(ResponseKind::QueryResponse(valid_queries), req.flags, req.id, req.queries, outgoing) 90 + } else { 91 + None 92 + } 93 + }) 94 + } 95 + }
+193
sachy-mdns/src/service.rs
··· 1 + use core::net::IpAddr; 2 + 3 + use alloc::{ 4 + string::{String, ToString}, 5 + vec::Vec, 6 + }; 7 + 8 + use crate::dns::{ 9 + label::Label, 10 + query::{Answer, QClass}, 11 + records::{A, AAAA, PTR, QType, Record, SRV, TXT}, 12 + }; 13 + 14 + #[derive(Debug, Default)] 15 + #[cfg_attr(feature = "defmt", derive(defmt::Format))] 16 + pub struct Service { 17 + service_type: String, 18 + instance: String, 19 + hostname: String, 20 + ip: Option<IpAddr>, 21 + port: u16, 22 + } 23 + 24 + impl Service { 25 + pub fn new( 26 + service_type: impl Into<String>, 27 + instance: impl Into<String>, 28 + hostname: impl Into<String>, 29 + ip: Option<IpAddr>, 30 + port: u16, 31 + ) -> Self { 32 + let service_type = service_type.into(); 33 + let mut instance = instance.into(); 34 + let mut hostname = hostname.into(); 35 + 36 + instance.push('.'); 37 + instance.push_str(&service_type); 38 + hostname.push_str(".local"); 39 + 40 + Self { 41 + service_type, 42 + instance, 43 + hostname, 44 + ip, 45 + port, 46 + } 47 + } 48 + 49 + pub fn service_type(&self) -> Label<'_> { 50 + Label::from(self.service_type.as_ref()) 51 + } 52 + 53 + pub fn instance(&self) -> Label<'_> { 54 + Label::from(self.instance.as_ref()) 55 + } 56 + 57 + pub fn hostname(&self) -> Label<'_> { 58 + Label::from(self.hostname.as_ref()) 59 + } 60 + 61 + pub fn ip(&self) -> Option<IpAddr> { 62 + self.ip 63 + } 64 + 65 + pub fn port(&self) -> u16 { 66 + self.port 67 + } 68 + 69 + pub(crate) fn ptr_answer(&self, aclass: QClass) -> Option<Answer<'_>> { 70 + Some(Answer { 71 + name: self.service_type(), 72 + atype: QType::PTR, 73 + aclass, 74 + ttl: 4500, 75 + record: Record::PTR(PTR { 76 + name: self.instance(), 77 + }), 78 + }) 79 + } 80 + 81 + pub(crate) fn srv_answer(&self, aclass: QClass) -> Option<Answer<'_>> { 82 + Some(Answer { 83 + name: self.instance(), 84 + atype: QType::SRV, 85 + aclass, 86 + ttl: 120, 87 + record: Record::SRV(SRV { 88 + priority: 0, 89 + weight: 0, 90 + port: self.port, 91 + target: self.hostname(), 92 + }), 93 + }) 94 + } 95 + 96 + pub(crate) fn txt_answer(&self, aclass: QClass) -> Option<Answer<'_>> { 97 + Some(Answer { 98 + name: self.instance(), 99 + atype: QType::TXT, 100 + aclass, 101 + ttl: 120, 102 + record: Record::TXT(TXT { text: Vec::new() }), 103 + }) 104 + } 105 + 106 + pub(crate) fn ip_answer(&self, aclass: QClass) -> Option<Answer<'_>> { 107 + self.ip().map(|address| match address { 108 + IpAddr::V4(address) => Answer { 109 + name: self.hostname(), 110 + atype: QType::A, 111 + aclass, 112 + ttl: 120, 113 + record: Record::A(A { address }), 114 + }, 115 + IpAddr::V6(address) => Answer { 116 + name: self.hostname(), 117 + atype: QType::AAAA, 118 + aclass, 119 + ttl: 120, 120 + record: Record::AAAA(AAAA { address }), 121 + }, 122 + }) 123 + } 124 + 125 + pub(crate) fn as_answers(&self, aclass: QClass) -> impl Iterator<Item = Answer<'_>> { 126 + self.ptr_answer(aclass) 127 + .into_iter() 128 + .chain(self.srv_answer(aclass)) 129 + .chain(self.txt_answer(aclass)) 130 + .chain(self.ip_answer(aclass)) 131 + } 132 + 133 + #[allow(dead_code)] 134 + pub(crate) fn from_answers(answers: &[Answer<'_>]) -> Vec<Self> { 135 + let mut output = Vec::new(); 136 + 137 + // Step 1: Process PTR records 138 + for answer in answers { 139 + if let Record::PTR(ptr) = &answer.record { 140 + let instance = ptr.name.to_string(); 141 + let service_type = answer.name.to_string(); 142 + output.push(Self { 143 + service_type, 144 + instance, 145 + ..Default::default() 146 + }); 147 + } 148 + } 149 + 150 + // Step 2: Process SRV records, A and AAAA records and merge data 151 + for answer in answers { 152 + match &answer.record { 153 + Record::SRV(srv) => { 154 + if let Some(stub) = output 155 + .iter_mut() 156 + .find(|stub| answer.name == stub.instance.as_ref()) 157 + { 158 + stub.hostname = srv.target.to_string(); 159 + stub.port = srv.port; 160 + } 161 + } 162 + Record::A(a) => { 163 + if let Some(stub) = output 164 + .iter_mut() 165 + .find(|stub| answer.name == stub.hostname.as_ref()) 166 + { 167 + stub.ip = Some(IpAddr::V4(a.address)); 168 + } 169 + } 170 + Record::AAAA(aaaa) => { 171 + if let Some(stub) = output 172 + .iter_mut() 173 + .find(|stub| answer.name == stub.hostname.as_ref()) 174 + { 175 + stub.ip = Some(IpAddr::V6(aaaa.address)); 176 + } 177 + } 178 + _ => {} 179 + } 180 + } 181 + 182 + // Final step: Retain only complete services 183 + output.retain(|stub| { 184 + !stub.service_type.is_empty() 185 + && !stub.instance.is_empty() 186 + && !stub.hostname.is_empty() 187 + && stub.ip.is_some() 188 + && stub.port != 0 189 + }); 190 + 191 + output 192 + } 193 + }
+100
sachy-mdns/src/state.rs
··· 1 + use embassy_time::{Duration, Instant}; 2 + use sachy_fmt::{debug, unwrap}; 3 + 4 + #[derive(Debug, Default)] 5 + #[cfg_attr(feature = "defmt", derive(defmt::Format))] 6 + pub(crate) enum MdnsStateMachine { 7 + #[default] 8 + Start, 9 + Announce { 10 + last_sent: Instant, 11 + }, 12 + ListenFor { 13 + last_sent: Instant, 14 + timeout: Duration, 15 + }, 16 + WaitFor { 17 + last_sent: Instant, 18 + duration: Duration, 19 + }, 20 + } 21 + 22 + impl MdnsStateMachine { 23 + /// Set the state to announced, if we have timed out the listening period and need 24 + /// to announce, or if we received a query while listening and have sent a response. 25 + pub(crate) fn announced(&mut self) { 26 + *self = Self::Announce { 27 + last_sent: Instant::now(), 28 + }; 29 + } 30 + 31 + fn next_state(&mut self) { 32 + match self { 33 + Self::Start => self.announced(), 34 + &mut Self::Announce { last_sent } => { 35 + let duration_since = last_sent.elapsed(); 36 + let duration = Duration::from_secs(1) - duration_since; 37 + 38 + *self = Self::WaitFor { 39 + last_sent, 40 + duration, 41 + }; 42 + } 43 + &mut Self::ListenFor { last_sent, .. } => { 44 + let duration_since = last_sent.elapsed(); 45 + let time_limit = Duration::from_secs(120); 46 + 47 + if duration_since >= time_limit { 48 + self.announced(); 49 + } else { 50 + let timeout = time_limit - duration_since; 51 + *self = Self::ListenFor { last_sent, timeout }; 52 + } 53 + } 54 + &mut Self::WaitFor { last_sent, .. } => { 55 + let duration_since = last_sent.elapsed(); 56 + let time_limit = Duration::from_secs(120); 57 + let timeout = time_limit - duration_since; 58 + 59 + *self = Self::ListenFor { last_sent, timeout }; 60 + } 61 + } 62 + } 63 + 64 + pub(crate) fn drive_next_action(&mut self) -> MdnsAction { 65 + self.next_state(); 66 + unwrap!(MdnsAction::try_from(self)) 67 + } 68 + } 69 + 70 + #[derive(Debug)] 71 + #[cfg_attr(feature = "defmt", derive(defmt::Format))] 72 + pub enum MdnsAction { 73 + Announce, 74 + ListenFor { timeout: Duration }, 75 + WaitFor { duration: Duration }, 76 + } 77 + 78 + impl TryFrom<&mut MdnsStateMachine> for MdnsAction { 79 + type Error = MdnsStateMachine; 80 + 81 + fn try_from(value: &mut MdnsStateMachine) -> Result<Self, Self::Error> { 82 + match value { 83 + // We should start in this state, but never remain nor return to it when 84 + // executing our state machine event loop. 85 + MdnsStateMachine::Start => Err(MdnsStateMachine::Start), 86 + MdnsStateMachine::Announce { .. } => { 87 + debug!("ANNOUNCE"); 88 + Ok(Self::Announce) 89 + } 90 + &mut MdnsStateMachine::ListenFor { timeout, .. } => { 91 + debug!("LISTEN FOR {}ms", timeout.as_millis()); 92 + Ok(Self::ListenFor { timeout }) 93 + } 94 + &mut MdnsStateMachine::WaitFor { duration, .. } => { 95 + debug!("WAIT FOR {}ms", duration.as_millis()); 96 + Ok(Self::WaitFor { duration }) 97 + } 98 + } 99 + } 100 + }