Adds a mdns resolver/state-machine crate for providing MDNS-SD functionality for an embedded device.
+21
.tangled/workflows/miri.yml
+21
.tangled/workflows/miri.yml
···
···
1
+
when:
2
+
- event: ["push", "pull_request"]
3
+
branch: main
4
+
5
+
engine: nixery
6
+
7
+
dependencies:
8
+
nixpkgs:
9
+
- clang
10
+
- rustup
11
+
12
+
steps:
13
+
- name: Install Nightly
14
+
command: |
15
+
rustup toolchain install nightly --component miri
16
+
rustup override set nightly
17
+
cargo miri setup
18
+
- name: Miri Test
19
+
command: cargo miri test --locked -p sachy-mdns
20
+
environment:
21
+
RUSTFLAGS: -Zrandomize-layout
+19
-3
Cargo.lock
+19
-3
Cargo.lock
···
420
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
421
dependencies = [
422
"libc",
423
-
"windows-sys 0.52.0",
424
]
425
426
[[package]]
···
1065
"errno",
1066
"libc",
1067
"linux-raw-sys 0.11.0",
1068
-
"windows-sys 0.52.0",
1069
]
1070
1071
[[package]]
···
1119
name = "sachy-fnv"
1120
version = "0.1.0"
1121
1122
[[package]]
1123
name = "sachy-shtc3"
1124
version = "0.1.0"
···
1292
"getrandom",
1293
"once_cell",
1294
"rustix 1.1.2",
1295
-
"windows-sys 0.52.0",
1296
]
1297
1298
[[package]]
···
1483
source = "registry+https://github.com/rust-lang/crates.io-index"
1484
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
1485
1486
[[package]]
1487
name = "wit-bindgen"
1488
version = "0.46.0"
···
420
checksum = "39cab71617ae0d63f51a36d69f866391735b51691dbda63cf6f96d042b63efeb"
421
dependencies = [
422
"libc",
423
+
"windows-sys 0.61.2",
424
]
425
426
[[package]]
···
1065
"errno",
1066
"libc",
1067
"linux-raw-sys 0.11.0",
1068
+
"windows-sys 0.61.2",
1069
]
1070
1071
[[package]]
···
1119
name = "sachy-fnv"
1120
version = "0.1.0"
1121
1122
+
[[package]]
1123
+
name = "sachy-mdns"
1124
+
version = "0.1.0"
1125
+
dependencies = [
1126
+
"defmt 1.0.1",
1127
+
"embassy-time",
1128
+
"sachy-fmt",
1129
+
"winnow",
1130
+
]
1131
+
1132
[[package]]
1133
name = "sachy-shtc3"
1134
version = "0.1.0"
···
1302
"getrandom",
1303
"once_cell",
1304
"rustix 1.1.2",
1305
+
"windows-sys 0.61.2",
1306
]
1307
1308
[[package]]
···
1493
source = "registry+https://github.com/rust-lang/crates.io-index"
1494
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
1495
1496
+
[[package]]
1497
+
name = "winnow"
1498
+
version = "0.7.14"
1499
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1500
+
checksum = "5a5364e9d77fcdeeaa6062ced926ee3381faa2ee02d3eb83a5c27a8825540829"
1501
+
1502
[[package]]
1503
name = "wit-bindgen"
1504
version = "0.46.0"
+1
Cargo.toml
+1
Cargo.toml
+22
sachy-mdns/Cargo.toml
+22
sachy-mdns/Cargo.toml
···
···
1
+
[package]
2
+
name = "sachy-mdns"
3
+
authors.workspace = true
4
+
edition.workspace = true
5
+
repository.workspace = true
6
+
license.workspace = true
7
+
version.workspace = true
8
+
rust-version.workspace = true
9
+
10
+
[dependencies]
11
+
defmt = { workspace = true, optional = true, features = ["alloc"] }
12
+
embassy-time = { workspace = true }
13
+
sachy-fmt = { path = "../sachy-fmt" }
14
+
winnow = { version = "0.7.12", default-features = false }
15
+
16
+
[features]
17
+
default = []
18
+
std = []
19
+
defmt = ["dep:defmt"]
20
+
21
+
[dev-dependencies]
22
+
winnow = { version = "0.7.12", default-features = false, features = ["alloc"] }
+6
sachy-mdns/src/dns.rs
+6
sachy-mdns/src/dns.rs
+231
sachy-mdns/src/dns/flags.rs
+231
sachy-mdns/src/dns/flags.rs
···
···
1
+
#![allow(dead_code)]
2
+
3
+
use core::{convert::Infallible, fmt};
4
+
use winnow::{ModalResult, Parser, binary::be_u16};
5
+
6
+
use crate::{
7
+
dns::traits::{DnsParse, DnsSerialize},
8
+
encoder::Encoder,
9
+
};
10
+
11
+
#[derive(Default, Clone, Copy, PartialEq, Eq)]
12
+
pub struct Flags(pub u16);
13
+
14
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15
+
#[repr(u8)]
16
+
pub enum Opcode {
17
+
Query = 0,
18
+
IQuery = 1,
19
+
Status = 2,
20
+
Reserved = 3,
21
+
Notify = 4,
22
+
Update = 5,
23
+
// Other values are reserved
24
+
}
25
+
26
+
impl Opcode {
27
+
const fn cast(value: u8) -> Self {
28
+
match value {
29
+
0 => Opcode::Query,
30
+
1 => Opcode::IQuery,
31
+
2 => Opcode::Status,
32
+
4 => Opcode::Notify,
33
+
5 => Opcode::Update,
34
+
_ => Opcode::Reserved,
35
+
}
36
+
}
37
+
}
38
+
39
+
impl From<u8> for Opcode {
40
+
fn from(value: u8) -> Self {
41
+
Self::cast(value)
42
+
}
43
+
}
44
+
45
+
impl From<Opcode> for u8 {
46
+
fn from(opcode: Opcode) -> Self {
47
+
opcode as u8
48
+
}
49
+
}
50
+
51
+
impl Flags {
52
+
const fn new() -> Self {
53
+
Flags(0)
54
+
}
55
+
56
+
pub const fn standard_request() -> Self {
57
+
let mut flags = Flags::new();
58
+
flags.set_query(true);
59
+
flags.set_opcode(Opcode::Query);
60
+
flags.set_recursion_desired(true);
61
+
flags
62
+
}
63
+
64
+
pub const fn standard_response() -> Self {
65
+
let mut flags = Flags::new();
66
+
flags.set_query(false);
67
+
flags.set_opcode(Opcode::Query);
68
+
flags.set_authoritative(true);
69
+
flags.set_recursion_available(false);
70
+
flags
71
+
}
72
+
73
+
// QR: Query/Response Flag
74
+
pub const fn is_query(&self) -> bool {
75
+
(self.0 & 0x8000) == 0
76
+
}
77
+
78
+
pub const fn set_query(&mut self, is_query: bool) {
79
+
if is_query {
80
+
self.0 &= !0x8000;
81
+
} else {
82
+
self.0 |= 0x8000;
83
+
}
84
+
}
85
+
86
+
// Opcode (bits 1-4)
87
+
pub const fn get_opcode(&self) -> Opcode {
88
+
Opcode::cast(((self.0 >> 11) & 0x0F) as u8)
89
+
}
90
+
91
+
pub const fn set_opcode(&mut self, opcode: Opcode) {
92
+
self.0 = (self.0 & !0x7800) | (((opcode as u8) as u16 & 0x0F) << 11);
93
+
}
94
+
95
+
// AA: Authoritative Answer
96
+
pub const fn is_authoritative(&self) -> bool {
97
+
(self.0 & 0x0400) != 0
98
+
}
99
+
100
+
pub const fn set_authoritative(&mut self, authoritative: bool) {
101
+
if authoritative {
102
+
self.0 |= 0x0400;
103
+
} else {
104
+
self.0 &= !0x0400;
105
+
}
106
+
}
107
+
108
+
// TC: Truncated
109
+
pub const fn is_truncated(&self) -> bool {
110
+
(self.0 & 0x0200) != 0
111
+
}
112
+
113
+
pub const fn set_truncated(&mut self, truncated: bool) {
114
+
if truncated {
115
+
self.0 |= 0x0200;
116
+
} else {
117
+
self.0 &= !0x0200;
118
+
}
119
+
}
120
+
121
+
// RD: Recursion Desired
122
+
pub const fn is_recursion_desired(&self) -> bool {
123
+
(self.0 & 0x0100) != 0
124
+
}
125
+
126
+
pub const fn set_recursion_desired(&mut self, recursion_desired: bool) {
127
+
if recursion_desired {
128
+
self.0 |= 0x0100;
129
+
} else {
130
+
self.0 &= !0x0100;
131
+
}
132
+
}
133
+
134
+
// RA: Recursion Available
135
+
pub const fn is_recursion_available(&self) -> bool {
136
+
(self.0 & 0x0080) != 0
137
+
}
138
+
139
+
pub const fn set_recursion_available(&mut self, recursion_available: bool) {
140
+
if recursion_available {
141
+
self.0 |= 0x0080;
142
+
} else {
143
+
self.0 &= !0x0080;
144
+
}
145
+
}
146
+
147
+
// Z: Reserved for future use (bits 9-11)
148
+
pub const fn get_reserved(&self) -> u8 {
149
+
((self.0 >> 4) & 0x07) as u8
150
+
}
151
+
152
+
pub const fn set_reserved(&mut self, reserved: u8) {
153
+
self.0 = (self.0 & !0x0070) | ((reserved as u16 & 0x07) << 4);
154
+
}
155
+
156
+
// RCODE: Response Code (bits 12-15)
157
+
pub const fn get_rcode(&self) -> u8 {
158
+
(self.0 & 0x000F) as u8
159
+
}
160
+
161
+
pub const fn set_rcode(&mut self, rcode: u8) {
162
+
self.0 = (self.0 & !0x000F) | (rcode as u16 & 0x0F);
163
+
}
164
+
}
165
+
166
+
impl<'a> DnsParse<'a> for Flags {
167
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
168
+
be_u16.map(Flags).parse_next(input)
169
+
}
170
+
}
171
+
172
+
impl<'a> DnsSerialize<'a> for Flags {
173
+
type Error = Infallible;
174
+
175
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
176
+
encoder.write(&self.0.to_be_bytes());
177
+
Ok(())
178
+
}
179
+
180
+
fn size(&self) -> usize {
181
+
core::mem::size_of::<u16>()
182
+
}
183
+
}
184
+
185
+
impl fmt::Debug for Flags {
186
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
187
+
f.debug_struct("Flags")
188
+
.field("query", &self.is_query())
189
+
.field("opcode", &self.get_opcode())
190
+
.field("authoritative", &self.is_authoritative())
191
+
.field("truncated", &self.is_truncated())
192
+
.field("recursion_desired", &self.is_recursion_desired())
193
+
.field("recursion_available", &self.is_recursion_available())
194
+
.field("reserved", &self.get_reserved())
195
+
.field("rcode", &self.get_rcode())
196
+
.finish()
197
+
}
198
+
}
199
+
200
+
#[cfg(feature = "defmt")]
201
+
impl defmt::Format for Flags {
202
+
fn format(&self, fmt: defmt::Formatter) {
203
+
defmt::write!(
204
+
fmt,
205
+
"Flags {{ query: {}, opcode: {:?}, authoritative: {}, truncated: {}, recursion_desired: {}, recursion_available: {}, reserved: {}, rcode: {} }}",
206
+
self.is_query(),
207
+
self.get_opcode(),
208
+
self.is_authoritative(),
209
+
self.is_truncated(),
210
+
self.is_recursion_desired(),
211
+
self.is_recursion_available(),
212
+
self.get_reserved(),
213
+
self.get_rcode()
214
+
);
215
+
}
216
+
}
217
+
218
+
#[cfg(feature = "defmt")]
219
+
impl defmt::Format for Opcode {
220
+
fn format(&self, fmt: defmt::Formatter) {
221
+
let opcode_str = match self {
222
+
Opcode::Query => "Query",
223
+
Opcode::IQuery => "IQuery",
224
+
Opcode::Status => "Status",
225
+
Opcode::Reserved => "Reserved",
226
+
Opcode::Notify => "Notify",
227
+
Opcode::Update => "Update",
228
+
};
229
+
defmt::write!(fmt, "Opcode({=str})", opcode_str);
230
+
}
231
+
}
+511
sachy-mdns/src/dns/label.rs
+511
sachy-mdns/src/dns/label.rs
···
···
1
+
use core::{fmt, str};
2
+
3
+
use winnow::{
4
+
ModalResult, Parser,
5
+
binary::be_u8,
6
+
error::{ContextError, ErrMode, FromExternalError},
7
+
stream::Offset,
8
+
token::take,
9
+
};
10
+
11
+
use crate::{
12
+
dns::traits::{DnsParse, DnsSerialize},
13
+
encoder::{DnsError, Encoder, MAX_STR_LEN, PTR_MASK},
14
+
};
15
+
16
+
#[derive(Clone, Copy)]
17
+
pub struct Label<'a> {
18
+
repr: LabelRepr<'a>,
19
+
}
20
+
21
+
impl<'a> From<&'a str> for Label<'a> {
22
+
fn from(value: &'a str) -> Self {
23
+
Self {
24
+
repr: LabelRepr::Str(value),
25
+
}
26
+
}
27
+
}
28
+
29
+
impl<'a> DnsSerialize<'a> for Label<'a> {
30
+
type Error = DnsError;
31
+
32
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
33
+
match self.repr {
34
+
LabelRepr::Bytes {
35
+
context,
36
+
start,
37
+
end,
38
+
} => {
39
+
encoder.write(&context[start..end]);
40
+
41
+
Ok(())
42
+
}
43
+
LabelRepr::Str(label) => encoder.write_label(label),
44
+
}
45
+
}
46
+
47
+
fn size(&self) -> usize {
48
+
match self.repr {
49
+
LabelRepr::Bytes {
50
+
context,
51
+
start,
52
+
end,
53
+
} => core::mem::size_of_val(&context[start..end]),
54
+
LabelRepr::Str(label) => core::mem::size_of_val(label) + 1,
55
+
}
56
+
}
57
+
}
58
+
59
+
impl<'a> DnsParse<'a> for Label<'a> {
60
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
61
+
let start = input.offset_from(&context);
62
+
let mut end = start;
63
+
64
+
loop {
65
+
match LabelSegment::parse(input)? {
66
+
LabelSegment::Empty => {
67
+
end += 1;
68
+
break;
69
+
}
70
+
LabelSegment::String(label) => {
71
+
end += 1 + label.len();
72
+
}
73
+
LabelSegment::Pointer(_) => {
74
+
end += 2;
75
+
break;
76
+
}
77
+
}
78
+
}
79
+
80
+
Ok(Self {
81
+
repr: LabelRepr::Bytes {
82
+
context,
83
+
start,
84
+
end,
85
+
},
86
+
})
87
+
}
88
+
}
89
+
90
+
impl Label<'_> {
91
+
pub fn segments(&self) -> impl Iterator<Item = LabelSegment<'_>> {
92
+
self.repr.iter()
93
+
}
94
+
95
+
pub fn names(&self) -> impl Iterator<Item = &'_ str> {
96
+
match self.repr {
97
+
LabelRepr::Str(view) => Either::A(view.split('.')),
98
+
LabelRepr::Bytes { context, start, .. } => Either::B(
99
+
LabelSegmentBytesIter::new(context, start).flat_map(|label| label.as_str()),
100
+
),
101
+
}
102
+
}
103
+
104
+
pub fn is_empty(&self) -> bool {
105
+
self.repr.iter().next().is_none()
106
+
}
107
+
}
108
+
109
+
#[derive(Clone, Copy, PartialEq, Eq)]
110
+
enum LabelRepr<'a> {
111
+
Bytes {
112
+
context: &'a [u8],
113
+
start: usize,
114
+
end: usize,
115
+
},
116
+
Str(&'a str),
117
+
}
118
+
119
+
/// A DNS-compatible label segment.
120
+
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
121
+
pub enum LabelSegment<'a> {
122
+
/// The empty terminator.
123
+
Empty,
124
+
125
+
/// A string label.
126
+
String(&'a str),
127
+
128
+
/// A pointer to a previous name.
129
+
Pointer(u16),
130
+
}
131
+
132
+
impl<'a> LabelSegment<'a> {
133
+
fn parse(input: &mut &'a [u8]) -> ModalResult<Self> {
134
+
let b1 = be_u8(input)?;
135
+
136
+
match b1 {
137
+
0 => Ok(Self::Empty),
138
+
b1 if b1 & PTR_MASK == PTR_MASK => {
139
+
let b2 = be_u8(input)?;
140
+
141
+
let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
142
+
143
+
Ok(Self::Pointer(ptr))
144
+
}
145
+
len => {
146
+
if len > MAX_STR_LEN {
147
+
return Err(ErrMode::Cut(ContextError::from_external_error(
148
+
input, DnsError::LabelTooLong,
149
+
)));
150
+
}
151
+
152
+
let segment = take(len).try_map(core::str::from_utf8).parse_next(input)?;
153
+
154
+
Ok(Self::String(segment))
155
+
}
156
+
}
157
+
}
158
+
159
+
/// ## Safety
160
+
/// The caller upholds that this function is not called when parsing from newly received data. Data that
161
+
/// has yet to be determined to be a valid [`Label`] should be parsed and validated with [`Label::parse`],
162
+
/// and that the entire data/context has been validated, not just a portion of it.
163
+
#[inline]
164
+
unsafe fn parse_unchecked(input: &'a [u8]) -> Option<Self> {
165
+
input.split_first().map(|(b1, input)| match *b1 {
166
+
0 => Self::Empty,
167
+
b1 if b1 & PTR_MASK == PTR_MASK => {
168
+
// SAFETY: The caller has already validated that a second byte is available for a
169
+
// Pointer segment.
170
+
let b2 = unsafe { *input.get_unchecked(0) };
171
+
172
+
let ptr = u16::from_be_bytes([b1 & !PTR_MASK, b2]);
173
+
174
+
Self::Pointer(ptr)
175
+
}
176
+
len => {
177
+
// SAFETY: The caller has validated that this length value is correct and will only
178
+
// access within the bounds of the provided slice.
179
+
let segment = unsafe { input.get_unchecked(0..(len as usize)) };
180
+
// SAFETY: The caller has upheld the validity of the bytes as valid UTF-8 once before.
181
+
let segment = unsafe { core::str::from_utf8_unchecked(segment) };
182
+
183
+
Self::String(segment)
184
+
}
185
+
})
186
+
}
187
+
188
+
fn as_str(&self) -> Option<&'a str> {
189
+
match self {
190
+
Self::String(label) => Some(*label),
191
+
_ => None,
192
+
}
193
+
}
194
+
}
195
+
196
+
pub struct LabelSegmentBytesIter<'a> {
197
+
context: &'a [u8],
198
+
start: usize,
199
+
}
200
+
201
+
impl<'a> LabelSegmentBytesIter<'a> {
202
+
pub(crate) fn new(context: &'a [u8], start: usize) -> Self {
203
+
Self { context, start }
204
+
}
205
+
}
206
+
207
+
impl<'a> Iterator for LabelSegmentBytesIter<'a> {
208
+
type Item = LabelSegment<'a>;
209
+
210
+
fn next(&mut self) -> Option<Self::Item> {
211
+
loop {
212
+
let view = &self.context[self.start..];
213
+
214
+
// SAFETY: The segment has already been validated, so they should be all valid variants and UTF-8 bytes
215
+
let segment = unsafe { LabelSegment::parse_unchecked(view)? };
216
+
217
+
match segment {
218
+
LabelSegment::String(label) => {
219
+
self.start = self.start.saturating_add(label.len() + 1);
220
+
return Some(segment);
221
+
}
222
+
LabelSegment::Pointer(ptr) => {
223
+
self.start = ptr as usize;
224
+
}
225
+
LabelSegment::Empty => {
226
+
// Set the index offset to be len() so that the view is empty and terminates the loop
227
+
self.start = self.context.len();
228
+
return Some(LabelSegment::Empty);
229
+
}
230
+
}
231
+
}
232
+
}
233
+
}
234
+
235
+
impl<'a> LabelRepr<'a> {
236
+
fn iter(&self) -> impl Iterator<Item = LabelSegment<'a>> {
237
+
match *self {
238
+
LabelRepr::Bytes { context, start, .. } => {
239
+
Either::A(LabelSegmentBytesIter::new(context, start))
240
+
}
241
+
LabelRepr::Str(view) => Either::B(
242
+
view.split('.')
243
+
.map(LabelSegment::String)
244
+
.chain(Some(LabelSegment::Empty)),
245
+
),
246
+
}
247
+
}
248
+
}
249
+
250
+
impl fmt::Debug for Label<'_> {
251
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
252
+
struct LabelFmt<'a>(&'a Label<'a>);
253
+
254
+
impl fmt::Debug for LabelFmt<'_> {
255
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256
+
fmt::Display::fmt(self.0, f)
257
+
}
258
+
}
259
+
260
+
f.debug_tuple("Label").field(&LabelFmt(self)).finish()
261
+
}
262
+
}
263
+
264
+
impl fmt::Display for Label<'_> {
265
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
266
+
let mut names = self.names();
267
+
268
+
if let Some(name) = names.next() {
269
+
f.write_str(name)?;
270
+
271
+
names.try_for_each(|name| {
272
+
f.write_str(".")?;
273
+
f.write_str(name)
274
+
})
275
+
} else {
276
+
Ok(())
277
+
}
278
+
}
279
+
}
280
+
281
+
impl<'a, 'b> PartialEq<Label<'a>> for Label<'b> {
282
+
fn eq(&self, other: &Label<'a>) -> bool {
283
+
self.segments().eq(other.segments())
284
+
}
285
+
}
286
+
287
+
impl Eq for Label<'_> {}
288
+
289
+
impl PartialEq<&str> for Label<'_> {
290
+
fn eq(&self, other: &&str) -> bool {
291
+
let mut self_iter = self.names();
292
+
let mut other_iter = other.split('.');
293
+
294
+
loop {
295
+
match (self_iter.next(), other_iter.next()) {
296
+
(Some(self_part), Some(other_part)) => {
297
+
if self_part != other_part {
298
+
return false;
299
+
}
300
+
}
301
+
(None, None) => return true,
302
+
_ => return false,
303
+
}
304
+
}
305
+
}
306
+
}
307
+
308
+
#[cfg(feature = "defmt")]
309
+
impl defmt::Format for Label<'_> {
310
+
fn format(&self, fmt: defmt::Formatter) {
311
+
defmt::write!(fmt, "Label(");
312
+
let mut iter = self.names();
313
+
if let Some(first) = iter.next() {
314
+
defmt::write!(fmt, "{}", first);
315
+
316
+
iter.for_each(|part| defmt::write!(fmt, ".{}", part));
317
+
}
318
+
defmt::write!(fmt, ")");
319
+
}
320
+
}
321
+
322
+
/// One iterator or another.
323
+
enum Either<A, B> {
324
+
A(A),
325
+
B(B),
326
+
}
327
+
328
+
impl<A: Iterator, Other: Iterator<Item = A::Item>> Iterator for Either<A, Other> {
329
+
type Item = A::Item;
330
+
331
+
fn next(&mut self) -> Option<Self::Item> {
332
+
match self {
333
+
Either::A(a) => a.next(),
334
+
Either::B(b) => b.next(),
335
+
}
336
+
}
337
+
338
+
fn size_hint(&self) -> (usize, Option<usize>) {
339
+
match self {
340
+
Either::A(a) => a.size_hint(),
341
+
Either::B(b) => b.size_hint(),
342
+
}
343
+
}
344
+
345
+
fn fold<B, F>(self, init: B, f: F) -> B
346
+
where
347
+
Self: Sized,
348
+
F: FnMut(B, Self::Item) -> B,
349
+
{
350
+
match self {
351
+
Either::A(a) => a.fold(init, f),
352
+
Either::B(b) => b.fold(init, f),
353
+
}
354
+
}
355
+
}
356
+
357
+
#[cfg(test)]
358
+
mod test {
359
+
use super::*;
360
+
361
+
#[test]
362
+
fn segments_iter_test() {
363
+
let label: Label<'static> = Label::from("_service._udp.local");
364
+
let mut segments = label.segments();
365
+
366
+
assert_eq!(segments.next(), Some(LabelSegment::String("_service")));
367
+
assert_eq!(segments.next(), Some(LabelSegment::String("_udp")));
368
+
assert_eq!(segments.next(), Some(LabelSegment::String("local")));
369
+
assert_eq!(segments.next(), Some(LabelSegment::Empty));
370
+
assert_eq!(segments.next(), None);
371
+
372
+
// example.com with a pointer to the start
373
+
let data = b"\x07example\x03com\x00\xC0\x00";
374
+
let context = &data[..];
375
+
// The data here is entirely valid, even though we parse only a portion of it.
376
+
let label = Label::parse(&mut &data[13..], context).unwrap();
377
+
378
+
let mut segments = label.segments();
379
+
assert_eq!(segments.next(), Some(LabelSegment::String("example")));
380
+
assert_eq!(segments.next(), Some(LabelSegment::String("com")));
381
+
assert_eq!(segments.next(), Some(LabelSegment::Empty));
382
+
assert_eq!(segments.next(), None);
383
+
}
384
+
385
+
#[test]
386
+
fn names_iter_test() {
387
+
let label: Label<'static> = Label::from("_service._udp.local");
388
+
let mut names = label.names();
389
+
390
+
assert_eq!(names.next(), Some("_service"));
391
+
assert_eq!(names.next(), Some("_udp"));
392
+
assert_eq!(names.next(), Some("local"));
393
+
assert_eq!(names.next(), None);
394
+
395
+
let data = b"\x07example\x03com\x00\xC0\x00";
396
+
let context = &data[..];
397
+
// The data here is entirely valid, even though we parse only a portion of it.
398
+
let label = Label::parse(&mut &data[13..], context).unwrap();
399
+
400
+
let mut names = label.names();
401
+
assert_eq!(names.next(), Some("example"));
402
+
assert_eq!(names.next(), Some("com"));
403
+
assert_eq!(names.next(), None);
404
+
}
405
+
406
+
#[test]
407
+
fn serialize_str_label() {
408
+
let label: Label<'static> = Label::from("_service._udp.local");
409
+
let mut buffer = [0u8; 256];
410
+
let mut buffer = Encoder::new(&mut buffer);
411
+
label.serialize(&mut buffer).unwrap();
412
+
assert_eq!(buffer.finish(), b"\x08_service\x04_udp\x05local\x00");
413
+
}
414
+
415
+
#[test]
416
+
fn serialize_compressed_str_label() {
417
+
let label: Label<'static> = Label::from("_service._udp.local");
418
+
let label2: Label<'static> = Label::from("other._udp.local");
419
+
let mut buffer = [0u8; 256];
420
+
let mut buffer = Encoder::new(&mut buffer);
421
+
label.serialize(&mut buffer).unwrap();
422
+
label2.serialize(&mut buffer).unwrap();
423
+
assert_eq!(
424
+
buffer.finish(),
425
+
b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09"
426
+
);
427
+
}
428
+
429
+
#[test]
430
+
fn round_trip_compressed_str_label() {
431
+
let label: Label<'static> = Label::from("_service._udp.local");
432
+
let label2: Label<'static> = Label::from("other._udp.local");
433
+
let mut buffer = [0u8; 256];
434
+
let mut buffer = Encoder::new(&mut buffer);
435
+
label.serialize(&mut buffer).unwrap();
436
+
label2.serialize(&mut buffer).unwrap();
437
+
let context = buffer.finish();
438
+
assert_eq!(
439
+
context,
440
+
b"\x08_service\x04_udp\x05local\x00\x05other\xC0\x09"
441
+
);
442
+
let view = &mut &context[..];
443
+
444
+
let parsed_label = Label::parse(view, context).unwrap();
445
+
let parsed_label2 = Label::parse(view, context).unwrap();
446
+
447
+
let parsed_label_count = parsed_label.segments().count();
448
+
let parsed_label2_count = parsed_label2.segments().count();
449
+
450
+
// Both have same amount of segments.
451
+
assert_eq!(parsed_label_count, parsed_label2_count);
452
+
453
+
assert_eq!(parsed_label, label);
454
+
assert_eq!(parsed_label2, label2);
455
+
}
456
+
457
+
#[test]
458
+
fn label_byte_repr_serialization_quick_path() {
459
+
let data = b"\x07example\x03com\x00\xC0\x00";
460
+
let context = &data[..];
461
+
// The data here is entirely valid, even though we parse only a portion of it.
462
+
let label = Label::parse(&mut &data[13..], context).unwrap();
463
+
464
+
let mut buffer = [0u8; 256];
465
+
let mut buffer = Encoder::new(&mut buffer);
466
+
label.serialize(&mut buffer).unwrap();
467
+
// If the original Label is just a pointer, the new output will be a pointer, assuming
468
+
// the original data is also present in the output
469
+
assert_eq!(buffer.finish(), b"\xC0\x00");
470
+
}
471
+
472
+
#[test]
473
+
fn parse_and_eq_created_label() {
474
+
let data = b"\x07example\x03com\x00\xC0\x00";
475
+
let context = &data[..];
476
+
477
+
// The data here is entirely valid, even though we parse only a portion of it.
478
+
let parsed_label = Label::parse(&mut &data[13..], context).unwrap();
479
+
480
+
let created_label = Label::from("example.com");
481
+
482
+
assert_eq!(parsed_label, created_label);
483
+
}
484
+
485
+
#[test]
486
+
fn parse_and_eq_label_with_str() {
487
+
let data = b"\x07example\x03com\x00";
488
+
let context = &data[..];
489
+
490
+
let parsed_label = Label::parse(&mut &data[..], context).unwrap();
491
+
492
+
assert_eq!(parsed_label, "example.com");
493
+
}
494
+
495
+
#[test]
496
+
fn parse_ptr_label_and_eq_with_str() {
497
+
let data = b"\x07example\x03com\x00\xC0\x00";
498
+
let context = &data[..];
499
+
500
+
// The data here is entirely valid, even though we parse only a portion of it.
501
+
let parsed_label = Label::parse(&mut &data[13..], context).unwrap();
502
+
503
+
assert_eq!(parsed_label, "example.com");
504
+
}
505
+
506
+
#[test]
507
+
fn label_new_without_dot_is_not_empty() {
508
+
let label: Label = Label::from("example");
509
+
assert!(!label.is_empty());
510
+
}
511
+
}
+228
sachy-mdns/src/dns/query.rs
+228
sachy-mdns/src/dns/query.rs
···
···
1
+
use winnow::binary::{be_u16, be_u32};
2
+
use winnow::{ModalResult, Parser};
3
+
4
+
use super::label::Label;
5
+
use super::records::Record;
6
+
use crate::encoder::Encoder;
7
+
use crate::{
8
+
dns::{
9
+
records::QType,
10
+
traits::{DnsParse, DnsParseKind, DnsSerialize},
11
+
},
12
+
encoder::DnsError,
13
+
};
14
+
15
+
#[derive(Debug, PartialEq, Eq)]
16
+
pub struct Query<'a> {
17
+
pub name: Label<'a>,
18
+
pub qtype: QType,
19
+
pub qclass: QClass,
20
+
}
21
+
22
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
23
+
#[repr(u16)]
24
+
pub enum QClass {
25
+
IN = 1,
26
+
Multicast = 32769, // (IN + Cache flush bit)
27
+
Unknown(u16),
28
+
}
29
+
30
+
impl<'a> DnsParse<'a> for Query<'a> {
31
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
32
+
let name = Label::parse(input, context)?;
33
+
let qtype = QType::parse(input, context)?;
34
+
let qclass = be_u16.map(QClass::from_u16).parse_next(input)?;
35
+
36
+
Ok(Query {
37
+
name,
38
+
qtype,
39
+
qclass,
40
+
})
41
+
}
42
+
}
43
+
44
+
impl<'a> DnsSerialize<'a> for Query<'a> {
45
+
type Error = DnsError;
46
+
47
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
48
+
self.name.serialize(encoder)?;
49
+
self.qtype.serialize(encoder).ok();
50
+
encoder.write(&self.qclass.to_u16().to_be_bytes());
51
+
Ok(())
52
+
}
53
+
54
+
fn size(&self) -> usize {
55
+
self.name.size() + self.qtype.size() + core::mem::size_of::<QClass>()
56
+
}
57
+
}
58
+
59
+
#[derive(Debug, PartialEq, Eq)]
60
+
pub struct Answer<'a> {
61
+
pub name: Label<'a>,
62
+
pub atype: QType,
63
+
pub aclass: QClass,
64
+
pub ttl: u32,
65
+
pub record: Record<'a>,
66
+
}
67
+
68
+
impl QClass {
69
+
fn from_u16(value: u16) -> Self {
70
+
match value {
71
+
1 => QClass::IN,
72
+
32769 => QClass::Multicast,
73
+
_ => QClass::Unknown(value),
74
+
}
75
+
}
76
+
77
+
fn to_u16(self) -> u16 {
78
+
match self {
79
+
QClass::IN => 1,
80
+
QClass::Multicast => 32769,
81
+
QClass::Unknown(value) => value,
82
+
}
83
+
}
84
+
}
85
+
86
+
impl<'a> DnsParse<'a> for Answer<'a> {
87
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
88
+
let name = Label::parse(input, context)?;
89
+
let atype = QType::parse(input, context)?;
90
+
let aclass = be_u16.map(QClass::from_u16).parse_next(input)?;
91
+
92
+
let ttl = be_u32.parse_next(input)?;
93
+
let record = atype.parse_kind(input, context)?;
94
+
95
+
Ok(Answer {
96
+
name,
97
+
atype,
98
+
aclass,
99
+
ttl,
100
+
record,
101
+
})
102
+
}
103
+
}
104
+
105
+
impl<'a> DnsSerialize<'a> for Answer<'a> {
106
+
type Error = DnsError;
107
+
108
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
109
+
self.name.serialize(encoder)?;
110
+
self.atype.serialize(encoder).ok();
111
+
encoder.write(&self.aclass.to_u16().to_be_bytes());
112
+
encoder.write(&self.ttl.to_be_bytes());
113
+
self.record.serialize(encoder)
114
+
}
115
+
116
+
fn size(&self) -> usize {
117
+
self.name.size()
118
+
+ self.atype.size()
119
+
+ core::mem::size_of::<QClass>()
120
+
+ core::mem::size_of::<u32>()
121
+
+ self.record.size()
122
+
}
123
+
}
124
+
125
+
#[cfg(feature = "defmt")]
126
+
impl<'a> defmt::Format for Query<'a> {
127
+
fn format(&self, fmt: defmt::Formatter) {
128
+
defmt::write!(
129
+
fmt,
130
+
"Query {{ name: {:?}, qtype: {:?}, qclass: {:?} }}",
131
+
self.name,
132
+
self.qtype,
133
+
self.qclass
134
+
);
135
+
}
136
+
}
137
+
138
+
#[cfg(feature = "defmt")]
139
+
impl defmt::Format for QType {
140
+
fn format(&self, fmt: defmt::Formatter) {
141
+
let qtype_str = match self {
142
+
QType::A => "A",
143
+
QType::AAAA => "AAAA",
144
+
QType::PTR => "PTR",
145
+
QType::TXT => "TXT",
146
+
QType::SRV => "SRV",
147
+
QType::Any => "Any",
148
+
QType::Unknown(_) => "Unknown",
149
+
};
150
+
defmt::write!(fmt, "QType({=str})", qtype_str);
151
+
}
152
+
}
153
+
154
+
#[cfg(feature = "defmt")]
155
+
impl defmt::Format for QClass {
156
+
fn format(&self, fmt: defmt::Formatter) {
157
+
let qclass_str = match self {
158
+
QClass::IN => "IN",
159
+
QClass::Multicast => "Multicast",
160
+
QClass::Unknown(_) => "Unknown",
161
+
};
162
+
defmt::write!(fmt, "QClass({=str})", qclass_str);
163
+
}
164
+
}
165
+
166
+
#[cfg(feature = "defmt")]
167
+
impl<'a> defmt::Format for Answer<'a> {
168
+
fn format(&self, fmt: defmt::Formatter) {
169
+
defmt::write!(
170
+
fmt,
171
+
"Answer {{ name: {:?}, atype: {:?}, aclass: {:?}, ttl: {}, record: {:?} }}",
172
+
self.name,
173
+
self.atype,
174
+
self.aclass,
175
+
self.ttl,
176
+
self.record
177
+
);
178
+
}
179
+
}
180
+
181
+
#[cfg(test)]
182
+
mod tests {
183
+
use super::*;
184
+
use crate::dns::records::A;
185
+
use core::net::Ipv4Addr;
186
+
187
+
#[test]
188
+
fn roundtrip_query() {
189
+
let name = Label::from("example.local");
190
+
191
+
let query = Query {
192
+
name,
193
+
qtype: QType::A,
194
+
qclass: QClass::IN,
195
+
};
196
+
197
+
let mut buffer = [0u8; 256];
198
+
let mut buffer = Encoder::new(&mut buffer);
199
+
query.serialize(&mut buffer).unwrap();
200
+
let buffer = buffer.finish();
201
+
let parsed_query = Query::parse(&mut &buffer[..], buffer).unwrap();
202
+
203
+
assert_eq!(query, parsed_query);
204
+
}
205
+
206
+
#[test]
207
+
fn roundtrip_answer() {
208
+
let name = Label::from("example.local");
209
+
210
+
let answer: Answer = Answer {
211
+
name,
212
+
atype: QType::A,
213
+
aclass: QClass::IN,
214
+
ttl: 120,
215
+
record: Record::A(A {
216
+
address: Ipv4Addr::new(192, 168, 1, 1),
217
+
}),
218
+
};
219
+
220
+
let mut buffer = [0u8; 256];
221
+
let mut buffer = Encoder::new(&mut buffer);
222
+
answer.serialize(&mut buffer).unwrap();
223
+
let buffer = buffer.finish();
224
+
let parsed_answer = Answer::parse(&mut &buffer[..], buffer).unwrap();
225
+
226
+
assert_eq!(answer, parsed_answer);
227
+
}
228
+
}
+413
sachy-mdns/src/dns/records.rs
+413
sachy-mdns/src/dns/records.rs
···
···
1
+
use core::{
2
+
convert::Infallible,
3
+
net::{Ipv4Addr, Ipv6Addr},
4
+
str,
5
+
};
6
+
7
+
use alloc::vec::Vec;
8
+
use winnow::token::take;
9
+
use winnow::{ModalResult, Parser};
10
+
use winnow::{binary::be_u8, error::ContextError};
11
+
use winnow::{binary::be_u16, error::FromExternalError};
12
+
13
+
use super::label::Label;
14
+
use crate::{
15
+
dns::traits::{DnsParse, DnsParseKind, DnsSerialize},
16
+
encoder::{DnsError, Encoder},
17
+
};
18
+
19
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20
+
#[repr(u16)]
21
+
#[allow(clippy::upper_case_acronyms)]
22
+
pub enum QType {
23
+
A = 1,
24
+
AAAA = 28,
25
+
PTR = 12,
26
+
TXT = 16,
27
+
SRV = 33,
28
+
Any = 255,
29
+
Unknown(u16),
30
+
}
31
+
32
+
impl<'a> DnsParse<'a> for QType {
33
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
34
+
be_u16.map(QType::from_u16).parse_next(input)
35
+
}
36
+
}
37
+
38
+
impl<'a> DnsSerialize<'a> for QType {
39
+
type Error = Infallible;
40
+
41
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
42
+
encoder.write(&self.to_u16().to_be_bytes());
43
+
Ok(())
44
+
}
45
+
46
+
fn size(&self) -> usize {
47
+
core::mem::size_of::<QType>()
48
+
}
49
+
}
50
+
51
+
impl<'a> DnsParseKind<'a> for QType {
52
+
type Output = Record<'a>;
53
+
54
+
fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output> {
55
+
match self {
56
+
QType::A => {
57
+
let record = A::parse(input, context)?;
58
+
Ok(Record::A(record))
59
+
}
60
+
QType::AAAA => {
61
+
let record = AAAA::parse(input, context)?;
62
+
Ok(Record::AAAA(record))
63
+
}
64
+
QType::PTR => {
65
+
let record = PTR::parse(input, context)?;
66
+
Ok(Record::PTR(record))
67
+
}
68
+
QType::TXT => {
69
+
let record = TXT::parse(input, context)?;
70
+
Ok(Record::TXT(record))
71
+
}
72
+
QType::SRV => {
73
+
let record = SRV::parse(input, context)?;
74
+
Ok(Record::SRV(record))
75
+
}
76
+
QType::Any => Err(winnow::error::ErrMode::Backtrack(
77
+
ContextError::from_external_error(input, DnsError::Unsupported),
78
+
)),
79
+
QType::Unknown(_) => Err(winnow::error::ErrMode::Backtrack(
80
+
ContextError::from_external_error(input, DnsError::Unsupported),
81
+
)),
82
+
}
83
+
}
84
+
}
85
+
86
+
impl QType {
87
+
fn from_u16(value: u16) -> Self {
88
+
match value {
89
+
1 => QType::A,
90
+
28 => QType::AAAA,
91
+
12 => QType::PTR,
92
+
16 => QType::TXT,
93
+
33 => QType::SRV,
94
+
255 => QType::Any,
95
+
_ => QType::Unknown(value),
96
+
}
97
+
}
98
+
99
+
fn to_u16(self) -> u16 {
100
+
match self {
101
+
QType::A => 1,
102
+
QType::AAAA => 28,
103
+
QType::PTR => 12,
104
+
QType::TXT => 16,
105
+
QType::SRV => 33,
106
+
QType::Any => 255,
107
+
QType::Unknown(value) => value,
108
+
}
109
+
}
110
+
}
111
+
112
+
#[derive(Debug, PartialEq, Eq)]
113
+
#[allow(clippy::upper_case_acronyms)]
114
+
// Enum for DNS-SD records
115
+
pub enum Record<'a> {
116
+
A(A),
117
+
AAAA(AAAA),
118
+
PTR(PTR<'a>),
119
+
TXT(TXT<'a>),
120
+
SRV(SRV<'a>),
121
+
}
122
+
123
+
impl<'a> DnsSerialize<'a> for Record<'a> {
124
+
type Error = DnsError;
125
+
126
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
127
+
match self {
128
+
Record::A(record) => {
129
+
record.serialize(encoder).ok();
130
+
}
131
+
Record::AAAA(record) => {
132
+
record.serialize(encoder).ok();
133
+
}
134
+
Record::PTR(record) => {
135
+
record.serialize(encoder)?;
136
+
}
137
+
Record::TXT(record) => {
138
+
record.serialize(encoder).ok();
139
+
}
140
+
Record::SRV(record) => {
141
+
record.serialize(encoder)?;
142
+
}
143
+
};
144
+
145
+
Ok(())
146
+
}
147
+
148
+
fn size(&self) -> usize {
149
+
match self {
150
+
Self::A(a) => a.size(),
151
+
Self::AAAA(aaaa) => aaaa.size(),
152
+
Self::PTR(ptr) => ptr.size(),
153
+
Self::TXT(txt) => txt.size(),
154
+
Self::SRV(srv) => srv.size(),
155
+
}
156
+
}
157
+
}
158
+
159
+
// Struct for A record
160
+
#[derive(Debug, PartialEq, Eq)]
161
+
pub struct A {
162
+
pub address: Ipv4Addr,
163
+
}
164
+
165
+
impl<'a> DnsParse<'a> for A {
166
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
167
+
let len = be_u16.parse_next(input)?;
168
+
let address = take(len)
169
+
.try_map(<[u8; 4]>::try_from)
170
+
.map(Ipv4Addr::from)
171
+
.parse_next(input)?;
172
+
173
+
Ok(A { address })
174
+
}
175
+
}
176
+
177
+
impl<'a> DnsSerialize<'a> for A {
178
+
type Error = Infallible;
179
+
180
+
fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> {
181
+
let len = 4u16.to_be_bytes();
182
+
writer.write(&len);
183
+
writer.write(&self.address.octets());
184
+
Ok(())
185
+
}
186
+
187
+
fn size(&self) -> usize {
188
+
core::mem::size_of::<Ipv4Addr>() + core::mem::size_of::<u16>()
189
+
}
190
+
}
191
+
192
+
// Struct for AAAA record
193
+
#[derive(Debug, PartialEq, Eq)]
194
+
#[allow(clippy::upper_case_acronyms)]
195
+
pub struct AAAA {
196
+
pub address: Ipv6Addr,
197
+
}
198
+
199
+
impl<'a> DnsParse<'a> for AAAA {
200
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
201
+
let len = be_u16.parse_next(input)?;
202
+
let address = take(len)
203
+
.try_map(<[u8; 16]>::try_from)
204
+
.map(Ipv6Addr::from)
205
+
.parse_next(input)?;
206
+
207
+
Ok(AAAA { address })
208
+
}
209
+
}
210
+
211
+
impl<'a> DnsSerialize<'a> for AAAA {
212
+
type Error = Infallible;
213
+
214
+
fn serialize(&self, writer: &mut Encoder<'_, '_>) -> Result<(), Self::Error> {
215
+
let len = 16u16.to_be_bytes();
216
+
writer.write(&len);
217
+
writer.write(&self.address.octets());
218
+
Ok(())
219
+
}
220
+
221
+
fn size(&self) -> usize {
222
+
core::mem::size_of::<Ipv6Addr>() + core::mem::size_of::<u16>()
223
+
}
224
+
}
225
+
226
+
// Struct for PTR record
227
+
#[derive(Debug, PartialEq, Eq)]
228
+
#[allow(clippy::upper_case_acronyms)]
229
+
pub struct PTR<'a> {
230
+
pub name: Label<'a>,
231
+
}
232
+
233
+
impl<'a> DnsParse<'a> for PTR<'a> {
234
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
235
+
let _ = be_u16.parse_next(input)?;
236
+
let name = Label::parse(input, context)?;
237
+
Ok(PTR { name })
238
+
}
239
+
}
240
+
241
+
impl<'a> DnsSerialize<'a> for PTR<'a> {
242
+
type Error = DnsError;
243
+
244
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
245
+
encoder.with_record_length(|enc| self.name.serialize(enc))
246
+
}
247
+
248
+
fn size(&self) -> usize {
249
+
self.name.size() + core::mem::size_of::<u16>()
250
+
}
251
+
}
252
+
253
+
// Struct for TXT record
254
+
#[derive(Debug, PartialEq, Eq)]
255
+
#[allow(clippy::upper_case_acronyms)]
256
+
pub struct TXT<'a> {
257
+
pub text: Vec<&'a str>,
258
+
}
259
+
260
+
impl<'a> DnsParse<'a> for TXT<'a> {
261
+
fn parse(input: &mut &'a [u8], _context: &'a [u8]) -> ModalResult<Self> {
262
+
let text_len = be_u16.parse_next(input)?;
263
+
264
+
let mut total = 0u16;
265
+
let mut text = Vec::new();
266
+
267
+
while total < text_len {
268
+
let len = be_u8(input)?;
269
+
270
+
total += 1 + len as u16;
271
+
272
+
if len > 0 {
273
+
let part = take(len).try_map(core::str::from_utf8).parse_next(input)?;
274
+
text.push(part);
275
+
}
276
+
}
277
+
278
+
Ok(TXT { text })
279
+
}
280
+
}
281
+
282
+
impl<'a> DnsSerialize<'a> for TXT<'a> {
283
+
type Error = DnsError;
284
+
285
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
286
+
encoder.with_record_length(|enc| {
287
+
self.text.iter().try_for_each(|&part| {
288
+
let text_len = u8::try_from(part.len())
289
+
.map_err(|_| DnsError::InvalidTxt)
290
+
.map(u8::to_be_bytes)?;
291
+
292
+
enc.write(&text_len);
293
+
enc.write(part.as_bytes());
294
+
295
+
Ok(())
296
+
})
297
+
})
298
+
}
299
+
300
+
fn size(&self) -> usize {
301
+
let len_size = core::mem::size_of::<u16>();
302
+
303
+
let text_size = if self.text.is_empty() {
304
+
1
305
+
} else {
306
+
self.text.iter().map(|part| part.len() + 1).sum()
307
+
};
308
+
309
+
len_size + text_size
310
+
}
311
+
}
312
+
313
+
// Struct for SRV record
314
+
#[derive(Debug, PartialEq, Eq)]
315
+
#[allow(clippy::upper_case_acronyms)]
316
+
pub struct SRV<'a> {
317
+
pub priority: u16,
318
+
pub weight: u16,
319
+
pub port: u16,
320
+
pub target: Label<'a>,
321
+
}
322
+
323
+
impl<'a> DnsParse<'a> for SRV<'a> {
324
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
325
+
let _ = be_u16.parse_next(input)?;
326
+
let priority = be_u16.parse_next(input)?;
327
+
let weight = be_u16.parse_next(input)?;
328
+
let port = be_u16.parse_next(input)?;
329
+
let target = Label::parse(input, context)?;
330
+
331
+
Ok(SRV {
332
+
priority,
333
+
weight,
334
+
port,
335
+
target,
336
+
})
337
+
}
338
+
}
339
+
340
+
impl<'a> DnsSerialize<'a> for SRV<'a> {
341
+
type Error = DnsError;
342
+
343
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
344
+
encoder.with_record_length(|enc| {
345
+
enc.write(&self.priority.to_be_bytes());
346
+
enc.write(&self.weight.to_be_bytes());
347
+
enc.write(&self.port.to_be_bytes());
348
+
349
+
self.target.serialize(enc)
350
+
})
351
+
}
352
+
353
+
fn size(&self) -> usize {
354
+
(core::mem::size_of::<u16>() * 4) + self.target.size()
355
+
}
356
+
}
357
+
358
+
#[cfg(feature = "defmt")]
359
+
impl defmt::Format for A {
360
+
fn format(&self, fmt: defmt::Formatter) {
361
+
// use crate::format::FormatIpv4Addr;
362
+
defmt::write!(fmt, "A({})", self.address)
363
+
}
364
+
}
365
+
366
+
#[cfg(feature = "defmt")]
367
+
impl defmt::Format for AAAA {
368
+
fn format(&self, fmt: defmt::Formatter) {
369
+
// use crate::format::FormatIpv6Addr;
370
+
defmt::write!(fmt, "AAAA({})", self.address)
371
+
}
372
+
}
373
+
374
+
#[cfg(feature = "defmt")]
375
+
impl<'a> defmt::Format for Record<'a> {
376
+
fn format(&self, fmt: defmt::Formatter) {
377
+
match self {
378
+
Record::A(record) => defmt::write!(fmt, "Record::A({:?})", record),
379
+
Record::AAAA(record) => defmt::write!(fmt, "Record::AAAA({:?})", record),
380
+
Record::PTR(record) => defmt::write!(fmt, "Record::PTR({:?})", record),
381
+
Record::TXT(record) => defmt::write!(fmt, "Record::TXT({:?})", record),
382
+
Record::SRV(record) => defmt::write!(fmt, "Record::SRV({:?})", record),
383
+
}
384
+
}
385
+
}
386
+
387
+
#[cfg(feature = "defmt")]
388
+
impl<'a> defmt::Format for PTR<'a> {
389
+
fn format(&self, fmt: defmt::Formatter) {
390
+
defmt::write!(fmt, "PTR {{ name: {:?} }}", self.name);
391
+
}
392
+
}
393
+
394
+
#[cfg(feature = "defmt")]
395
+
impl<'a> defmt::Format for TXT<'a> {
396
+
fn format(&self, fmt: defmt::Formatter) {
397
+
defmt::write!(fmt, "TXT {{ text: {:?} }}", self.text);
398
+
}
399
+
}
400
+
401
+
#[cfg(feature = "defmt")]
402
+
impl<'a> defmt::Format for SRV<'a> {
403
+
fn format(&self, fmt: defmt::Formatter) {
404
+
defmt::write!(
405
+
fmt,
406
+
"SRV {{ priority: {}, weight: {}, port: {}, target: {:?} }}",
407
+
self.priority,
408
+
self.weight,
409
+
self.port,
410
+
self.target
411
+
);
412
+
}
413
+
}
+481
sachy-mdns/src/dns/reqres.rs
+481
sachy-mdns/src/dns/reqres.rs
···
···
1
+
use alloc::vec::Vec;
2
+
use winnow::ModalResult;
3
+
use winnow::binary::be_u16;
4
+
5
+
use super::flags::Flags;
6
+
use super::query::{Answer, Query};
7
+
use crate::{
8
+
dns::traits::{DnsParse, DnsSerialize},
9
+
encoder::{DnsError, Encoder},
10
+
};
11
+
12
+
const ZERO_U16: [u8; 2] = 0u16.to_be_bytes();
13
+
14
+
#[derive(Debug, PartialEq, Eq)]
15
+
pub struct Request<'a> {
16
+
pub id: u16,
17
+
pub flags: Flags,
18
+
pub(crate) queries: Vec<Query<'a>>,
19
+
}
20
+
21
+
impl<'a> DnsParse<'a> for Request<'a> {
22
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
23
+
let id = be_u16(input)?;
24
+
let flags = Flags::parse(input, context)?;
25
+
let qdcount = be_u16(input)?;
26
+
let _ancount = be_u16(input)?;
27
+
let _nscount = be_u16(input)?;
28
+
let _arcount = be_u16(input)?;
29
+
let queries = (0..qdcount)
30
+
.map(|_| Query::parse(input, context))
31
+
.collect::<Result<Vec<_>, _>>()?;
32
+
Ok(Request { id, flags, queries })
33
+
}
34
+
}
35
+
36
+
impl<'a> DnsSerialize<'a> for Request<'a> {
37
+
type Error = DnsError;
38
+
39
+
fn serialize<'b>(&self, writer: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
40
+
writer.write(&self.id.to_be_bytes());
41
+
self.flags.serialize(writer).ok();
42
+
writer.write(&(self.queries.len() as u16).to_be_bytes());
43
+
writer.write(&ZERO_U16);
44
+
writer.write(&ZERO_U16);
45
+
writer.write(&ZERO_U16);
46
+
47
+
self.queries
48
+
.iter()
49
+
.try_for_each(|query| query.serialize(writer))
50
+
}
51
+
52
+
fn size(&self) -> usize {
53
+
let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum();
54
+
55
+
core::mem::size_of::<u16>()
56
+
+ self.flags.size()
57
+
+ (core::mem::size_of::<u16>() * 4)
58
+
+ total_query_size
59
+
}
60
+
}
61
+
62
+
#[derive(Debug, PartialEq, Eq)]
63
+
pub struct Response<'a> {
64
+
pub id: u16,
65
+
pub flags: Flags,
66
+
pub queries: Vec<Query<'a>>,
67
+
pub answers: Vec<Answer<'a>>,
68
+
}
69
+
70
+
impl<'a> DnsParse<'a> for Response<'a> {
71
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self> {
72
+
let id = be_u16(input)?;
73
+
let flags = Flags::parse(input, context)?;
74
+
let qdcount = be_u16(input)?;
75
+
let ancount = be_u16(input)?;
76
+
let _nscount = be_u16(input)?;
77
+
let _arcount = be_u16(input)?;
78
+
79
+
let queries = (0..qdcount)
80
+
.map(|_| Query::parse(input, context))
81
+
.collect::<Result<Vec<_>, _>>()?;
82
+
83
+
let answers = (0..ancount)
84
+
.map(|_| Answer::parse(input, context))
85
+
.collect::<Result<Vec<_>, _>>()?;
86
+
87
+
Ok(Response {
88
+
id,
89
+
flags,
90
+
queries,
91
+
answers,
92
+
})
93
+
}
94
+
}
95
+
96
+
impl<'a> DnsSerialize<'a> for Response<'a> {
97
+
type Error = DnsError;
98
+
99
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error> {
100
+
encoder.write(&self.id.to_be_bytes());
101
+
self.flags.serialize(encoder).ok();
102
+
encoder.write(&(self.queries.len() as u16).to_be_bytes());
103
+
encoder.write(&(self.answers.len() as u16).to_be_bytes());
104
+
encoder.write(&ZERO_U16);
105
+
encoder.write(&ZERO_U16);
106
+
107
+
self.queries
108
+
.iter()
109
+
.try_for_each(|query| query.serialize(encoder))?;
110
+
self.answers
111
+
.iter()
112
+
.try_for_each(|answer| answer.serialize(encoder))
113
+
}
114
+
115
+
fn size(&self) -> usize {
116
+
let total_query_size: usize = self.queries.iter().map(DnsSerialize::size).sum();
117
+
let total_answer_size: usize = self.answers.iter().map(DnsSerialize::size).sum();
118
+
119
+
core::mem::size_of::<u16>()
120
+
+ self.flags.size()
121
+
+ (core::mem::size_of::<u16>() * 4)
122
+
+ total_query_size
123
+
+ total_answer_size
124
+
}
125
+
}
126
+
127
+
#[cfg(feature = "defmt")]
128
+
impl<'a> defmt::Format for Request<'a> {
129
+
fn format(&self, fmt: defmt::Formatter) {
130
+
defmt::write!(
131
+
fmt,
132
+
"Request {{ id: {}, flags: {:?}, queries: {:?} }}",
133
+
self.id,
134
+
self.flags,
135
+
self.queries
136
+
);
137
+
}
138
+
}
139
+
140
+
#[cfg(feature = "defmt")]
141
+
impl<'a> defmt::Format for Response<'a> {
142
+
fn format(&self, fmt: defmt::Formatter) {
143
+
defmt::write!(
144
+
fmt,
145
+
"Response {{ id: {}, flags: {:?}, queries: {:?}, answers: {:?} }}",
146
+
self.id,
147
+
self.flags,
148
+
self.queries,
149
+
self.answers
150
+
);
151
+
}
152
+
}
153
+
154
+
#[cfg(test)]
155
+
mod tests {
156
+
use alloc::vec;
157
+
158
+
use super::*;
159
+
use crate::dns::{
160
+
label::Label,
161
+
query::QClass,
162
+
records::{A, PTR, QType, Record, SRV, TXT},
163
+
};
164
+
use core::net::Ipv4Addr;
165
+
166
+
#[test]
167
+
fn parse_query() {
168
+
let data = [
169
+
0xAA, 0xAA, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x07, 0x65,
170
+
// example . com in label format
171
+
0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
172
+
//
173
+
0x00, 0x01, 0x00, 0x01,
174
+
];
175
+
176
+
let request = Request::parse(&mut data.as_slice(), data.as_slice()).unwrap();
177
+
178
+
assert_eq!(request.id, 0xAAAA);
179
+
assert_eq!(request.flags.0, 0x0100);
180
+
assert_eq!(request.queries.len(), 1);
181
+
assert_eq!(request.queries[0].name, "example.com");
182
+
assert_eq!(request.queries[0].qtype, QType::A);
183
+
assert_eq!(request.queries[0].qclass, QClass::IN);
184
+
}
185
+
186
+
#[test]
187
+
fn parse_response() {
188
+
let data = [
189
+
0xAA, 0xAA, // transaction ID
190
+
0x81, 0x80, // flags
191
+
0x00, 0x01, // 1 question
192
+
0x00, 0x01, // 1 A-answer
193
+
0x00, 0x00, // no authority
194
+
0x00, 0x00, // no additional answers
195
+
// example . com in label format
196
+
0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
197
+
//
198
+
0x00, 0x01, 0x00, 0x01, //
199
+
//
200
+
0xC0, 0x0C, // ptr to question section
201
+
//
202
+
0x00, 0x01, 0x00, 0x01, // A and IN
203
+
//
204
+
0x00, 0x00, 0x00, 0x3C, // TTL 60 seconds
205
+
//
206
+
0x00, 0x04, // length of address
207
+
// IP address:
208
+
192, 168, 1, 3,
209
+
];
210
+
211
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
212
+
213
+
assert_eq!(response.id, 0xAAAA);
214
+
assert_eq!(response.flags.0, 0x8180);
215
+
assert_eq!(response.answers.len(), 1);
216
+
assert_eq!(response.answers[0].name, "example.com");
217
+
assert_eq!(response.answers[0].atype, QType::A);
218
+
assert_eq!(response.answers[0].aclass, QClass::IN);
219
+
assert_eq!(response.answers[0].ttl, 60);
220
+
if let Record::A(a) = &response.answers[0].record {
221
+
assert_eq!(a.address, Ipv4Addr::new(192, 168, 1, 3));
222
+
} else {
223
+
panic!("Expected A record");
224
+
}
225
+
}
226
+
227
+
#[test]
228
+
fn parse_response_two_records() {
229
+
let data = [
230
+
0xAA, 0xAA, //
231
+
0x81, 0x80, //
232
+
0x00, 0x01, //
233
+
0x00, 0x02, //
234
+
0x00, 0x00, //
235
+
0x00, 0x00, //
236
+
// example . com in label format
237
+
0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
238
+
//
239
+
0x00, 0x01, // query type
240
+
0x00, 0x01, // query class
241
+
//
242
+
0xC0, 0x0C, // pointer
243
+
0x00, 0x01, //
244
+
0x00, 0x01, //
245
+
0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds
246
+
0x00, 0x04, // length of A-record
247
+
0x5D, 0xB8, 0xD8, 0x22, // a-record
248
+
//
249
+
0x07, 0x65, 0x78, 0x61, 0x6D, 0x70, 0x6C, 0x65, 0x03, 0x63, 0x6F, 0x6D, 0x00, //
250
+
//
251
+
0x00, 0x10, // TXT
252
+
0x00, 0x01, // IN
253
+
//
254
+
0x00, 0x00, 0x00, 0x3C, // ttl 60 seconds
255
+
//
256
+
0x00, 0x10, // length of txt record
257
+
// (len) "test txt record"
258
+
0x0F, 0x74, 0x65, 0x73, 0x74, 0x20, 0x74, 0x78, 0x74, 0x20, 0x72, 0x65, 0x63, 0x6F, 0x72,
259
+
0x64,
260
+
];
261
+
262
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
263
+
264
+
assert_eq!(response.id, 0xAAAA);
265
+
assert_eq!(response.flags.0, 0x8180);
266
+
assert_eq!(response.answers.len(), 2);
267
+
268
+
// First answer
269
+
assert_eq!(response.answers[0].name, "example.com");
270
+
assert_eq!(response.answers[0].atype, QType::A);
271
+
assert_eq!(response.answers[0].aclass, QClass::IN);
272
+
assert_eq!(response.answers[0].ttl, 60);
273
+
if let Record::A(a) = &response.answers[0].record {
274
+
assert_eq!(a.address, Ipv4Addr::new(93, 184, 216, 34));
275
+
} else {
276
+
panic!("Expected A record");
277
+
}
278
+
279
+
// Second answer
280
+
assert_eq!(response.answers[1].name, "example.com");
281
+
assert_eq!(response.answers[1].atype, QType::TXT);
282
+
assert_eq!(response.answers[1].aclass, QClass::IN);
283
+
assert_eq!(response.answers[1].ttl, 60);
284
+
if let Record::TXT(txt) = &response.answers[1].record
285
+
&& let Some(&text) = txt.text.first()
286
+
{
287
+
assert_eq!(text, "test txt record");
288
+
} else {
289
+
panic!("Expected TXT record");
290
+
}
291
+
}
292
+
293
+
#[test]
294
+
fn parse_response_srv() {
295
+
let data = [
296
+
//
297
+
0xAA, 0xAA, // id
298
+
0x81, 0x80, // flags
299
+
0x00, 0x01, // one question
300
+
0x00, 0x01, // one answer
301
+
0x00, 0x00, // no authority
302
+
0x00, 0x00, // no extra
303
+
//
304
+
0x04, 0x5f, 0x73, 0x69, 0x70, 0x04, 0x5f, 0x74, 0x63, 0x70, 0x07, 0x65, 0x78, 0x61,
305
+
0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, //
306
+
//
307
+
0x00, 0x21, // type SRV
308
+
0x00, 0x01, // IN
309
+
//
310
+
0xc0, 0x0c, //
311
+
//
312
+
0x00, 0x21, // SRV
313
+
0x00, 0x01, // IN
314
+
0x00, 0x00, 0x00, 0x3C, // ttl 60
315
+
//
316
+
0x00, 0x19, // data len
317
+
0x00, 0x0A, // prio
318
+
0x00, 0x05, // weight
319
+
0x13, 0xC4, // PORT
320
+
//
321
+
0x09, 0x73, 0x69, 0x70, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x07, 0x65, 0x78, 0x61,
322
+
0x6d, 0x70, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, //
323
+
];
324
+
325
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
326
+
327
+
assert_eq!(response.id, 0xAAAA);
328
+
assert_eq!(response.flags.0, 0x8180);
329
+
assert_eq!(response.answers.len(), 1);
330
+
331
+
// Answer
332
+
assert_eq!(response.answers[0].name, "_sip._tcp.example.com");
333
+
assert_eq!(response.answers[0].atype, QType::SRV);
334
+
assert_eq!(response.answers[0].aclass, QClass::IN);
335
+
assert_eq!(response.answers[0].ttl, 60);
336
+
let Record::SRV(srv) = &response.answers[0].record else {
337
+
panic!("Expected SRV record");
338
+
};
339
+
340
+
assert_eq!(srv.priority, 10);
341
+
assert_eq!(srv.weight, 5);
342
+
assert_eq!(srv.port, 5060);
343
+
assert_eq!(srv.target, "sipserver.example.com");
344
+
}
345
+
346
+
#[test]
347
+
fn parse_response_back_forth() {
348
+
let data = [
349
+
0, 0, // Transaction ID
350
+
132, 0, // Response, Authoritative Answer, No Recursion
351
+
0, 0, // 0 questions
352
+
0, 4, // 4 answers
353
+
0, 0, // 0 authority RRs
354
+
0, 0, // 0 additional RRs
355
+
// _midiriff
356
+
9, 95, 109, 105, 100, 105, 114, 105, 102, 102, //
357
+
// _udp
358
+
4, 95, 117, 100, 112, //
359
+
// local
360
+
5, 108, 111, 99, 97, 108, //
361
+
0, // <end>
362
+
//
363
+
0, 12, // PTR
364
+
0, 1, // Class IN
365
+
0, 0, 0, 120, // TTL 120 seconds
366
+
0, 10, // Data Length 10
367
+
// pi35291
368
+
7, 112, 105, 51, 53, 50, 57, 49, //
369
+
//
370
+
192, 12, // Pointer to _midirif._udp._local.
371
+
//
372
+
192, 44, // Pointer to instace name: pi35291._midirif._udp._local.
373
+
0, 33, // SRV
374
+
128, 1, // IN (Cache flush bit set)
375
+
0, 0, 0, 120, // TTL 120 seconds
376
+
0, 11, // Data Length 11
377
+
0, 0, // Priority 0
378
+
0, 0, // Weight 0
379
+
137, 219, // Port 35291
380
+
2, 112, 105, // _pi
381
+
192, 27, // Pointer to: .local.
382
+
// TXT (Empty)
383
+
192, 44, 0, 16, 128, 1, 0, 0, 17, 148, 0, 1, 0,
384
+
// A (10.1.1.9)
385
+
192, 72, 0, 1, 128, 1, 0, 0, 0, 120, 0, 4, 10, 1, 1, 9,
386
+
];
387
+
388
+
let response = Response::parse(&mut data.as_slice(), data.as_slice()).unwrap();
389
+
390
+
assert_eq!(response.answers[0].name, "_midiriff._udp.local");
391
+
assert_eq!(response.answers[0].ttl, 120);
392
+
let Record::PTR(ptr) = &response.answers[0].record else {
393
+
panic!()
394
+
};
395
+
assert_eq!(ptr.name, "pi35291._midiriff._udp.local");
396
+
397
+
let mut buffer = [0u8; 256];
398
+
let mut buffer = Encoder::new(&mut buffer);
399
+
response.serialize(&mut buffer).unwrap();
400
+
401
+
let buffer = buffer.finish();
402
+
403
+
let response2 = Response::parse(&mut &buffer[..], buffer).unwrap();
404
+
405
+
assert_eq!(response, response2);
406
+
}
407
+
408
+
#[test]
409
+
fn mdns_service_response() {
410
+
let mut response = Response {
411
+
id: 0x1234,
412
+
flags: Flags::standard_response(),
413
+
queries: Vec::new(),
414
+
answers: Vec::new(),
415
+
};
416
+
417
+
let query = Query {
418
+
name: Label::from("_test._udp.local"),
419
+
qtype: QType::PTR,
420
+
qclass: QClass::IN,
421
+
};
422
+
response.queries.push(query);
423
+
424
+
let ptr_answer = Answer {
425
+
name: Label::from("_test._udp.local"),
426
+
atype: QType::PTR,
427
+
aclass: QClass::IN,
428
+
ttl: 4500,
429
+
record: Record::PTR(PTR {
430
+
name: Label::from("test-service._test._udp.local"),
431
+
}),
432
+
};
433
+
response.answers.push(ptr_answer);
434
+
435
+
let srv_answer = Answer {
436
+
name: Label::from("test-service._test._udp.local"),
437
+
atype: QType::SRV,
438
+
aclass: QClass::IN,
439
+
ttl: 120,
440
+
record: Record::SRV(SRV {
441
+
priority: 0,
442
+
weight: 0,
443
+
port: 8080,
444
+
target: Label::from("host.local"),
445
+
}),
446
+
};
447
+
response.answers.push(srv_answer);
448
+
449
+
let txt_answer = Answer {
450
+
name: Label::from("test-service._test._udp.local"),
451
+
atype: QType::TXT,
452
+
aclass: QClass::IN,
453
+
ttl: 120,
454
+
record: Record::TXT(TXT {
455
+
text: vec!["path=/test"],
456
+
}),
457
+
};
458
+
response.answers.push(txt_answer);
459
+
460
+
let a_answer = Answer {
461
+
name: Label::from("host.local"),
462
+
atype: QType::A,
463
+
aclass: QClass::IN,
464
+
ttl: 120,
465
+
record: Record::A(A {
466
+
address: Ipv4Addr::new(192, 168, 1, 100),
467
+
}),
468
+
};
469
+
response.answers.push(a_answer);
470
+
471
+
let mut buffer = [0u8; 256];
472
+
let mut buffer = Encoder::new(&mut buffer);
473
+
response.serialize(&mut buffer).unwrap();
474
+
475
+
let buffer = buffer.finish();
476
+
477
+
let parsed_response = Response::parse(&mut &buffer[..], buffer).unwrap();
478
+
479
+
assert_eq!(response, parsed_response);
480
+
}
481
+
}
+21
sachy-mdns/src/dns/traits.rs
+21
sachy-mdns/src/dns/traits.rs
···
···
1
+
use winnow::ModalResult;
2
+
3
+
use crate::encoder::Encoder;
4
+
5
+
pub trait DnsParse<'a>: Sized {
6
+
fn parse(input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self>;
7
+
}
8
+
9
+
pub trait DnsParseKind<'a> {
10
+
type Output;
11
+
12
+
fn parse_kind(&self, input: &mut &'a [u8], context: &'a [u8]) -> ModalResult<Self::Output>;
13
+
}
14
+
15
+
pub trait DnsSerialize<'a> {
16
+
type Error;
17
+
18
+
fn serialize<'b>(&self, encoder: &mut Encoder<'a, 'b>) -> Result<(), Self::Error>;
19
+
#[allow(dead_code)]
20
+
fn size(&self) -> usize;
21
+
}
+145
sachy-mdns/src/encoder.rs
+145
sachy-mdns/src/encoder.rs
···
···
1
+
use core::ops::Range;
2
+
3
+
use alloc::collections::BTreeMap;
4
+
5
+
use crate::dns::traits::DnsSerialize;
6
+
7
+
pub(crate) const MAX_STR_LEN: u8 = !PTR_MASK;
8
+
pub(crate) const PTR_MASK: u8 = 0b1100_0000;
9
+
10
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
12
+
pub enum DnsError {
13
+
LabelTooLong,
14
+
InvalidTxt,
15
+
Unsupported,
16
+
}
17
+
18
+
impl core::fmt::Display for DnsError {
19
+
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
20
+
match self {
21
+
Self::LabelTooLong => f.write_str("Encoding Error: Segment too long"),
22
+
Self::InvalidTxt => f.write_str("Encoding Error: TXT segment is invalid"),
23
+
Self::Unsupported => f.write_str("Encoding Error: Unsupported Record Type"),
24
+
}
25
+
}
26
+
}
27
+
28
+
impl core::error::Error for DnsError {}
29
+
30
+
#[derive(Debug)]
31
+
pub struct Encoder<'a, 'b> {
32
+
output: &'b mut [u8],
33
+
position: usize,
34
+
lookup: BTreeMap<&'a str, u16>,
35
+
reservation: Option<usize>,
36
+
}
37
+
38
+
impl<'a, 'b> Encoder<'a, 'b> {
39
+
pub const fn new(buffer: &'b mut [u8]) -> Self {
40
+
Self {
41
+
output: buffer,
42
+
position: 0,
43
+
lookup: BTreeMap::new(),
44
+
reservation: None,
45
+
}
46
+
}
47
+
48
+
/// Takes a payload and encodes it, consuming the encoder and yielding the resulting
49
+
/// slice.
50
+
pub fn encode<T, E>(mut self, payload: T) -> Result<&'b [u8], E>
51
+
where
52
+
E: core::error::Error,
53
+
T: DnsSerialize<'a, Error = E>,
54
+
{
55
+
payload.serialize(&mut self)?;
56
+
Ok(self.finish())
57
+
}
58
+
59
+
pub(crate) fn finish(self) -> &'b [u8] {
60
+
&self.output[..self.position]
61
+
}
62
+
63
+
fn increment(&mut self, amount: usize) {
64
+
self.position += amount;
65
+
}
66
+
67
+
pub(crate) fn write_label(&mut self, mut label: &'a str) -> Result<(), DnsError> {
68
+
loop {
69
+
if let Some(pos) = self.get_label_position(label) {
70
+
let [b1, b2] = u16::to_be_bytes(pos);
71
+
self.write(&[b1 | PTR_MASK, b2]);
72
+
return Ok(());
73
+
}
74
+
75
+
let dot = label.find('.');
76
+
77
+
let end = dot.unwrap_or(label.len());
78
+
let segment = &label[..end];
79
+
let len = u8::try_from(segment.len()).map_err(|_| DnsError::LabelTooLong)?;
80
+
81
+
if len > MAX_STR_LEN {
82
+
return Err(DnsError::LabelTooLong);
83
+
}
84
+
85
+
self.store_label_position(label);
86
+
self.write(&len.to_be_bytes());
87
+
self.write(segment.as_bytes());
88
+
89
+
match dot {
90
+
Some(end) => {
91
+
label = &label[end + 1..];
92
+
}
93
+
None => {
94
+
self.write(&[0]);
95
+
return Ok(());
96
+
}
97
+
}
98
+
}
99
+
}
100
+
101
+
pub(crate) fn write(&mut self, bytes: &[u8]) {
102
+
let len = bytes.len();
103
+
let end = self.position + len;
104
+
self.output[self.position..end].copy_from_slice(bytes);
105
+
self.increment(len);
106
+
}
107
+
108
+
fn get_label_position(&mut self, label: &str) -> Option<u16> {
109
+
self.lookup.get(label).copied()
110
+
}
111
+
112
+
fn store_label_position(&mut self, label: &'a str) {
113
+
self.lookup.insert(label, self.position as u16);
114
+
}
115
+
116
+
fn reserve_record_length(&mut self) {
117
+
if self.reservation.is_none() {
118
+
self.reservation = Some(self.position);
119
+
self.increment(2);
120
+
}
121
+
}
122
+
123
+
fn distance_from_reservation(&mut self) -> Option<(Range<usize>, u16)> {
124
+
self.reservation
125
+
.take()
126
+
.map(|start| (start..(start + 2), (self.position - start - 2) as u16))
127
+
}
128
+
129
+
fn write_record_length(&mut self) {
130
+
if let Some((reservation, len)) = self.distance_from_reservation() {
131
+
self.output[reservation].copy_from_slice(&len.to_be_bytes());
132
+
}
133
+
}
134
+
135
+
pub(crate) fn with_record_length<E, F>(&mut self, encoding_scope: F) -> Result<(), E>
136
+
where
137
+
E: core::error::Error,
138
+
F: FnOnce(&mut Self) -> Result<(), E>,
139
+
{
140
+
self.reserve_record_length();
141
+
encoding_scope(self)?;
142
+
self.write_record_length();
143
+
Ok(())
144
+
}
145
+
}
+60
sachy-mdns/src/lib.rs
+60
sachy-mdns/src/lib.rs
···
···
1
+
#![no_std]
2
+
3
+
mod dns;
4
+
pub(crate) mod encoder;
5
+
pub mod server;
6
+
mod service;
7
+
mod state;
8
+
9
+
extern crate alloc;
10
+
11
+
use core::net::{Ipv4Addr, SocketAddrV4};
12
+
13
+
pub use crate::service::Service;
14
+
pub use crate::state::MdnsAction;
15
+
use crate::{dns::flags::Flags, server::Server, state::MdnsStateMachine};
16
+
17
+
/// Standard port for mDNS (5353).
18
+
pub const MDNS_PORT: u16 = 5353;
19
+
20
+
/// Standard IPv4 multicast address for mDNS (224.0.0.251).
21
+
pub const GROUP_ADDR_V4: Ipv4Addr = Ipv4Addr::new(224, 0, 0, 251);
22
+
pub const GROUP_SOCK_V4: SocketAddrV4 = SocketAddrV4::new(GROUP_ADDR_V4, MDNS_PORT);
23
+
24
+
#[derive(Debug)]
25
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
26
+
pub struct MdnsService {
27
+
server: Server,
28
+
state: MdnsStateMachine,
29
+
}
30
+
31
+
impl MdnsService {
32
+
pub fn new(service: Service) -> Self {
33
+
Self {
34
+
server: Server::new(service),
35
+
state: Default::default(),
36
+
}
37
+
}
38
+
39
+
pub fn next_action(&mut self) -> MdnsAction {
40
+
self.state.drive_next_action()
41
+
}
42
+
43
+
pub fn send_announcement<'buffer>(&self, outgoing: &'buffer mut [u8]) -> Option<&'buffer [u8]> {
44
+
self.server.broadcast(
45
+
server::ResponseKind::Announcement,
46
+
Flags::standard_response(),
47
+
1,
48
+
alloc::vec::Vec::new(),
49
+
outgoing,
50
+
)
51
+
}
52
+
53
+
pub fn listen_for_queries<'buffer>(
54
+
&mut self,
55
+
incoming: &[u8],
56
+
outgoing: &'buffer mut [u8],
57
+
) -> Option<&'buffer [u8]> {
58
+
self.server.respond(incoming, outgoing)
59
+
}
60
+
}
+95
sachy-mdns/src/server.rs
+95
sachy-mdns/src/server.rs
···
···
1
+
use alloc::vec::Vec;
2
+
use sachy_fmt::{error, info};
3
+
4
+
use crate::{
5
+
dns::{
6
+
flags::Flags,
7
+
query::{QClass, Query},
8
+
records::QType,
9
+
reqres::{Request, Response},
10
+
traits::DnsParse,
11
+
},
12
+
encoder::Encoder,
13
+
service::Service,
14
+
};
15
+
16
+
pub(crate) enum ResponseKind {
17
+
Announcement,
18
+
QueryResponse(Vec<(QType, QClass)>),
19
+
}
20
+
21
+
#[derive(Debug)]
22
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
23
+
pub(crate) struct Server {
24
+
service: Service,
25
+
}
26
+
27
+
impl Server {
28
+
pub(crate) fn new(service: Service) -> Self {
29
+
Self { service }
30
+
}
31
+
32
+
pub(crate) fn broadcast<'a, 'b>(
33
+
&self,
34
+
response_kind: ResponseKind,
35
+
flags: Flags,
36
+
id: u16,
37
+
queries: Vec<Query<'b>>,
38
+
outgoing: &'a mut [u8],
39
+
) -> Option<&'a [u8]> {
40
+
let answers: Vec<_> = match response_kind {
41
+
ResponseKind::Announcement => self.service.as_answers(QClass::Multicast).collect(),
42
+
ResponseKind::QueryResponse(valid) => valid
43
+
.iter()
44
+
.flat_map(|&(qtype, qclass)| match qtype {
45
+
QType::A | QType::AAAA => self.service.ip_answer(qclass),
46
+
QType::PTR => self.service.ptr_answer(qclass),
47
+
QType::TXT => self.service.txt_answer(qclass),
48
+
QType::SRV => self.service.srv_answer(qclass),
49
+
QType::Any | QType::Unknown(_) => None,
50
+
})
51
+
.collect(),
52
+
};
53
+
54
+
if !answers.is_empty() {
55
+
let res = Response {
56
+
flags,
57
+
id,
58
+
queries,
59
+
answers,
60
+
};
61
+
62
+
info!("MDNS RESPONSE: {}", res);
63
+
64
+
return Encoder::new(outgoing)
65
+
.encode(res)
66
+
.inspect_err(|err| error!("Encoder errored: {}", err))
67
+
.ok();
68
+
}
69
+
70
+
None
71
+
}
72
+
73
+
pub(crate) fn respond<'a>(&self, incoming: &[u8], outgoing: &'a mut [u8]) -> Option<&'a [u8]> {
74
+
Request::parse(&mut &incoming[..], incoming)
75
+
.ok()
76
+
.and_then(|req| {
77
+
let valid_queries = req
78
+
.queries
79
+
.iter()
80
+
.filter_map(|q| match q.qtype {
81
+
QType::A | QType::AAAA | QType::TXT | QType::SRV => {
82
+
(q.name == self.service.hostname()).then_some((q.qtype, q.qclass))
83
+
}
84
+
QType::PTR => (q.name == self.service.service_type()).then_some((q.qtype, q.qclass)),
85
+
QType::Any | QType::Unknown(_) => None,
86
+
}).collect::<Vec<_>>();
87
+
88
+
if !valid_queries.is_empty() {
89
+
self.broadcast(ResponseKind::QueryResponse(valid_queries), req.flags, req.id, req.queries, outgoing)
90
+
} else {
91
+
None
92
+
}
93
+
})
94
+
}
95
+
}
+193
sachy-mdns/src/service.rs
+193
sachy-mdns/src/service.rs
···
···
1
+
use core::net::IpAddr;
2
+
3
+
use alloc::{
4
+
string::{String, ToString},
5
+
vec::Vec,
6
+
};
7
+
8
+
use crate::dns::{
9
+
label::Label,
10
+
query::{Answer, QClass},
11
+
records::{A, AAAA, PTR, QType, Record, SRV, TXT},
12
+
};
13
+
14
+
#[derive(Debug, Default)]
15
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
16
+
pub struct Service {
17
+
service_type: String,
18
+
instance: String,
19
+
hostname: String,
20
+
ip: Option<IpAddr>,
21
+
port: u16,
22
+
}
23
+
24
+
impl Service {
25
+
pub fn new(
26
+
service_type: impl Into<String>,
27
+
instance: impl Into<String>,
28
+
hostname: impl Into<String>,
29
+
ip: Option<IpAddr>,
30
+
port: u16,
31
+
) -> Self {
32
+
let service_type = service_type.into();
33
+
let mut instance = instance.into();
34
+
let mut hostname = hostname.into();
35
+
36
+
instance.push('.');
37
+
instance.push_str(&service_type);
38
+
hostname.push_str(".local");
39
+
40
+
Self {
41
+
service_type,
42
+
instance,
43
+
hostname,
44
+
ip,
45
+
port,
46
+
}
47
+
}
48
+
49
+
pub fn service_type(&self) -> Label<'_> {
50
+
Label::from(self.service_type.as_ref())
51
+
}
52
+
53
+
pub fn instance(&self) -> Label<'_> {
54
+
Label::from(self.instance.as_ref())
55
+
}
56
+
57
+
pub fn hostname(&self) -> Label<'_> {
58
+
Label::from(self.hostname.as_ref())
59
+
}
60
+
61
+
pub fn ip(&self) -> Option<IpAddr> {
62
+
self.ip
63
+
}
64
+
65
+
pub fn port(&self) -> u16 {
66
+
self.port
67
+
}
68
+
69
+
pub(crate) fn ptr_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
70
+
Some(Answer {
71
+
name: self.service_type(),
72
+
atype: QType::PTR,
73
+
aclass,
74
+
ttl: 4500,
75
+
record: Record::PTR(PTR {
76
+
name: self.instance(),
77
+
}),
78
+
})
79
+
}
80
+
81
+
pub(crate) fn srv_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
82
+
Some(Answer {
83
+
name: self.instance(),
84
+
atype: QType::SRV,
85
+
aclass,
86
+
ttl: 120,
87
+
record: Record::SRV(SRV {
88
+
priority: 0,
89
+
weight: 0,
90
+
port: self.port,
91
+
target: self.hostname(),
92
+
}),
93
+
})
94
+
}
95
+
96
+
pub(crate) fn txt_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
97
+
Some(Answer {
98
+
name: self.instance(),
99
+
atype: QType::TXT,
100
+
aclass,
101
+
ttl: 120,
102
+
record: Record::TXT(TXT { text: Vec::new() }),
103
+
})
104
+
}
105
+
106
+
pub(crate) fn ip_answer(&self, aclass: QClass) -> Option<Answer<'_>> {
107
+
self.ip().map(|address| match address {
108
+
IpAddr::V4(address) => Answer {
109
+
name: self.hostname(),
110
+
atype: QType::A,
111
+
aclass,
112
+
ttl: 120,
113
+
record: Record::A(A { address }),
114
+
},
115
+
IpAddr::V6(address) => Answer {
116
+
name: self.hostname(),
117
+
atype: QType::AAAA,
118
+
aclass,
119
+
ttl: 120,
120
+
record: Record::AAAA(AAAA { address }),
121
+
},
122
+
})
123
+
}
124
+
125
+
pub(crate) fn as_answers(&self, aclass: QClass) -> impl Iterator<Item = Answer<'_>> {
126
+
self.ptr_answer(aclass)
127
+
.into_iter()
128
+
.chain(self.srv_answer(aclass))
129
+
.chain(self.txt_answer(aclass))
130
+
.chain(self.ip_answer(aclass))
131
+
}
132
+
133
+
#[allow(dead_code)]
134
+
pub(crate) fn from_answers(answers: &[Answer<'_>]) -> Vec<Self> {
135
+
let mut output = Vec::new();
136
+
137
+
// Step 1: Process PTR records
138
+
for answer in answers {
139
+
if let Record::PTR(ptr) = &answer.record {
140
+
let instance = ptr.name.to_string();
141
+
let service_type = answer.name.to_string();
142
+
output.push(Self {
143
+
service_type,
144
+
instance,
145
+
..Default::default()
146
+
});
147
+
}
148
+
}
149
+
150
+
// Step 2: Process SRV records, A and AAAA records and merge data
151
+
for answer in answers {
152
+
match &answer.record {
153
+
Record::SRV(srv) => {
154
+
if let Some(stub) = output
155
+
.iter_mut()
156
+
.find(|stub| answer.name == stub.instance.as_ref())
157
+
{
158
+
stub.hostname = srv.target.to_string();
159
+
stub.port = srv.port;
160
+
}
161
+
}
162
+
Record::A(a) => {
163
+
if let Some(stub) = output
164
+
.iter_mut()
165
+
.find(|stub| answer.name == stub.hostname.as_ref())
166
+
{
167
+
stub.ip = Some(IpAddr::V4(a.address));
168
+
}
169
+
}
170
+
Record::AAAA(aaaa) => {
171
+
if let Some(stub) = output
172
+
.iter_mut()
173
+
.find(|stub| answer.name == stub.hostname.as_ref())
174
+
{
175
+
stub.ip = Some(IpAddr::V6(aaaa.address));
176
+
}
177
+
}
178
+
_ => {}
179
+
}
180
+
}
181
+
182
+
// Final step: Retain only complete services
183
+
output.retain(|stub| {
184
+
!stub.service_type.is_empty()
185
+
&& !stub.instance.is_empty()
186
+
&& !stub.hostname.is_empty()
187
+
&& stub.ip.is_some()
188
+
&& stub.port != 0
189
+
});
190
+
191
+
output
192
+
}
193
+
}
+100
sachy-mdns/src/state.rs
+100
sachy-mdns/src/state.rs
···
···
1
+
use embassy_time::{Duration, Instant};
2
+
use sachy_fmt::{debug, unwrap};
3
+
4
+
#[derive(Debug, Default)]
5
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
6
+
pub(crate) enum MdnsStateMachine {
7
+
#[default]
8
+
Start,
9
+
Announce {
10
+
last_sent: Instant,
11
+
},
12
+
ListenFor {
13
+
last_sent: Instant,
14
+
timeout: Duration,
15
+
},
16
+
WaitFor {
17
+
last_sent: Instant,
18
+
duration: Duration,
19
+
},
20
+
}
21
+
22
+
impl MdnsStateMachine {
23
+
/// Set the state to announced, if we have timed out the listening period and need
24
+
/// to announce, or if we received a query while listening and have sent a response.
25
+
pub(crate) fn announced(&mut self) {
26
+
*self = Self::Announce {
27
+
last_sent: Instant::now(),
28
+
};
29
+
}
30
+
31
+
fn next_state(&mut self) {
32
+
match self {
33
+
Self::Start => self.announced(),
34
+
&mut Self::Announce { last_sent } => {
35
+
let duration_since = last_sent.elapsed();
36
+
let duration = Duration::from_secs(1) - duration_since;
37
+
38
+
*self = Self::WaitFor {
39
+
last_sent,
40
+
duration,
41
+
};
42
+
}
43
+
&mut Self::ListenFor { last_sent, .. } => {
44
+
let duration_since = last_sent.elapsed();
45
+
let time_limit = Duration::from_secs(120);
46
+
47
+
if duration_since >= time_limit {
48
+
self.announced();
49
+
} else {
50
+
let timeout = time_limit - duration_since;
51
+
*self = Self::ListenFor { last_sent, timeout };
52
+
}
53
+
}
54
+
&mut Self::WaitFor { last_sent, .. } => {
55
+
let duration_since = last_sent.elapsed();
56
+
let time_limit = Duration::from_secs(120);
57
+
let timeout = time_limit - duration_since;
58
+
59
+
*self = Self::ListenFor { last_sent, timeout };
60
+
}
61
+
}
62
+
}
63
+
64
+
pub(crate) fn drive_next_action(&mut self) -> MdnsAction {
65
+
self.next_state();
66
+
unwrap!(MdnsAction::try_from(self))
67
+
}
68
+
}
69
+
70
+
#[derive(Debug)]
71
+
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
72
+
pub enum MdnsAction {
73
+
Announce,
74
+
ListenFor { timeout: Duration },
75
+
WaitFor { duration: Duration },
76
+
}
77
+
78
+
impl TryFrom<&mut MdnsStateMachine> for MdnsAction {
79
+
type Error = MdnsStateMachine;
80
+
81
+
fn try_from(value: &mut MdnsStateMachine) -> Result<Self, Self::Error> {
82
+
match value {
83
+
// We should start in this state, but never remain nor return to it when
84
+
// executing our state machine event loop.
85
+
MdnsStateMachine::Start => Err(MdnsStateMachine::Start),
86
+
MdnsStateMachine::Announce { .. } => {
87
+
debug!("ANNOUNCE");
88
+
Ok(Self::Announce)
89
+
}
90
+
&mut MdnsStateMachine::ListenFor { timeout, .. } => {
91
+
debug!("LISTEN FOR {}ms", timeout.as_millis());
92
+
Ok(Self::ListenFor { timeout })
93
+
}
94
+
&mut MdnsStateMachine::WaitFor { duration, .. } => {
95
+
debug!("WAIT FOR {}ms", duration.as_millis());
96
+
Ok(Self::WaitFor { duration })
97
+
}
98
+
}
99
+
}
100
+
}