From 311122715d2fad4b2a89549cde0cdcc7fe62ae1f Mon Sep 17 00:00:00 2001 From: Sachymetsu Date: Mon, 15 Dec 2025 17:46:29 +0100 Subject: [PATCH] feat: WIP MDNS crate Change-Id: kqmrywxprxtnrwrpspzvrysrvvvwvpzp Adds a mdns resolver/state-machine crate for providing MDNS-SD functionality for an embedded device. --- .tangled/workflows/miri.yml | 21 ++ Cargo.lock | 22 +- Cargo.toml | 1 + sachy-mdns/Cargo.toml | 22 ++ sachy-mdns/src/dns.rs | 6 + sachy-mdns/src/dns/flags.rs | 231 +++++++++++++++ sachy-mdns/src/dns/label.rs | 511 ++++++++++++++++++++++++++++++++++ sachy-mdns/src/dns/query.rs | 228 +++++++++++++++ sachy-mdns/src/dns/records.rs | 413 +++++++++++++++++++++++++++ sachy-mdns/src/dns/reqres.rs | 481 ++++++++++++++++++++++++++++++++ sachy-mdns/src/dns/traits.rs | 21 ++ sachy-mdns/src/encoder.rs | 145 ++++++++++ sachy-mdns/src/lib.rs | 60 ++++ sachy-mdns/src/server.rs | 95 +++++++ sachy-mdns/src/service.rs | 193 +++++++++++++ sachy-mdns/src/state.rs | 100 +++++++ 16 files changed, 2547 insertions(+), 3 deletions(-) create mode 100644 .tangled/workflows/miri.yml create mode 100644 sachy-mdns/Cargo.toml create mode 100644 sachy-mdns/src/dns.rs create mode 100644 sachy-mdns/src/dns/flags.rs create mode 100644 sachy-mdns/src/dns/label.rs create mode 100644 sachy-mdns/src/dns/query.rs create mode 100644 sachy-mdns/src/dns/records.rs create mode 100644 sachy-mdns/src/dns/reqres.rs create mode 100644 sachy-mdns/src/dns/traits.rs create mode 100644 sachy-mdns/src/encoder.rs create mode 100644 sachy-mdns/src/lib.rs create mode 100644 sachy-mdns/src/server.rs create mode 100644 sachy-mdns/src/service.rs create mode 100644 sachy-mdns/src/state.rs diff --git a/.tangled/workflows/miri.yml b/.tangled/workflows/miri.yml new file mode 100644 index 0000000..2f18cd6 --- /dev/null +++ b/.tangled/workflows/miri.yml @@ -0,0 +1,21 @@ +when: + - event: ["push", "pull_request"] + branch: main + +engine: nixery + +dependencies: + nixpkgs: + - clang + - rustup + +steps: + - name: Install Nightly + command: | + rustup toolchain install nightly --component miri + rustup override set nightly + cargo miri setup + - name: Miri Test + command: cargo miri test --locked -p sachy-mdns + environment: + RUSTFLAGS: -Zrandomize-layout diff --git a/Cargo.lock b/Cargo.lock index 7dee9c0..2dad946 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -420,7 +420,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb" dependencies = [ "libc", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1065,7 +1065,7 @@ dependencies = [ "errno", "libc", "linux-raw-sys 0.11.0", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1119,6 +1119,16 @@ dependencies = [ name = "sachy-fnv" version = "0.1.0" +[[package]] +name = "sachy-mdns" +version = "0.1.0" +dependencies = [ + "defmt 1.0.1", + "embassy-time", + "sachy-fmt", + "winnow", +] + [[package]] name = "sachy-shtc3" version = "0.1.0" @@ -1292,7 +1302,7 @@ dependencies = [ "getrandom", "once_cell", "rustix 1.1.2", - "windows-sys 0.52.0", + "windows-sys 0.61.2", ] [[package]] @@ -1483,6 +1493,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "winnow" +version = "0.7.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829" + [[package]] name = "wit-bindgen" version = "0.46.0" diff --git a/Cargo.toml b/Cargo.toml index 2ab4e4b..7f114a4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,6 +6,7 @@ members = [ "sachy-esphome", "sachy-fmt", "sachy-fnv", + "sachy-mdns", "sachy-shtc3", "sachy-sntp", ] diff --git a/sachy-mdns/Cargo.toml b/sachy-mdns/Cargo.toml new file mode 100644 index 0000000..f0ee517 --- /dev/null +++ b/sachy-mdns/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "sachy-mdns" +authors.workspace = true +edition.workspace = true +repository.workspace = true +license.workspace = true +version.workspace = true +rust-version.workspace = true + +[dependencies] +defmt = { workspace = true, optional = true, features = ["alloc"] } +embassy-time = { workspace = true } +sachy-fmt = { path = "../sachy-fmt" } +winnow = { version = "0.7.12", default-features = false } + +[features] +default = [] +std = [] +defmt = ["dep:defmt"] + +[dev-dependencies] +winnow = { version = "0.7.12", default-features = false, features = ["alloc"] } diff --git a/sachy-mdns/src/dns.rs b/sachy-mdns/src/dns.rs new file mode 100644 index 0000000..089dc8a --- /dev/null +++ b/sachy-mdns/src/dns.rs @@ -0,0 +1,6 @@ +pub(crate) mod flags; +pub(crate) mod label; +pub(crate) mod query; +pub(crate) mod records; +pub(crate) mod reqres; +pub mod traits; diff --git a/sachy-mdns/src/dns/flags.rs b/sachy-mdns/src/dns/flags.rs new file mode 100644 index 0000000..c233b99 --- /dev/null +++ b/sachy-mdns/src/dns/flags.rs @@ -0,0 +1,231 @@ +#![allow(dead_code)] + +use core::{convert::Infallible, fmt}; +use winnow::{ModalResult, Parser, binary::be_u16}; + +use crate::{ + dns::traits::{DnsParse, DnsSerialize}, + encoder::Encoder, +}; + +#[derive(Default, Clone, Copy, PartialEq, Eq)] +pub struct Flags(pub u16); + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u8)] +pub enum Opcode { + Query = 0, + IQuery = 1, + Status = 2, + Reserved = 3, + Notify = 4, + Update = 5, + // Other values are reserved +} + +impl Opcode { + const fn cast(value: u8) -> Self { + match value { + 0 => Opcode::Query, + 1 => Opcode::IQuery, + 2 => Opcode::Status, + 4 => Opcode::Notify, + 5 => Opcode::Update, + _ => Opcode::Reserved, + } + } +} + +impl From for Opcode { + fn from(value: u8) -> Self { + Self::cast(value) + } +} + +impl From for u8 { + fn from(opcode: Opcode) -> Self { + opcode as u8 + } +} + +impl Flags { + const fn new() -> Self { + Flags(0) + } + + pub const fn standard_request() -> Self { + let mut flags = Flags::new(); + flags.set_query(true); + flags.set_opcode(Opcode::Query); + flags.set_recursion_desired(true); + flags + } + + pub const fn standard_response() -> Self { + let mut flags = Flags::new(); + flags.set_query(false); + flags.set_opcode(Opcode::Query); + flags.set_authoritative(true); + flags.set_recursion_available(false); + flags + } + + // QR: Query/Response Flag + pub const fn is_query(&self) -> bool { + (self.0 & 0x8000) == 0 + } + + pub const fn set_query(&mut self, is_query: bool) { + if is_query { + self.0 &= !0x8000; + } else { + self.0 |= 0x8000; + } + } + + // Opcode (bits 1-4) + pub const fn get_opcode(&self) -> Opcode { + Opcode::cast(((self.0 >> 11) & 0x0F) as u8) + } + + pub const fn set_opcode(&mut self, opcode: Opcode) { + self.0 = (self.0 & !0x7800) | (((opcode as u8) as u16 & 0x0F) << 11); + } + + // AA: Authoritative Answer + pub const fn is_authoritative(&self) -> bool { + (self.0 & 0x0400) != 0 + } + + pub const fn set_authoritative(&mut self, authoritative: bool) { + if authoritative { + self.0 |= 0x0400; + } else { + self.0 &= !0x0400; + } + } + + // TC: Truncated + pub const fn is_truncated(&self) -> bool { + (self.0 & 0x0200) != 0 + } + + pub const fn set_truncated(&mut self, truncated: bool) { + if truncated { + self.0 |= 0x0200; + } else { + self.0 &= !0x0200; + } + } + + // RD: Recursion Desired + pub const fn is_recursion_desired(&self) -> bool { + (self.0 & 0x0100) != 0 + } + + pub const fn set_recursion_desired(&mut self, recursion_desired: bool) { + if recursion_desired { + self.0 |= 0x0100; + } else { + self.0 &= !0x0100; + } + } + + // RA: Recursion Available + pub const fn is_recursion_available(&self) -> bool { + (self.0 & 0x0080) != 0 + } + + pub const fn set_recursion_available(&mut self, recursion_available: bool) { + if recursion_available { + self.0 |= 0x0080; + } else { + self.0 &= !0x0080; + } + } + + // Z: Reserved for future use (bits 9-11) + pub const fn get_reserved(&self) -> u8 { + ((self.0 >> 4) & 0x07) as u8 + } + + pub const fn set_reserved(&mut self, reserved: u8) { + self.0 = (self.0 & !0x0070) | ((reserved as u16 & 0x07) << 4); + } + + // RCODE: Response Code (bits 12-15) + pub const fn get_rcode(&self) -> u8 { + (self.0 & 0x000F) as u8 + } + + pub const fn set_rcode(&mut self, rcode: u8) { + self.0 = (self.0 & !0x000F) | (rcode as u16 & 0x0F); + } +} + +impl<'a> DnsParse<'a> for Flags { + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult { + be_u16.map(Flags).parse_next(input) + } +} + +impl<'a> DnsSerialize<'a> for Flags { + type Error = Infallible; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + encoder.write(&self.0.to_be_bytes()); + Ok(()) + } + + fn size(&self) -> usize { + core::mem::size_of::() + } +} + +impl fmt::Debug for Flags { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("Flags") + .field("query", &self.is_query()) + .field("opcode", &self.get_opcode()) + .field("authoritative", &self.is_authoritative()) + .field("truncated", &self.is_truncated()) + .field("recursion_desired", &self.is_recursion_desired()) + .field("recursion_available", &self.is_recursion_available()) + .field("reserved", &self.get_reserved()) + .field("rcode", &self.get_rcode()) + .finish() + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Flags { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "Flags {{ query: {}, opcode: {:?}, authoritative: {}, truncated: {}, recursion_desired: {}, recursion_available: {}, reserved: {}, rcode: {} }}", + self.is_query(), + self.get_opcode(), + self.is_authoritative(), + self.is_truncated(), + self.is_recursion_desired(), + self.is_recursion_available(), + self.get_reserved(), + self.get_rcode() + ); + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Opcode { + fn format(&self, fmt: defmt::Formatter) { + let opcode_str = match self { + Opcode::Query => "Query", + Opcode::IQuery => "IQuery", + Opcode::Status => "Status", + Opcode::Reserved => "Reserved", + Opcode::Notify => "Notify", + Opcode::Update => "Update", + }; + defmt::write!(fmt, "Opcode({=str})", opcode_str); + } +} diff --git a/sachy-mdns/src/dns/label.rs b/sachy-mdns/src/dns/label.rs new file mode 100644 index 0000000..a772cb1 --- /dev/null +++ b/sachy-mdns/src/dns/label.rs @@ -0,0 +1,511 @@ +use core::{fmt, str}; + +use winnow::{ + ModalResult, Parser, + binary::be_u8, + error::{ContextError, ErrMode, FromExternalError}, + stream::Offset, + token::take, +}; + +use crate::{ + dns::traits::{DnsParse, DnsSerialize}, + encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK}, +}; + +#[derive(Clone, Copy)] +pub struct Label<'a> { + repr: LabelRepr<'a>, +} + +impl<'a> From<&'a str> for Label<'a> { + fn from(value: &'a str) -> Self { + Self { + repr: LabelRepr::Str(value), + } + } +} + +impl<'a> DnsSerialize<'a> for Label<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + match self.repr { + LabelRepr::Bytes { + context, + start, + end, + } => { + encoder.write(&context[start..end]); + + Ok(()) + } + LabelRepr::Str(label) => encoder.write_label(label), + } + } + + fn size(&self) -> usize { + match self.repr { + LabelRepr::Bytes { + context, + start, + end, + } => core::mem::size_of_val(&context[start..end]), + LabelRepr::Str(label) => core::mem::size_of_val(label) + 1, + } + } +} + +impl<'a> DnsParse<'a> for Label<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let start = input.offset_from(&context); + let mut end = start; + + loop { + match LabelSegment::parse(input)? { + LabelSegment::Empty => { + end += 1; + break; + } + LabelSegment::String(label) => { + end += 1 + label.len(); + } + LabelSegment::Pointer(_) => { + end += 2; + break; + } + } + } + + Ok(Self { + repr: LabelRepr::Bytes { + context, + start, + end, + }, + }) + } +} + +impl Label<'_> { + pub fn segments(&self) -> impl Iterator> { + self.repr.iter() + } + + pub fn names(&self) -> impl Iterator { + match self.repr { + LabelRepr::Str(view) => Either::A(view.split('.')), + LabelRepr::Bytes { context, start, .. } => Either::B( + LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()), + ), + } + } + + pub fn is_empty(&self) -> bool { + self.repr.iter().next().is_none() + } +} + +#[derive(Clone, Copy, PartialEq, Eq)] +enum LabelRepr<'a> { + Bytes { + context: &'a [u8], + start: usize, + end: usize, + }, + Str(&'a str), +} + +/// A DNS-compatible label segment. +#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)] +pub enum LabelSegment<'a> { + /// The empty terminator. + Empty, + + /// A string label. + String(&'a str), + + /// A pointer to a previous name. + Pointer(u16), +} + +impl<'a> LabelSegment<'a> { + fn parse(input: &mut &'a [u8]) -> ModalResult { + let b1 = be_u8(input)?; + + match b1 { + 0 => Ok(Self::Empty), + b1 if b1 & PTR_MASK == PTR_MASK => { + let b2 = be_u8(input)?; + + let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); + + Ok(Self::Pointer(ptr)) + } + len => { + if len > MAX_STR_LEN { + return Err(ErrMode::Cut(ContextError::from_external_error( + input, DnsError::LabelTooLong, + ))); + } + + let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?; + + Ok(Self::String(segment)) + } + } + } + + /// ## Safety + /// The caller upholds that this function is not called when parsing from newly received data. Data that + /// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`], + /// and that the entire data/context has been validated, not just a portion of it. + #[inline] + unsafe fn parse_unchecked(input: &'a [u8]) -> Option { + input.split_first().map(|(b1, input)| match *b1 { + 0 => Self::Empty, + b1 if b1 & PTR_MASK == PTR_MASK => { + // SAFETY: The caller has already validated that a second byte is available for a + // Pointer segment. + let b2 = unsafe { *input.get_unchecked(0) }; + + let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]); + + Self::Pointer(ptr) + } + len => { + // SAFETY: The caller has validated that this length value is correct and will only + // access within the bounds of the provided slice. + let segment = unsafe { input.get_unchecked(0..(len as usize)) }; + // SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before. + let segment = unsafe { core::str::from_utf8_unchecked(segment) }; + + Self::String(segment) + } + }) + } + + fn as_str(&self) -> Option<&'a str> { + match self { + Self::String(label) => Some(*label), + _ => None, + } + } +} + +pub struct LabelSegmentBytesIter<'a> { + context: &'a [u8], + start: usize, +} + +impl<'a> LabelSegmentBytesIter<'a> { + pub(crate) fn new(context: &'a [u8], start: usize) -> Self { + Self { context, start } + } +} + +impl<'a> Iterator for LabelSegmentBytesIter<'a> { + type Item = LabelSegment<'a>; + + fn next(&mut self) -> Option { + loop { + let view = &self.context[self.start..]; + + // SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes + let segment = unsafe { LabelSegment::parse_unchecked(view)? }; + + match segment { + LabelSegment::String(label) => { + self.start = self.start.saturating_add(label.len() + 1); + return Some(segment); + } + LabelSegment::Pointer(ptr) => { + self.start = ptr as usize; + } + LabelSegment::Empty => { + // Set the index offset to be len() so that the view is empty and terminates the loop + self.start = self.context.len(); + return Some(LabelSegment::Empty); + } + } + } + } +} + +impl<'a> LabelRepr<'a> { + fn iter(&self) -> impl Iterator> { + match *self { + LabelRepr::Bytes { context, start, .. } => { + Either::A(LabelSegmentBytesIter::new(context, start)) + } + LabelRepr::Str(view) => Either::B( + view.split('.') + .map(LabelSegment::String) + .chain(Some(LabelSegment::Empty)), + ), + } + } +} + +impl fmt::Debug for Label<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + struct LabelFmt<'a>(&'a Label<'a>); + + impl fmt::Debug for LabelFmt<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + fmt::Display::fmt(self.0, f) + } + } + + f.debug_tuple("Label").field(&LabelFmt(self)).finish() + } +} + +impl fmt::Display for Label<'_> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let mut names = self.names(); + + if let Some(name) = names.next() { + f.write_str(name)?; + + names.try_for_each(|name| { + f.write_str(".")?; + f.write_str(name) + }) + } else { + Ok(()) + } + } +} + +impl<'a, 'b> PartialEq> for Label<'b> { + fn eq(&self, other: &Label<'a>) -> bool { + self.segments().eq(other.segments()) + } +} + +impl Eq for Label<'_> {} + +impl PartialEq<&str> for Label<'_> { + fn eq(&self, other: &&str) -> bool { + let mut self_iter = self.names(); + let mut other_iter = other.split('.'); + + loop { + match (self_iter.next(), other_iter.next()) { + (Some(self_part), Some(other_part)) => { + if self_part != other_part { + return false; + } + } + (None, None) => return true, + _ => return false, + } + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Label<'_> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "Label("); + let mut iter = self.names(); + if let Some(first) = iter.next() { + defmt::write!(fmt, "{}", first); + + iter.for_each(|part| defmt::write!(fmt, ".{}", part)); + } + defmt::write!(fmt, ")"); + } +} + +/// One iterator or another. +enum Either { + A(A), + B(B), +} + +impl> Iterator for Either { + type Item = A::Item; + + fn next(&mut self) -> Option { + match self { + Either::A(a) => a.next(), + Either::B(b) => b.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + Either::A(a) => a.size_hint(), + Either::B(b) => b.size_hint(), + } + } + + fn fold(self, init: B, f: F) -> B + where + Self: Sized, + F: FnMut(B, Self::Item) -> B, + { + match self { + Either::A(a) => a.fold(init, f), + Either::B(b) => b.fold(init, f), + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn segments_iter_test() { + let label: Label<'static> = Label::from("_service._udp.local"); + let mut segments = label.segments(); + + assert_eq!(segments.next(), Some(LabelSegment::String("_service"))); + assert_eq!(segments.next(), Some(LabelSegment::String("_udp"))); + assert_eq!(segments.next(), Some(LabelSegment::String("local"))); + assert_eq!(segments.next(), Some(LabelSegment::Empty)); + assert_eq!(segments.next(), None); + + // example.com with a pointer to the start + let data = b"\x07example\x03com\x00\xC0\x00"; + let context = &data[..]; + // The data here is entirely valid, even though we parse only a portion of it. + let label = Label::parse(&mut &data[13..], context).unwrap(); + + let mut segments = label.segments(); + assert_eq!(segments.next(), Some(LabelSegment::String("example"))); + assert_eq!(segments.next(), Some(LabelSegment::String("com"))); + assert_eq!(segments.next(), Some(LabelSegment::Empty)); + assert_eq!(segments.next(), None); + } + + #[test] + fn names_iter_test() { + let label: Label<'static> = Label::from("_service._udp.local"); + let mut names = label.names(); + + assert_eq!(names.next(), Some("_service")); + assert_eq!(names.next(), Some("_udp")); + assert_eq!(names.next(), Some("local")); + assert_eq!(names.next(), None); + + let data = b"\x07example\x03com\x00\xC0\x00"; + let context = &data[..]; + // The data here is entirely valid, even though we parse only a portion of it. + let label = Label::parse(&mut &data[13..], context).unwrap(); + + let mut names = label.names(); + assert_eq!(names.next(), Some("example")); + assert_eq!(names.next(), Some("com")); + assert_eq!(names.next(), None); + } + + #[test] + fn serialize_str_label() { + let label: Label<'static> = Label::from("_service._udp.local"); + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + label.serialize(&mut buffer).unwrap(); + assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00"); + } + + #[test] + fn serialize_compressed_str_label() { + let label: Label<'static> = Label::from("_service._udp.local"); + let label2: Label<'static> = Label::from("other._udp.local"); + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + label.serialize(&mut buffer).unwrap(); + label2.serialize(&mut buffer).unwrap(); + assert_eq!( + buffer.finish(), + b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" + ); + } + + #[test] + fn round_trip_compressed_str_label() { + let label: Label<'static> = Label::from("_service._udp.local"); + let label2: Label<'static> = Label::from("other._udp.local"); + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + label.serialize(&mut buffer).unwrap(); + label2.serialize(&mut buffer).unwrap(); + let context = buffer.finish(); + assert_eq!( + context, + b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09" + ); + let view = &mut &context[..]; + + let parsed_label = Label::parse(view, context).unwrap(); + let parsed_label2 = Label::parse(view, context).unwrap(); + + let parsed_label_count = parsed_label.segments().count(); + let parsed_label2_count = parsed_label2.segments().count(); + + // Both have same amount of segments. + assert_eq!(parsed_label_count, parsed_label2_count); + + assert_eq!(parsed_label, label); + assert_eq!(parsed_label2, label2); + } + + #[test] + fn label_byte_repr_serialization_quick_path() { + let data = b"\x07example\x03com\x00\xC0\x00"; + let context = &data[..]; + // The data here is entirely valid, even though we parse only a portion of it. + let label = Label::parse(&mut &data[13..], context).unwrap(); + + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + label.serialize(&mut buffer).unwrap(); + // If the original Label is just a pointer, the new output will be a pointer, assuming + // the original data is also present in the output + assert_eq!(buffer.finish(), b"\xC0\x00"); + } + + #[test] + fn parse_and_eq_created_label() { + let data = b"\x07example\x03com\x00\xC0\x00"; + let context = &data[..]; + + // The data here is entirely valid, even though we parse only a portion of it. + let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); + + let created_label = Label::from("example.com"); + + assert_eq!(parsed_label, created_label); + } + + #[test] + fn parse_and_eq_label_with_str() { + let data = b"\x07example\x03com\x00"; + let context = &data[..]; + + let parsed_label = Label::parse(&mut &data[..], context).unwrap(); + + assert_eq!(parsed_label, "example.com"); + } + + #[test] + fn parse_ptr_label_and_eq_with_str() { + let data = b"\x07example\x03com\x00\xC0\x00"; + let context = &data[..]; + + // The data here is entirely valid, even though we parse only a portion of it. + let parsed_label = Label::parse(&mut &data[13..], context).unwrap(); + + assert_eq!(parsed_label, "example.com"); + } + + #[test] + fn label_new_without_dot_is_not_empty() { + let label: Label = Label::from("example"); + assert!(!label.is_empty()); + } +} diff --git a/sachy-mdns/src/dns/query.rs b/sachy-mdns/src/dns/query.rs new file mode 100644 index 0000000..1190586 --- /dev/null +++ b/sachy-mdns/src/dns/query.rs @@ -0,0 +1,228 @@ +use winnow::binary::{be_u16, be_u32}; +use winnow::{ModalResult, Parser}; + +use super::label::Label; +use super::records::Record; +use crate::encoder::Encoder; +use crate::{ + dns::{ + records::QType, + traits::{DnsParse, DnsParseKind, DnsSerialize}, + }, + encoder::DnsError, +}; + +#[derive(Debug, PartialEq, Eq)] +pub struct Query<'a> { + pub name: Label<'a>, + pub qtype: QType, + pub qclass: QClass, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +pub enum QClass { + IN = 1, + Multicast = 32769, // (IN + Cache flush bit) + Unknown(u16), +} + +impl<'a> DnsParse<'a> for Query<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let name = Label::parse(input, context)?; + let qtype = QType::parse(input, context)?; + let qclass = be_u16.map(QClass::from_u16).parse_next(input)?; + + Ok(Query { + name, + qtype, + qclass, + }) + } +} + +impl<'a> DnsSerialize<'a> for Query<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + self.name.serialize(encoder)?; + self.qtype.serialize(encoder).ok(); + encoder.write(&self.qclass.to_u16().to_be_bytes()); + Ok(()) + } + + fn size(&self) -> usize { + self.name.size() + self.qtype.size() + core::mem::size_of::() + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Answer<'a> { + pub name: Label<'a>, + pub atype: QType, + pub aclass: QClass, + pub ttl: u32, + pub record: Record<'a>, +} + +impl QClass { + fn from_u16(value: u16) -> Self { + match value { + 1 => QClass::IN, + 32769 => QClass::Multicast, + _ => QClass::Unknown(value), + } + } + + fn to_u16(self) -> u16 { + match self { + QClass::IN => 1, + QClass::Multicast => 32769, + QClass::Unknown(value) => value, + } + } +} + +impl<'a> DnsParse<'a> for Answer<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let name = Label::parse(input, context)?; + let atype = QType::parse(input, context)?; + let aclass = be_u16.map(QClass::from_u16).parse_next(input)?; + + let ttl = be_u32.parse_next(input)?; + let record = atype.parse_kind(input, context)?; + + Ok(Answer { + name, + atype, + aclass, + ttl, + record, + }) + } +} + +impl<'a> DnsSerialize<'a> for Answer<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + self.name.serialize(encoder)?; + self.atype.serialize(encoder).ok(); + encoder.write(&self.aclass.to_u16().to_be_bytes()); + encoder.write(&self.ttl.to_be_bytes()); + self.record.serialize(encoder) + } + + fn size(&self) -> usize { + self.name.size() + + self.atype.size() + + core::mem::size_of::() + + core::mem::size_of::() + + self.record.size() + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Query<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "Query {{ name: {:?}, qtype: {:?}, qclass: {:?} }}", + self.name, + self.qtype, + self.qclass + ); + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for QType { + fn format(&self, fmt: defmt::Formatter) { + let qtype_str = match self { + QType::A => "A", + QType::AAAA => "AAAA", + QType::PTR => "PTR", + QType::TXT => "TXT", + QType::SRV => "SRV", + QType::Any => "Any", + QType::Unknown(_) => "Unknown", + }; + defmt::write!(fmt, "QType({=str})", qtype_str); + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for QClass { + fn format(&self, fmt: defmt::Formatter) { + let qclass_str = match self { + QClass::IN => "IN", + QClass::Multicast => "Multicast", + QClass::Unknown(_) => "Unknown", + }; + defmt::write!(fmt, "QClass({=str})", qclass_str); + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Answer<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "Answer {{ name: {:?}, atype: {:?}, aclass: {:?}, ttl: {}, record: {:?} }}", + self.name, + self.atype, + self.aclass, + self.ttl, + self.record + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::dns::records::A; + use core::net::Ipv4Addr; + + #[test] + fn roundtrip_query() { + let name = Label::from("example.local"); + + let query = Query { + name, + qtype: QType::A, + qclass: QClass::IN, + }; + + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + query.serialize(&mut buffer).unwrap(); + let buffer = buffer.finish(); + let parsed_query = Query::parse(&mut &buffer[..], buffer).unwrap(); + + assert_eq!(query, parsed_query); + } + + #[test] + fn roundtrip_answer() { + let name = Label::from("example.local"); + + let answer: Answer = Answer { + name, + atype: QType::A, + aclass: QClass::IN, + ttl: 120, + record: Record::A(A { + address: Ipv4Addr::new(192, 168, 1, 1), + }), + }; + + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + answer.serialize(&mut buffer).unwrap(); + let buffer = buffer.finish(); + let parsed_answer = Answer::parse(&mut &buffer[..], buffer).unwrap(); + + assert_eq!(answer, parsed_answer); + } +} diff --git a/sachy-mdns/src/dns/records.rs b/sachy-mdns/src/dns/records.rs new file mode 100644 index 0000000..c133c6d --- /dev/null +++ b/sachy-mdns/src/dns/records.rs @@ -0,0 +1,413 @@ +use core::{ + convert::Infallible, + net::{Ipv4Addr, Ipv6Addr}, + str, +}; + +use alloc::vec::Vec; +use winnow::token::take; +use winnow::{ModalResult, Parser}; +use winnow::{binary::be_u8, error::ContextError}; +use winnow::{binary::be_u16, error::FromExternalError}; + +use super::label::Label; +use crate::{ + dns::traits::{DnsParse, DnsParseKind, DnsSerialize}, + encoder::{DnsError, Encoder}, +}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[repr(u16)] +#[allow(clippy::upper_case_acronyms)] +pub enum QType { + A = 1, + AAAA = 28, + PTR = 12, + TXT = 16, + SRV = 33, + Any = 255, + Unknown(u16), +} + +impl<'a> DnsParse<'a> for QType { + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult { + be_u16.map(QType::from_u16).parse_next(input) + } +} + +impl<'a> DnsSerialize<'a> for QType { + type Error = Infallible; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + encoder.write(&self.to_u16().to_be_bytes()); + Ok(()) + } + + fn size(&self) -> usize { + core::mem::size_of::() + } +} + +impl<'a> DnsParseKind<'a> for QType { + type Output = Record<'a>; + + fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + match self { + QType::A => { + let record = A::parse(input, context)?; + Ok(Record::A(record)) + } + QType::AAAA => { + let record = AAAA::parse(input, context)?; + Ok(Record::AAAA(record)) + } + QType::PTR => { + let record = PTR::parse(input, context)?; + Ok(Record::PTR(record)) + } + QType::TXT => { + let record = TXT::parse(input, context)?; + Ok(Record::TXT(record)) + } + QType::SRV => { + let record = SRV::parse(input, context)?; + Ok(Record::SRV(record)) + } + QType::Any => Err(winnow::error::ErrMode::Backtrack( + ContextError::from_external_error(input, DnsError::Unsupported), + )), + QType::Unknown(_) => Err(winnow::error::ErrMode::Backtrack( + ContextError::from_external_error(input, DnsError::Unsupported), + )), + } + } +} + +impl QType { + fn from_u16(value: u16) -> Self { + match value { + 1 => QType::A, + 28 => QType::AAAA, + 12 => QType::PTR, + 16 => QType::TXT, + 33 => QType::SRV, + 255 => QType::Any, + _ => QType::Unknown(value), + } + } + + fn to_u16(self) -> u16 { + match self { + QType::A => 1, + QType::AAAA => 28, + QType::PTR => 12, + QType::TXT => 16, + QType::SRV => 33, + QType::Any => 255, + QType::Unknown(value) => value, + } + } +} + +#[derive(Debug, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +// Enum for DNS-SD records +pub enum Record<'a> { + A(A), + AAAA(AAAA), + PTR(PTR<'a>), + TXT(TXT<'a>), + SRV(SRV<'a>), +} + +impl<'a> DnsSerialize<'a> for Record<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + match self { + Record::A(record) => { + record.serialize(encoder).ok(); + } + Record::AAAA(record) => { + record.serialize(encoder).ok(); + } + Record::PTR(record) => { + record.serialize(encoder)?; + } + Record::TXT(record) => { + record.serialize(encoder).ok(); + } + Record::SRV(record) => { + record.serialize(encoder)?; + } + }; + + Ok(()) + } + + fn size(&self) -> usize { + match self { + Self::A(a) => a.size(), + Self::AAAA(aaaa) => aaaa.size(), + Self::PTR(ptr) => ptr.size(), + Self::TXT(txt) => txt.size(), + Self::SRV(srv) => srv.size(), + } + } +} + +// Struct for A record +#[derive(Debug, PartialEq, Eq)] +pub struct A { + pub address: Ipv4Addr, +} + +impl<'a> DnsParse<'a> for A { + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult { + let len = be_u16.parse_next(input)?; + let address = take(len) + .try_map(<[u8; 4]>::try_from) + .map(Ipv4Addr::from) + .parse_next(input)?; + + Ok(A { address }) + } +} + +impl<'a> DnsSerialize<'a> for A { + type Error = Infallible; + + fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> { + let len = 4u16.to_be_bytes(); + writer.write(&len); + writer.write(&self.address.octets()); + Ok(()) + } + + fn size(&self) -> usize { + core::mem::size_of::() + core::mem::size_of::() + } +} + +// Struct for AAAA record +#[derive(Debug, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +pub struct AAAA { + pub address: Ipv6Addr, +} + +impl<'a> DnsParse<'a> for AAAA { + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult { + let len = be_u16.parse_next(input)?; + let address = take(len) + .try_map(<[u8; 16]>::try_from) + .map(Ipv6Addr::from) + .parse_next(input)?; + + Ok(AAAA { address }) + } +} + +impl<'a> DnsSerialize<'a> for AAAA { + type Error = Infallible; + + fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> { + let len = 16u16.to_be_bytes(); + writer.write(&len); + writer.write(&self.address.octets()); + Ok(()) + } + + fn size(&self) -> usize { + core::mem::size_of::() + core::mem::size_of::() + } +} + +// Struct for PTR record +#[derive(Debug, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +pub struct PTR<'a> { + pub name: Label<'a>, +} + +impl<'a> DnsParse<'a> for PTR<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let _ = be_u16.parse_next(input)?; + let name = Label::parse(input, context)?; + Ok(PTR { name }) + } +} + +impl<'a> DnsSerialize<'a> for PTR<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + encoder.with_record_length(|enc| self.name.serialize(enc)) + } + + fn size(&self) -> usize { + self.name.size() + core::mem::size_of::() + } +} + +// Struct for TXT record +#[derive(Debug, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +pub struct TXT<'a> { + pub text: Vec<&'a str>, +} + +impl<'a> DnsParse<'a> for TXT<'a> { + fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult { + let text_len = be_u16.parse_next(input)?; + + let mut total = 0u16; + let mut text = Vec::new(); + + while total < text_len { + let len = be_u8(input)?; + + total += 1 + len as u16; + + if len > 0 { + let part = take(len).try_map(core::str::from_utf8).parse_next(input)?; + text.push(part); + } + } + + Ok(TXT { text }) + } +} + +impl<'a> DnsSerialize<'a> for TXT<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + encoder.with_record_length(|enc| { + self.text.iter().try_for_each(|&part| { + let text_len = u8::try_from(part.len()) + .map_err(|_| DnsError::InvalidTxt) + .map(u8::to_be_bytes)?; + + enc.write(&text_len); + enc.write(part.as_bytes()); + + Ok(()) + }) + }) + } + + fn size(&self) -> usize { + let len_size = core::mem::size_of::(); + + let text_size = if self.text.is_empty() { + 1 + } else { + self.text.iter().map(|part| part.len() + 1).sum() + }; + + len_size + text_size + } +} + +// Struct for SRV record +#[derive(Debug, PartialEq, Eq)] +#[allow(clippy::upper_case_acronyms)] +pub struct SRV<'a> { + pub priority: u16, + pub weight: u16, + pub port: u16, + pub target: Label<'a>, +} + +impl<'a> DnsParse<'a> for SRV<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let _ = be_u16.parse_next(input)?; + let priority = be_u16.parse_next(input)?; + let weight = be_u16.parse_next(input)?; + let port = be_u16.parse_next(input)?; + let target = Label::parse(input, context)?; + + Ok(SRV { + priority, + weight, + port, + target, + }) + } +} + +impl<'a> DnsSerialize<'a> for SRV<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + encoder.with_record_length(|enc| { + enc.write(&self.priority.to_be_bytes()); + enc.write(&self.weight.to_be_bytes()); + enc.write(&self.port.to_be_bytes()); + + self.target.serialize(enc) + }) + } + + fn size(&self) -> usize { + (core::mem::size_of::() * 4) + self.target.size() + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for A { + fn format(&self, fmt: defmt::Formatter) { + // use crate::format::FormatIpv4Addr; + defmt::write!(fmt, "A({})", self.address) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for AAAA { + fn format(&self, fmt: defmt::Formatter) { + // use crate::format::FormatIpv6Addr; + defmt::write!(fmt, "AAAA({})", self.address) + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Record<'a> { + fn format(&self, fmt: defmt::Formatter) { + match self { + Record::A(record) => defmt::write!(fmt, "Record::A({:?})", record), + Record::AAAA(record) => defmt::write!(fmt, "Record::AAAA({:?})", record), + Record::PTR(record) => defmt::write!(fmt, "Record::PTR({:?})", record), + Record::TXT(record) => defmt::write!(fmt, "Record::TXT({:?})", record), + Record::SRV(record) => defmt::write!(fmt, "Record::SRV({:?})", record), + } + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for PTR<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "PTR {{ name: {:?} }}", self.name); + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for TXT<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "TXT {{ text: {:?} }}", self.text); + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for SRV<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "SRV {{ priority: {}, weight: {}, port: {}, target: {:?} }}", + self.priority, + self.weight, + self.port, + self.target + ); + } +} diff --git a/sachy-mdns/src/dns/reqres.rs b/sachy-mdns/src/dns/reqres.rs new file mode 100644 index 0000000..987c350 --- /dev/null +++ b/sachy-mdns/src/dns/reqres.rs @@ -0,0 +1,481 @@ +use alloc::vec::Vec; +use winnow::ModalResult; +use winnow::binary::be_u16; + +use super::flags::Flags; +use super::query::{Answer, Query}; +use crate::{ + dns::traits::{DnsParse, DnsSerialize}, + encoder::{DnsError, Encoder}, +}; + +const ZERO_U16: [u8; 2] = 0u16.to_be_bytes(); + +#[derive(Debug, PartialEq, Eq)] +pub struct Request<'a> { + pub id: u16, + pub flags: Flags, + pub(crate) queries: Vec>, +} + +impl<'a> DnsParse<'a> for Request<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let id = be_u16(input)?; + let flags = Flags::parse(input, context)?; + let qdcount = be_u16(input)?; + let _ancount = be_u16(input)?; + let _nscount = be_u16(input)?; + let _arcount = be_u16(input)?; + let queries = (0..qdcount) + .map(|_| Query::parse(input, context)) + .collect::, _>>()?; + Ok(Request { id, flags, queries }) + } +} + +impl<'a> DnsSerialize<'a> for Request<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + writer.write(&self.id.to_be_bytes()); + self.flags.serialize(writer).ok(); + writer.write(&(self.queries.len() as u16).to_be_bytes()); + writer.write(&ZERO_U16); + writer.write(&ZERO_U16); + writer.write(&ZERO_U16); + + self.queries + .iter() + .try_for_each(|query| query.serialize(writer)) + } + + fn size(&self) -> usize { + let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); + + core::mem::size_of::() + + self.flags.size() + + (core::mem::size_of::() * 4) + + total_query_size + } +} + +#[derive(Debug, PartialEq, Eq)] +pub struct Response<'a> { + pub id: u16, + pub flags: Flags, + pub queries: Vec>, + pub answers: Vec>, +} + +impl<'a> DnsParse<'a> for Response<'a> { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult { + let id = be_u16(input)?; + let flags = Flags::parse(input, context)?; + let qdcount = be_u16(input)?; + let ancount = be_u16(input)?; + let _nscount = be_u16(input)?; + let _arcount = be_u16(input)?; + + let queries = (0..qdcount) + .map(|_| Query::parse(input, context)) + .collect::, _>>()?; + + let answers = (0..ancount) + .map(|_| Answer::parse(input, context)) + .collect::, _>>()?; + + Ok(Response { + id, + flags, + queries, + answers, + }) + } +} + +impl<'a> DnsSerialize<'a> for Response<'a> { + type Error = DnsError; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> { + encoder.write(&self.id.to_be_bytes()); + self.flags.serialize(encoder).ok(); + encoder.write(&(self.queries.len() as u16).to_be_bytes()); + encoder.write(&(self.answers.len() as u16).to_be_bytes()); + encoder.write(&ZERO_U16); + encoder.write(&ZERO_U16); + + self.queries + .iter() + .try_for_each(|query| query.serialize(encoder))?; + self.answers + .iter() + .try_for_each(|answer| answer.serialize(encoder)) + } + + fn size(&self) -> usize { + let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum(); + let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum(); + + core::mem::size_of::() + + self.flags.size() + + (core::mem::size_of::() * 4) + + total_query_size + + total_answer_size + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Request<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "Request {{ id: {}, flags: {:?}, queries: {:?} }}", + self.id, + self.flags, + self.queries + ); + } +} + +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Response<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}", + self.id, + self.flags, + self.queries, + self.answers + ); + } +} + +#[cfg(test)] +mod tests { + use alloc::vec; + + use super::*; + use crate::dns::{ + label::Label, + query::QClass, + records::{A, PTR, QType, Record, SRV, TXT}, + }; + use core::net::Ipv4Addr; + + #[test] + fn parse_query() { + let data = [ + 0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65, + // example . com in label format + 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // + // + 0x00, 0x01, 0x00, 0x01, + ]; + + let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap(); + + assert_eq!(request.id, 0xAAAA); + assert_eq!(request.flags.0, 0x0100); + assert_eq!(request.queries.len(), 1); + assert_eq!(request.queries[0].name, "example.com"); + assert_eq!(request.queries[0].qtype, QType::A); + assert_eq!(request.queries[0].qclass, QClass::IN); + } + + #[test] + fn parse_response() { + let data = [ + 0xAA, 0xAA, // transaction ID + 0x81, 0x80, // flags + 0x00, 0x01, // 1 question + 0x00, 0x01, // 1 A-answer + 0x00, 0x00, // no authority + 0x00, 0x00, // no additional answers + // example . com in label format + 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // + // + 0x00, 0x01, 0x00, 0x01, // + // + 0xC0, 0x0C, // ptr to question section + // + 0x00, 0x01, 0x00, 0x01, // A and IN + // + 0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds + // + 0x00, 0x04, // length of address + // IP address: + 192, 168, 1, 3, + ]; + + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); + + assert_eq!(response.id, 0xAAAA); + assert_eq!(response.flags.0, 0x8180); + assert_eq!(response.answers.len(), 1); + assert_eq!(response.answers[0].name, "example.com"); + assert_eq!(response.answers[0].atype, QType::A); + assert_eq!(response.answers[0].aclass, QClass::IN); + assert_eq!(response.answers[0].ttl, 60); + if let Record::A(a) = &response.answers[0].record { + assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3)); + } else { + panic!("Expected A record"); + } + } + + #[test] + fn parse_response_two_records() { + let data = [ + 0xAA, 0xAA, // + 0x81, 0x80, // + 0x00, 0x01, // + 0x00, 0x02, // + 0x00, 0x00, // + 0x00, 0x00, // + // example . com in label format + 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // + // + 0x00, 0x01, // query type + 0x00, 0x01, // query class + // + 0xC0, 0x0C, // pointer + 0x00, 0x01, // + 0x00, 0x01, // + 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds + 0x00, 0x04, // length of A-record + 0x5D, 0xB8, 0xD8, 0x22, // a-record + // + 0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, // + // + 0x00, 0x10, // TXT + 0x00, 0x01, // IN + // + 0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds + // + 0x00, 0x10, // length of txt record + // (len) "test txt record" + 0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72, + 0x64, + ]; + + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); + + assert_eq!(response.id, 0xAAAA); + assert_eq!(response.flags.0, 0x8180); + assert_eq!(response.answers.len(), 2); + + // First answer + assert_eq!(response.answers[0].name, "example.com"); + assert_eq!(response.answers[0].atype, QType::A); + assert_eq!(response.answers[0].aclass, QClass::IN); + assert_eq!(response.answers[0].ttl, 60); + if let Record::A(a) = &response.answers[0].record { + assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34)); + } else { + panic!("Expected A record"); + } + + // Second answer + assert_eq!(response.answers[1].name, "example.com"); + assert_eq!(response.answers[1].atype, QType::TXT); + assert_eq!(response.answers[1].aclass, QClass::IN); + assert_eq!(response.answers[1].ttl, 60); + if let Record::TXT(txt) = &response.answers[1].record + && let Some(&text) = txt.text.first() + { + assert_eq!(text, "test txt record"); + } else { + panic!("Expected TXT record"); + } + } + + #[test] + fn parse_response_srv() { + let data = [ + // + 0xAA, 0xAA, // id + 0x81, 0x80, // flags + 0x00, 0x01, // one question + 0x00, 0x01, // one answer + 0x00, 0x00, // no authority + 0x00, 0x00, // no extra + // + 0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61, + 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // + // + 0x00, 0x21, // type SRV + 0x00, 0x01, // IN + // + 0xc0, 0x0c, // + // + 0x00, 0x21, // SRV + 0x00, 0x01, // IN + 0x00, 0x00, 0x00, 0x3C, // ttl 60 + // + 0x00, 0x19, // data len + 0x00, 0x0A, // prio + 0x00, 0x05, // weight + 0x13, 0xC4, // PORT + // + 0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61, + 0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, // + ]; + + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); + + assert_eq!(response.id, 0xAAAA); + assert_eq!(response.flags.0, 0x8180); + assert_eq!(response.answers.len(), 1); + + // Answer + assert_eq!(response.answers[0].name, "_sip._tcp.example.com"); + assert_eq!(response.answers[0].atype, QType::SRV); + assert_eq!(response.answers[0].aclass, QClass::IN); + assert_eq!(response.answers[0].ttl, 60); + let Record::SRV(srv) = &response.answers[0].record else { + panic!("Expected SRV record"); + }; + + assert_eq!(srv.priority, 10); + assert_eq!(srv.weight, 5); + assert_eq!(srv.port, 5060); + assert_eq!(srv.target, "sipserver.example.com"); + } + + #[test] + fn parse_response_back_forth() { + let data = [ + 0, 0, // Transaction ID + 132, 0, // Response, Authoritative Answer, No Recursion + 0, 0, // 0 questions + 0, 4, // 4 answers + 0, 0, // 0 authority RRs + 0, 0, // 0 additional RRs + // _midiriff + 9, 95, 109, 105, 100, 105, 114, 105, 102, 102, // + // _udp + 4, 95, 117, 100, 112, // + // local + 5, 108, 111, 99, 97, 108, // + 0, // + // + 0, 12, // PTR + 0, 1, // Class IN + 0, 0, 0, 120, // TTL 120 seconds + 0, 10, // Data Length 10 + // pi35291 + 7, 112, 105, 51, 53, 50, 57, 49, // + // + 192, 12, // Pointer to _midirif._udp._local. + // + 192, 44, // Pointer to instace name: pi35291._midirif._udp._local. + 0, 33, // SRV + 128, 1, // IN (Cache flush bit set) + 0, 0, 0, 120, // TTL 120 seconds + 0, 11, // Data Length 11 + 0, 0, // Priority 0 + 0, 0, // Weight 0 + 137, 219, // Port 35291 + 2, 112, 105, // _pi + 192, 27, // Pointer to: .local. + // TXT (Empty) + 192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0, + // A (10.1.1.9) + 192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9, + ]; + + let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap(); + + assert_eq!(response.answers[0].name, "_midiriff._udp.local"); + assert_eq!(response.answers[0].ttl, 120); + let Record::PTR(ptr) = &response.answers[0].record else { + panic!() + }; + assert_eq!(ptr.name, "pi35291._midiriff._udp.local"); + + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + response.serialize(&mut buffer).unwrap(); + + let buffer = buffer.finish(); + + let response2 = Response::parse(&mut &buffer[..], buffer).unwrap(); + + assert_eq!(response, response2); + } + + #[test] + fn mdns_service_response() { + let mut response = Response { + id: 0x1234, + flags: Flags::standard_response(), + queries: Vec::new(), + answers: Vec::new(), + }; + + let query = Query { + name: Label::from("_test._udp.local"), + qtype: QType::PTR, + qclass: QClass::IN, + }; + response.queries.push(query); + + let ptr_answer = Answer { + name: Label::from("_test._udp.local"), + atype: QType::PTR, + aclass: QClass::IN, + ttl: 4500, + record: Record::PTR(PTR { + name: Label::from("test-service._test._udp.local"), + }), + }; + response.answers.push(ptr_answer); + + let srv_answer = Answer { + name: Label::from("test-service._test._udp.local"), + atype: QType::SRV, + aclass: QClass::IN, + ttl: 120, + record: Record::SRV(SRV { + priority: 0, + weight: 0, + port: 8080, + target: Label::from("host.local"), + }), + }; + response.answers.push(srv_answer); + + let txt_answer = Answer { + name: Label::from("test-service._test._udp.local"), + atype: QType::TXT, + aclass: QClass::IN, + ttl: 120, + record: Record::TXT(TXT { + text: vec!["path=/test"], + }), + }; + response.answers.push(txt_answer); + + let a_answer = Answer { + name: Label::from("host.local"), + atype: QType::A, + aclass: QClass::IN, + ttl: 120, + record: Record::A(A { + address: Ipv4Addr::new(192, 168, 1, 100), + }), + }; + response.answers.push(a_answer); + + let mut buffer = [0u8; 256]; + let mut buffer = Encoder::new(&mut buffer); + response.serialize(&mut buffer).unwrap(); + + let buffer = buffer.finish(); + + let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap(); + + assert_eq!(response, parsed_response); + } +} diff --git a/sachy-mdns/src/dns/traits.rs b/sachy-mdns/src/dns/traits.rs new file mode 100644 index 0000000..62f686b --- /dev/null +++ b/sachy-mdns/src/dns/traits.rs @@ -0,0 +1,21 @@ +use winnow::ModalResult; + +use crate::encoder::Encoder; + +pub trait DnsParse<'a>: Sized { + fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult; +} + +pub trait DnsParseKind<'a> { + type Output; + + fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult; +} + +pub trait DnsSerialize<'a> { + type Error; + + fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error>; + #[allow(dead_code)] + fn size(&self) -> usize; +} diff --git a/sachy-mdns/src/encoder.rs b/sachy-mdns/src/encoder.rs new file mode 100644 index 0000000..7fc763a --- /dev/null +++ b/sachy-mdns/src/encoder.rs @@ -0,0 +1,145 @@ +use core::ops::Range; + +use alloc::collections::BTreeMap; + +use crate::dns::traits::DnsSerialize; + +pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK; +pub(crate) const PTR_MASK: u8 = 0b1100_0000; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum DnsError { + LabelTooLong, + InvalidTxt, + Unsupported, +} + +impl core::fmt::Display for DnsError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"), + Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"), + Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"), + } + } +} + +impl core::error::Error for DnsError {} + +#[derive(Debug)] +pub struct Encoder<'a, 'b> { + output: &'b mut [u8], + position: usize, + lookup: BTreeMap<&'a str, u16>, + reservation: Option, +} + +impl<'a, 'b> Encoder<'a, 'b> { + pub const fn new(buffer: &'b mut [u8]) -> Self { + Self { + output: buffer, + position: 0, + lookup: BTreeMap::new(), + reservation: None, + } + } + + /// Takes a payload and encodes it, consuming the encoder and yielding the resulting + /// slice. + pub fn encode(mut self, payload: T) -> Result<&'b [u8], E> + where + E: core::error::Error, + T: DnsSerialize<'a, Error = E>, + { + payload.serialize(&mut self)?; + Ok(self.finish()) + } + + pub(crate) fn finish(self) -> &'b [u8] { + &self.output[..self.position] + } + + fn increment(&mut self, amount: usize) { + self.position += amount; + } + + pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> { + loop { + if let Some(pos) = self.get_label_position(label) { + let [b1, b2] = u16::to_be_bytes(pos); + self.write(&[b1 | PTR_MASK, b2]); + return Ok(()); + } + + let dot = label.find('.'); + + let end = dot.unwrap_or(label.len()); + let segment = &label[..end]; + let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?; + + if len > MAX_STR_LEN { + return Err(DnsError::LabelTooLong); + } + + self.store_label_position(label); + self.write(&len.to_be_bytes()); + self.write(segment.as_bytes()); + + match dot { + Some(end) => { + label = &label[end + 1..]; + } + None => { + self.write(&[0]); + return Ok(()); + } + } + } + } + + pub(crate) fn write(&mut self, bytes: &[u8]) { + let len = bytes.len(); + let end = self.position + len; + self.output[self.position..end].copy_from_slice(bytes); + self.increment(len); + } + + fn get_label_position(&mut self, label: &str) -> Option { + self.lookup.get(label).copied() + } + + fn store_label_position(&mut self, label: &'a str) { + self.lookup.insert(label, self.position as u16); + } + + fn reserve_record_length(&mut self) { + if self.reservation.is_none() { + self.reservation = Some(self.position); + self.increment(2); + } + } + + fn distance_from_reservation(&mut self) -> Option<(Range, u16)> { + self.reservation + .take() + .map(|start| (start..(start + 2), (self.position - start - 2) as u16)) + } + + fn write_record_length(&mut self) { + if let Some((reservation, len)) = self.distance_from_reservation() { + self.output[reservation].copy_from_slice(&len.to_be_bytes()); + } + } + + pub(crate) fn with_record_length(&mut self, encoding_scope: F) -> Result<(), E> + where + E: core::error::Error, + F: FnOnce(&mut Self) -> Result<(), E>, + { + self.reserve_record_length(); + encoding_scope(self)?; + self.write_record_length(); + Ok(()) + } +} diff --git a/sachy-mdns/src/lib.rs b/sachy-mdns/src/lib.rs new file mode 100644 index 0000000..39775ef --- /dev/null +++ b/sachy-mdns/src/lib.rs @@ -0,0 +1,60 @@ +#![no_std] + +mod dns; +pub(crate) mod encoder; +pub mod server; +mod service; +mod state; + +extern crate alloc; + +use core::net::{Ipv4Addr, SocketAddrV4}; + +pub use crate::service::Service; +pub use crate::state::MdnsAction; +use crate::{dns::flags::Flags, server::Server, state::MdnsStateMachine}; + +/// Standard port for mDNS (5353). +pub const MDNS_PORT: u16 = 5353; + +/// Standard IPv4 multicast address for mDNS (224.0.0.251). +pub const GROUP_ADDR_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251); +pub const GROUP_SOCK_V4: SocketAddrV4 = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT); + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct MdnsService { + server: Server, + state: MdnsStateMachine, +} + +impl MdnsService { + pub fn new(service: Service) -> Self { + Self { + server: Server::new(service), + state: Default::default(), + } + } + + pub fn next_action(&mut self) -> MdnsAction { + self.state.drive_next_action() + } + + pub fn send_announcement<'buffer>(&self, outgoing: &'buffer mut [u8]) -> Option<&'buffer [u8]> { + self.server.broadcast( + server::ResponseKind::Announcement, + Flags::standard_response(), + 1, + alloc::vec::Vec::new(), + outgoing, + ) + } + + pub fn listen_for_queries<'buffer>( + &mut self, + incoming: &[u8], + outgoing: &'buffer mut [u8], + ) -> Option<&'buffer [u8]> { + self.server.respond(incoming, outgoing) + } +} diff --git a/sachy-mdns/src/server.rs b/sachy-mdns/src/server.rs new file mode 100644 index 0000000..f8c982a --- /dev/null +++ b/sachy-mdns/src/server.rs @@ -0,0 +1,95 @@ +use alloc::vec::Vec; +use sachy_fmt::{error, info}; + +use crate::{ + dns::{ + flags::Flags, + query::{QClass, Query}, + records::QType, + reqres::{Request, Response}, + traits::DnsParse, + }, + encoder::Encoder, + service::Service, +}; + +pub(crate) enum ResponseKind { + Announcement, + QueryResponse(Vec<(QType, QClass)>), +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct Server { + service: Service, +} + +impl Server { + pub(crate) fn new(service: Service) -> Self { + Self { service } + } + + pub(crate) fn broadcast<'a, 'b>( + &self, + response_kind: ResponseKind, + flags: Flags, + id: u16, + queries: Vec>, + outgoing: &'a mut [u8], + ) -> Option<&'a [u8]> { + let answers: Vec<_> = match response_kind { + ResponseKind::Announcement => self.service.as_answers(QClass::Multicast).collect(), + ResponseKind::QueryResponse(valid) => valid + .iter() + .flat_map(|&(qtype, qclass)| match qtype { + QType::A | QType::AAAA => self.service.ip_answer(qclass), + QType::PTR => self.service.ptr_answer(qclass), + QType::TXT => self.service.txt_answer(qclass), + QType::SRV => self.service.srv_answer(qclass), + QType::Any | QType::Unknown(_) => None, + }) + .collect(), + }; + + if !answers.is_empty() { + let res = Response { + flags, + id, + queries, + answers, + }; + + info!("MDNS RESPONSE: {}", res); + + return Encoder::new(outgoing) + .encode(res) + .inspect_err(|err| error!("Encoder errored: {}", err)) + .ok(); + } + + None + } + + pub(crate) fn respond<'a>(&self, incoming: &[u8], outgoing: &'a mut [u8]) -> Option<&'a [u8]> { + Request::parse(&mut &incoming[..], incoming) + .ok() + .and_then(|req| { + let valid_queries = req + .queries + .iter() + .filter_map(|q| match q.qtype { + QType::A | QType::AAAA | QType::TXT | QType::SRV => { + (q.name == self.service.hostname()).then_some((q.qtype, q.qclass)) + } + QType::PTR => (q.name == self.service.service_type()).then_some((q.qtype, q.qclass)), + QType::Any | QType::Unknown(_) => None, + }).collect::>(); + + if !valid_queries.is_empty() { + self.broadcast(ResponseKind::QueryResponse(valid_queries), req.flags, req.id, req.queries, outgoing) + } else { + None + } + }) + } +} diff --git a/sachy-mdns/src/service.rs b/sachy-mdns/src/service.rs new file mode 100644 index 0000000..3943b36 --- /dev/null +++ b/sachy-mdns/src/service.rs @@ -0,0 +1,193 @@ +use core::net::IpAddr; + +use alloc::{ + string::{String, ToString}, + vec::Vec, +}; + +use crate::dns::{ + label::Label, + query::{Answer, QClass}, + records::{A, AAAA, PTR, QType, Record, SRV, TXT}, +}; + +#[derive(Debug, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Service { + service_type: String, + instance: String, + hostname: String, + ip: Option, + port: u16, +} + +impl Service { + pub fn new( + service_type: impl Into, + instance: impl Into, + hostname: impl Into, + ip: Option, + port: u16, + ) -> Self { + let service_type = service_type.into(); + let mut instance = instance.into(); + let mut hostname = hostname.into(); + + instance.push('.'); + instance.push_str(&service_type); + hostname.push_str(".local"); + + Self { + service_type, + instance, + hostname, + ip, + port, + } + } + + pub fn service_type(&self) -> Label<'_> { + Label::from(self.service_type.as_ref()) + } + + pub fn instance(&self) -> Label<'_> { + Label::from(self.instance.as_ref()) + } + + pub fn hostname(&self) -> Label<'_> { + Label::from(self.hostname.as_ref()) + } + + pub fn ip(&self) -> Option { + self.ip + } + + pub fn port(&self) -> u16 { + self.port + } + + pub(crate) fn ptr_answer(&self, aclass: QClass) -> Option> { + Some(Answer { + name: self.service_type(), + atype: QType::PTR, + aclass, + ttl: 4500, + record: Record::PTR(PTR { + name: self.instance(), + }), + }) + } + + pub(crate) fn srv_answer(&self, aclass: QClass) -> Option> { + Some(Answer { + name: self.instance(), + atype: QType::SRV, + aclass, + ttl: 120, + record: Record::SRV(SRV { + priority: 0, + weight: 0, + port: self.port, + target: self.hostname(), + }), + }) + } + + pub(crate) fn txt_answer(&self, aclass: QClass) -> Option> { + Some(Answer { + name: self.instance(), + atype: QType::TXT, + aclass, + ttl: 120, + record: Record::TXT(TXT { text: Vec::new() }), + }) + } + + pub(crate) fn ip_answer(&self, aclass: QClass) -> Option> { + self.ip().map(|address| match address { + IpAddr::V4(address) => Answer { + name: self.hostname(), + atype: QType::A, + aclass, + ttl: 120, + record: Record::A(A { address }), + }, + IpAddr::V6(address) => Answer { + name: self.hostname(), + atype: QType::AAAA, + aclass, + ttl: 120, + record: Record::AAAA(AAAA { address }), + }, + }) + } + + pub(crate) fn as_answers(&self, aclass: QClass) -> impl Iterator> { + self.ptr_answer(aclass) + .into_iter() + .chain(self.srv_answer(aclass)) + .chain(self.txt_answer(aclass)) + .chain(self.ip_answer(aclass)) + } + + #[allow(dead_code)] + pub(crate) fn from_answers(answers: &[Answer<'_>]) -> Vec { + let mut output = Vec::new(); + + // Step 1: Process PTR records + for answer in answers { + if let Record::PTR(ptr) = &answer.record { + let instance = ptr.name.to_string(); + let service_type = answer.name.to_string(); + output.push(Self { + service_type, + instance, + ..Default::default() + }); + } + } + + // Step 2: Process SRV records, A and AAAA records and merge data + for answer in answers { + match &answer.record { + Record::SRV(srv) => { + if let Some(stub) = output + .iter_mut() + .find(|stub| answer.name == stub.instance.as_ref()) + { + stub.hostname = srv.target.to_string(); + stub.port = srv.port; + } + } + Record::A(a) => { + if let Some(stub) = output + .iter_mut() + .find(|stub| answer.name == stub.hostname.as_ref()) + { + stub.ip = Some(IpAddr::V4(a.address)); + } + } + Record::AAAA(aaaa) => { + if let Some(stub) = output + .iter_mut() + .find(|stub| answer.name == stub.hostname.as_ref()) + { + stub.ip = Some(IpAddr::V6(aaaa.address)); + } + } + _ => {} + } + } + + // Final step: Retain only complete services + output.retain(|stub| { + !stub.service_type.is_empty() + && !stub.instance.is_empty() + && !stub.hostname.is_empty() + && stub.ip.is_some() + && stub.port != 0 + }); + + output + } +} diff --git a/sachy-mdns/src/state.rs b/sachy-mdns/src/state.rs new file mode 100644 index 0000000..5a4dd11 --- /dev/null +++ b/sachy-mdns/src/state.rs @@ -0,0 +1,100 @@ +use embassy_time::{Duration, Instant}; +use sachy_fmt::{debug, unwrap}; + +#[derive(Debug, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum MdnsStateMachine { + #[default] + Start, + Announce { + last_sent: Instant, + }, + ListenFor { + last_sent: Instant, + timeout: Duration, + }, + WaitFor { + last_sent: Instant, + duration: Duration, + }, +} + +impl MdnsStateMachine { + /// Set the state to announced, if we have timed out the listening period and need + /// to announce, or if we received a query while listening and have sent a response. + pub(crate) fn announced(&mut self) { + *self = Self::Announce { + last_sent: Instant::now(), + }; + } + + fn next_state(&mut self) { + match self { + Self::Start => self.announced(), + &mut Self::Announce { last_sent } => { + let duration_since = last_sent.elapsed(); + let duration = Duration::from_secs(1) - duration_since; + + *self = Self::WaitFor { + last_sent, + duration, + }; + } + &mut Self::ListenFor { last_sent, .. } => { + let duration_since = last_sent.elapsed(); + let time_limit = Duration::from_secs(120); + + if duration_since >= time_limit { + self.announced(); + } else { + let timeout = time_limit - duration_since; + *self = Self::ListenFor { last_sent, timeout }; + } + } + &mut Self::WaitFor { last_sent, .. } => { + let duration_since = last_sent.elapsed(); + let time_limit = Duration::from_secs(120); + let timeout = time_limit - duration_since; + + *self = Self::ListenFor { last_sent, timeout }; + } + } + } + + pub(crate) fn drive_next_action(&mut self) -> MdnsAction { + self.next_state(); + unwrap!(MdnsAction::try_from(self)) + } +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum MdnsAction { + Announce, + ListenFor { timeout: Duration }, + WaitFor { duration: Duration }, +} + +impl TryFrom<&mut MdnsStateMachine> for MdnsAction { + type Error = MdnsStateMachine; + + fn try_from(value: &mut MdnsStateMachine) -> Result { + match value { + // We should start in this state, but never remain nor return to it when + // executing our state machine event loop. + MdnsStateMachine::Start => Err(MdnsStateMachine::Start), + MdnsStateMachine::Announce { .. } => { + debug!("ANNOUNCE"); + Ok(Self::Announce) + } + &mut MdnsStateMachine::ListenFor { timeout, .. } => { + debug!("LISTEN FOR {}ms", timeout.as_millis()); + Ok(Self::ListenFor { timeout }) + } + &mut MdnsStateMachine::WaitFor { duration, .. } => { + debug!("WAIT FOR {}ms", duration.as_millis()); + Ok(Self::WaitFor { duration }) + } + } + } +} -- 2.52.0