this repo has no description
at main 10 kB view raw
1const std = @import("std"); 2const assert = std.debug.assert; 3const io = std.io; 4const net = std.net; 5 6pub const BytePacketBuffer = struct { 7 buf: [512]u8 = undefined, 8 pos: usize = 0, 9 10 pub const ReadError = error{ 11 EndOfBuffer, 12 JumpLimitExceeded, 13 }; 14 15 pub const Reader = io.Reader(*BytePacketBuffer, ReadError, read); 16 17 /// Change the buffer position forward a specific number of steps 18 pub fn step(self: *BytePacketBuffer, pos: usize) void { 19 self.pos += pos; 20 } 21 22 /// Chanke the buffer position 23 pub fn seek(self: *BytePacketBuffer, pos: usize) void { 24 self.pos = pos; 25 } 26 27 pub fn reader(self: *BytePacketBuffer) Reader { 28 return .{ .context = self }; 29 } 30 31 /// Read a single byte and move the position one step forward 32 pub fn read(self: *BytePacketBuffer, dest: []u8) ReadError!usize { 33 if (self.pos + dest.len > self.buf.len) 34 return ReadError.EndOfBuffer; 35 const size = dest.len; 36 const end = self.pos + size; 37 @memcpy(dest[0..size], self.buf[self.pos..end]); 38 self.pos = end; 39 return size; 40 } 41 42 /// Get a single byte without changing the buffer position 43 pub fn get(self: *const BytePacketBuffer, pos: usize) ReadError!u8 { 44 if (pos >= comptime self.buf.len) return ReadError.EndOfBuffer; 45 return self.buf[pos]; 46 } 47 48 /// Get a range of bytes 49 pub fn getRange(self: *const BytePacketBuffer, start: usize, len: usize) ReadError![]const u8 { 50 if (start + len >= comptime self.buf.len) return ReadError.EndOfBuffer; 51 return self.buf[start .. start + len]; 52 } 53 54 /// Read a qname 55 /// 56 /// The tricky part: Reading domain names, taking labels into consideration. 57 /// Will take something like [3]www[6]google[3]com and append 58 /// www.google.com to outstr. 59 pub fn readQname(self: *BytePacketBuffer, outstr: []u8) ReadError!void { 60 // We might encounter jumps, therefore we need to keep thrack of our position locally 61 var pos = self.pos; 62 var out_pos: usize = 0; 63 64 // track whether or nor we've jumped 65 var jumped = false; 66 const max_jumps: usize = 5; 67 var jumps_performed: usize = 0; 68 69 var delim: ?[]const u8 = null; 70 while (true) { 71 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded; 72 73 // Each label starts with a length byte 74 const len = try self.get(pos); 75 76 // If len has the two most signigicant bit set, it represents a jump to some other 77 // offset in the packet: 78 if ((len & 0xC0) == 0xC0) { 79 // Update the buffer position to a point past the current label 80 if (!jumped) self.seek(2); 81 82 // Read another byte, calculate offset and performe the jump by updating our 83 // local position variable 84 const b2 = @as(u16, try self.get(pos + 1)); 85 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2; 86 pos = @as(usize, offset); 87 88 // Indicate that a jump was performed 89 jumped = true; 90 jumps_performed += 1; 91 92 continue; 93 } else { 94 // Move a single byte forward to move path the length 95 self.pos += 1; 96 pos += 1; 97 98 // Domain names are terminated by an empty label of length 0, so if the length 99 // is zero we're done 100 if (len == 0) break; 101 102 if (delim) |del| { 103 @memcpy(outstr[out_pos .. out_pos + del.len], del); 104 out_pos += del.len; 105 } 106 107 const read_len = try self.read(outstr[out_pos .. out_pos + len]); 108 assert(read_len == len); 109 110 delim = "."; 111 112 pos += len; 113 out_pos += len; 114 } 115 } 116 117 if (!jumped) self.seek(1); 118 } 119 120 /// Read a qname 121 /// 122 /// The tricky part: Reading domain names, taking labels into consideration. 123 /// Will take something like [3]www[6]google[3]com and append 124 /// www.google.com to outstr. 125 pub fn readQnameAlloc(self: *BytePacketBuffer, alloc: std.mem.Allocator) ![]u8 { 126 var buffer: std.ArrayList(u8) = .init(alloc); 127 128 // We might encounter jumps, therefore we need to keep thrack of our position locally 129 var pos = self.pos; 130 131 // track whether or nor we've jumped 132 var jumped = false; 133 const max_jumps: usize = 5; 134 var jumps_performed: usize = 0; 135 136 var delim: ?[]const u8 = null; 137 while (true) { 138 if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded; 139 140 // Each label starts with a length byte 141 const len = try self.get(pos); 142 143 // If len has the two most signigicant bit set, it represents a jump to some other 144 // offset in the packet: 145 if ((len & 0xC0) == 0xC0) { 146 // Update the buffer position to a point past the current label 147 if (!jumped) self.seek(2); 148 149 // Read another byte, calculate offset and performe the jump by updating our 150 // local position variable 151 const b2 = @as(u16, try self.get(pos + 1)); 152 const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2; 153 pos = @as(usize, offset); 154 155 // Indicate that a jump was performed 156 jumped = true; 157 jumps_performed += 1; 158 159 continue; 160 } else { 161 // Move a single byte forward to move path the length 162 self.pos += 1; 163 pos += 1; 164 165 // Domain names are terminated by an empty label of length 0, so if the length 166 // is zero we're done 167 if (len == 0) break; 168 169 if (delim) |del| { 170 try buffer.appendSlice(del); 171 } 172 173 try buffer.appendSlice(try self.getRange(pos, len)); 174 175 delim = "."; 176 177 pos += len; 178 } 179 } 180 181 if (!jumped) self.seek(1); 182 183 return buffer.toOwnedSlice(); 184 } 185}; 186 187test "BytePacketBuffer.read" { 188 const testing = std.testing; 189 var buf = BytePacketBuffer{}; 190 buf.buf[0] = 0x1; 191 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big)); 192} 193 194test "BytePacketBuffer.read_u16" { 195 const testing = std.testing; 196 var buf = BytePacketBuffer{}; 197 buf.buf[0] = 0x1; 198 buf.buf[1] = 0x1; 199 try testing.expectEqual(0x101, try buf.reader().readInt(u16, .big)); 200} 201 202test "BytePacketBuffer.read_u32" { 203 const testing = std.testing; 204 var buf = BytePacketBuffer{}; 205 buf.buf[0] = 0x1; 206 buf.buf[1] = 0x1; 207 buf.buf[2] = 0x1; 208 buf.buf[3] = 0x1; 209 try testing.expectEqual(0x1010101, try buf.reader().readInt(u32, .big)); 210} 211 212test "BytePacketBuffer.read last byte" { 213 const testing = std.testing; 214 var buf = BytePacketBuffer{}; 215 buf.buf[buf.buf.len - 1] = 0x1; 216 buf.pos = buf.buf.len - 1; 217 try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big)); 218 try testing.expectError( 219 BytePacketBuffer.ReadError.EndOfBuffer, 220 buf.reader().readInt(u8, .big), 221 ); 222} 223 224test "BytePacketBuffer.read_qname google.com" { 225 const testing = std.testing; 226 const allocator = testing.allocator; 227 228 const input = [_]u8{ 229 0x06, // [6] 230 0x67, // g 231 0x6f, // o 232 0x6f, // o 233 0x67, // g 234 0x6c, // l 235 0x65, // e 236 0x03, // [3] 237 0x63, // c 238 0x6f, // o 239 0x6d, // m 240 0x00, // [0] 241 }; 242 const expected = "google.com"; 243 244 var buf = BytePacketBuffer{}; 245 for (input, 0..) |char, idx| { 246 buf.buf[idx] = char; 247 } 248 249 const outstr = try allocator.alloc(u8, expected.len); 250 defer allocator.free(outstr); 251 252 try buf.readQname(outstr); 253 254 try testing.expectEqualStrings(expected, outstr); 255} 256 257test "BytePacketBuffer.read_qname_alloc google.com" { 258 const testing = std.testing; 259 const allocator = testing.allocator; 260 261 const input = [_]u8{ 262 0x06, // [6] 263 0x67, // g 264 0x6f, // o 265 0x6f, // o 266 0x67, // g 267 0x6c, // l 268 0x65, // e 269 0x03, // [3] 270 0x63, // c 271 0x6f, // o 272 0x6d, // m 273 0x00, // [0] 274 }; 275 const expected = "google.com"; 276 277 var buf = BytePacketBuffer{}; 278 for (input, 0..) |char, idx| { 279 buf.buf[idx] = char; 280 } 281 282 const outstr = try buf.readQnameAlloc(allocator); 283 defer allocator.free(outstr); 284 285 try testing.expectEqualStrings(expected, outstr); 286} 287 288pub const DnsHeader = packed struct { 289 id: u16 = 0, 290 291 recursion_desired: bool = false, 292 truncated_message: bool = false, 293 authoritative_answer: bool = false, 294 opcode: u4 = 0, 295 response: bool = false, 296 297 rescode: enum(u4) { 298 noerror = 0, 299 formerr = 1, 300 servfail = 2, 301 nxdomain = 3, 302 notimp = 4, 303 refused = 5, 304 } = .noerror, 305 checking_disabled: bool = false, 306 authed_data: bool = false, 307 z: bool = false, 308 recursion_available: bool = false, 309 310 questions: u16 = 0, 311 answers: u16 = 0, 312 authoritative_entries: u16 = 0, 313 resource_entries: u16 = 0, 314}; 315 316pub const QueryType = enum(u16) { 317 a = 1, 318 unknown, 319}; 320 321pub const DnsQuestion = struct { 322 name: []const u8, 323 qtype: QueryType, 324 325 pub fn read(self: *DnsQuestion, buffer: *BytePacketBuffer, alloc: std.mem.Allocator) !void { 326 self.name = try buffer.readQnameAlloc(alloc); 327 self.qtype = @enumFromInt(try buffer.reader().readInt(u16, .big)); 328 _ = try buffer.reader().readInt(u16, .big); // class 329 } 330}; 331 332// TODO: add DnsQuestion.read test 333 334pub const DnsRecord = union(QueryType) { 335 unknown: struct { 336 domain: []const u8, 337 qtype: u16, 338 data_len: u16, 339 ttl: u32, 340 }, 341 a: struct { 342 domain: []const u8, 343 addr: net.Ip4Address, 344 ttl: u32, 345 }, 346};