this repo has no description
1const std = @import("std");
2const assert = std.debug.assert;
3const io = std.io;
4const net = std.net;
5
6pub const BytePacketBuffer = struct {
7 buf: [512]u8 = undefined,
8 pos: usize = 0,
9
10 pub const ReadError = error{
11 EndOfBuffer,
12 JumpLimitExceeded,
13 };
14
15 pub const Reader = io.Reader(*BytePacketBuffer, ReadError, read);
16
17 /// Change the buffer position forward a specific number of steps
18 pub fn step(self: *BytePacketBuffer, pos: usize) void {
19 self.pos += pos;
20 }
21
22 /// Chanke the buffer position
23 pub fn seek(self: *BytePacketBuffer, pos: usize) void {
24 self.pos = pos;
25 }
26
27 pub fn reader(self: *BytePacketBuffer) Reader {
28 return .{ .context = self };
29 }
30
31 /// Read a single byte and move the position one step forward
32 pub fn read(self: *BytePacketBuffer, dest: []u8) ReadError!usize {
33 if (self.pos + dest.len > self.buf.len)
34 return ReadError.EndOfBuffer;
35 const size = dest.len;
36 const end = self.pos + size;
37 @memcpy(dest[0..size], self.buf[self.pos..end]);
38 self.pos = end;
39 return size;
40 }
41
42 /// Get a single byte without changing the buffer position
43 pub fn get(self: *const BytePacketBuffer, pos: usize) ReadError!u8 {
44 if (pos >= comptime self.buf.len) return ReadError.EndOfBuffer;
45 return self.buf[pos];
46 }
47
48 /// Get a range of bytes
49 pub fn getRange(self: *const BytePacketBuffer, start: usize, len: usize) ReadError![]const u8 {
50 if (start + len >= comptime self.buf.len) return ReadError.EndOfBuffer;
51 return self.buf[start .. start + len];
52 }
53
54 /// Read a qname
55 ///
56 /// The tricky part: Reading domain names, taking labels into consideration.
57 /// Will take something like [3]www[6]google[3]com and append
58 /// www.google.com to outstr.
59 pub fn readQname(self: *BytePacketBuffer, outstr: []u8) ReadError!void {
60 // We might encounter jumps, therefore we need to keep thrack of our position locally
61 var pos = self.pos;
62 var out_pos: usize = 0;
63
64 // track whether or nor we've jumped
65 var jumped = false;
66 const max_jumps: usize = 5;
67 var jumps_performed: usize = 0;
68
69 var delim: ?[]const u8 = null;
70 while (true) {
71 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded;
72
73 // Each label starts with a length byte
74 const len = try self.get(pos);
75
76 // If len has the two most signigicant bit set, it represents a jump to some other
77 // offset in the packet:
78 if ((len & 0xC0) == 0xC0) {
79 // Update the buffer position to a point past the current label
80 if (!jumped) self.seek(2);
81
82 // Read another byte, calculate offset and performe the jump by updating our
83 // local position variable
84 const b2 = @as(u16, try self.get(pos + 1));
85 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2;
86 pos = @as(usize, offset);
87
88 // Indicate that a jump was performed
89 jumped = true;
90 jumps_performed += 1;
91
92 continue;
93 } else {
94 // Move a single byte forward to move path the length
95 self.pos += 1;
96 pos += 1;
97
98 // Domain names are terminated by an empty label of length 0, so if the length
99 // is zero we're done
100 if (len == 0) break;
101
102 if (delim) |del| {
103 @memcpy(outstr[out_pos .. out_pos + del.len], del);
104 out_pos += del.len;
105 }
106
107 const read_len = try self.read(outstr[out_pos .. out_pos + len]);
108 assert(read_len == len);
109
110 delim = ".";
111
112 pos += len;
113 out_pos += len;
114 }
115 }
116
117 if (!jumped) self.seek(1);
118 }
119
120 /// Read a qname
121 ///
122 /// The tricky part: Reading domain names, taking labels into consideration.
123 /// Will take something like [3]www[6]google[3]com and append
124 /// www.google.com to outstr.
125 pub fn readQnameAlloc(self: *BytePacketBuffer, alloc: std.mem.Allocator) ![]u8 {
126 var buffer: std.ArrayList(u8) = .init(alloc);
127
128 // We might encounter jumps, therefore we need to keep thrack of our position locally
129 var pos = self.pos;
130
131 // track whether or nor we've jumped
132 var jumped = false;
133 const max_jumps: usize = 5;
134 var jumps_performed: usize = 0;
135
136 var delim: ?[]const u8 = null;
137 while (true) {
138 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded;
139
140 // Each label starts with a length byte
141 const len = try self.get(pos);
142
143 // If len has the two most signigicant bit set, it represents a jump to some other
144 // offset in the packet:
145 if ((len & 0xC0) == 0xC0) {
146 // Update the buffer position to a point past the current label
147 if (!jumped) self.seek(2);
148
149 // Read another byte, calculate offset and performe the jump by updating our
150 // local position variable
151 const b2 = @as(u16, try self.get(pos + 1));
152 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2;
153 pos = @as(usize, offset);
154
155 // Indicate that a jump was performed
156 jumped = true;
157 jumps_performed += 1;
158
159 continue;
160 } else {
161 // Move a single byte forward to move path the length
162 self.pos += 1;
163 pos += 1;
164
165 // Domain names are terminated by an empty label of length 0, so if the length
166 // is zero we're done
167 if (len == 0) break;
168
169 if (delim) |del| {
170 try buffer.appendSlice(del);
171 }
172
173 try buffer.appendSlice(try self.getRange(pos, len));
174
175 delim = ".";
176
177 pos += len;
178 }
179 }
180
181 if (!jumped) self.seek(1);
182
183 return buffer.toOwnedSlice();
184 }
185};
186
187test "BytePacketBuffer.read" {
188 const testing = std.testing;
189 var buf = BytePacketBuffer{};
190 buf.buf[0] = 0x1;
191 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
192}
193
194test "BytePacketBuffer.read_u16" {
195 const testing = std.testing;
196 var buf = BytePacketBuffer{};
197 buf.buf[0] = 0x1;
198 buf.buf[1] = 0x1;
199 try testing.expectEqual(0x101, try buf.reader().readInt(u16, .big));
200}
201
202test "BytePacketBuffer.read_u32" {
203 const testing = std.testing;
204 var buf = BytePacketBuffer{};
205 buf.buf[0] = 0x1;
206 buf.buf[1] = 0x1;
207 buf.buf[2] = 0x1;
208 buf.buf[3] = 0x1;
209 try testing.expectEqual(0x1010101, try buf.reader().readInt(u32, .big));
210}
211
212test "BytePacketBuffer.read last byte" {
213 const testing = std.testing;
214 var buf = BytePacketBuffer{};
215 buf.buf[buf.buf.len - 1] = 0x1;
216 buf.pos = buf.buf.len - 1;
217 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big));
218 try testing.expectError(
219 BytePacketBuffer.ReadError.EndOfBuffer,
220 buf.reader().readInt(u8, .big),
221 );
222}
223
224test "BytePacketBuffer.read_qname google.com" {
225 const testing = std.testing;
226 const allocator = testing.allocator;
227
228 const input = [_]u8{
229 0x06, // [6]
230 0x67, // g
231 0x6f, // o
232 0x6f, // o
233 0x67, // g
234 0x6c, // l
235 0x65, // e
236 0x03, // [3]
237 0x63, // c
238 0x6f, // o
239 0x6d, // m
240 0x00, // [0]
241 };
242 const expected = "google.com";
243
244 var buf = BytePacketBuffer{};
245 for (input, 0..) |char, idx| {
246 buf.buf[idx] = char;
247 }
248
249 const outstr = try allocator.alloc(u8, expected.len);
250 defer allocator.free(outstr);
251
252 try buf.readQname(outstr);
253
254 try testing.expectEqualStrings(expected, outstr);
255}
256
257test "BytePacketBuffer.read_qname_alloc google.com" {
258 const testing = std.testing;
259 const allocator = testing.allocator;
260
261 const input = [_]u8{
262 0x06, // [6]
263 0x67, // g
264 0x6f, // o
265 0x6f, // o
266 0x67, // g
267 0x6c, // l
268 0x65, // e
269 0x03, // [3]
270 0x63, // c
271 0x6f, // o
272 0x6d, // m
273 0x00, // [0]
274 };
275 const expected = "google.com";
276
277 var buf = BytePacketBuffer{};
278 for (input, 0..) |char, idx| {
279 buf.buf[idx] = char;
280 }
281
282 const outstr = try buf.readQnameAlloc(allocator);
283 defer allocator.free(outstr);
284
285 try testing.expectEqualStrings(expected, outstr);
286}
287
288pub const DnsHeader = packed struct {
289 id: u16 = 0,
290
291 recursion_desired: bool = false,
292 truncated_message: bool = false,
293 authoritative_answer: bool = false,
294 opcode: u4 = 0,
295 response: bool = false,
296
297 rescode: enum(u4) {
298 noerror = 0,
299 formerr = 1,
300 servfail = 2,
301 nxdomain = 3,
302 notimp = 4,
303 refused = 5,
304 } = .noerror,
305 checking_disabled: bool = false,
306 authed_data: bool = false,
307 z: bool = false,
308 recursion_available: bool = false,
309
310 questions: u16 = 0,
311 answers: u16 = 0,
312 authoritative_entries: u16 = 0,
313 resource_entries: u16 = 0,
314};
315
316pub const QueryType = enum(u16) {
317 a = 1,
318 unknown,
319};
320
321pub const DnsQuestion = struct {
322 name: []const u8,
323 qtype: QueryType,
324
325 pub fn read(self: *DnsQuestion, buffer: *BytePacketBuffer, alloc: std.mem.Allocator) !void {
326 self.name = try buffer.readQnameAlloc(alloc);
327 self.qtype = @enumFromInt(try buffer.reader().readInt(u16, .big));
328 _ = try buffer.reader().readInt(u16, .big); // class
329 }
330};
331
332// TODO: add DnsQuestion.read test
333
334pub const DnsRecord = union(QueryType) {
335 unknown: struct {
336 domain: []const u8,
337 qtype: u16,
338 data_len: u16,
339 ttl: u32,
340 },
341 a: struct {
342 domain: []const u8,
343 addr: net.Ip4Address,
344 ttl: u32,
345 },
346};