const std = @import("std"); const assert = std.debug.assert; const io = std.io; const net = std.net; pub const BytePacketBuffer = struct { buf: [512]u8 = undefined, pos: usize = 0, pub const ReadError = error{ EndOfBuffer, JumpLimitExceeded, }; pub const Reader = io.Reader(*BytePacketBuffer, ReadError, read); /// Change the buffer position forward a specific number of steps pub fn step(self: *BytePacketBuffer, pos: usize) void { self.pos += pos; } /// Chanke the buffer position pub fn seek(self: *BytePacketBuffer, pos: usize) void { self.pos = pos; } pub fn reader(self: *BytePacketBuffer) Reader { return .{ .context = self }; } /// Read a single byte and move the position one step forward pub fn read(self: *BytePacketBuffer, dest: []u8) ReadError!usize { if (self.pos + dest.len > self.buf.len) return ReadError.EndOfBuffer; const size = dest.len; const end = self.pos + size; @memcpy(dest[0..size], self.buf[self.pos..end]); self.pos = end; return size; } /// Get a single byte without changing the buffer position pub fn get(self: *const BytePacketBuffer, pos: usize) ReadError!u8 { if (pos >= comptime self.buf.len) return ReadError.EndOfBuffer; return self.buf[pos]; } /// Get a range of bytes pub fn getRange(self: *const BytePacketBuffer, start: usize, len: usize) ReadError![]const u8 { if (start + len >= comptime self.buf.len) return ReadError.EndOfBuffer; return self.buf[start .. start + len]; } /// Read a qname /// /// The tricky part: Reading domain names, taking labels into consideration. /// Will take something like [3]www[6]google[3]com and append /// www.google.com to outstr. pub fn readQname(self: *BytePacketBuffer, outstr: []u8) ReadError!void { // We might encounter jumps, therefore we need to keep thrack of our position locally var pos = self.pos; var out_pos: usize = 0; // track whether or nor we've jumped var jumped = false; const max_jumps: usize = 5; var jumps_performed: usize = 0; var delim: ?[]const u8 = null; while (true) { if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded; // Each label starts with a length byte const len = try self.get(pos); // If len has the two most signigicant bit set, it represents a jump to some other // offset in the packet: if ((len & 0xC0) == 0xC0) { // Update the buffer position to a point past the current label if (!jumped) self.seek(2); // Read another byte, calculate offset and performe the jump by updating our // local position variable const b2 = @as(u16, try self.get(pos + 1)); const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2; pos = @as(usize, offset); // Indicate that a jump was performed jumped = true; jumps_performed += 1; continue; } else { // Move a single byte forward to move path the length self.pos += 1; pos += 1; // Domain names are terminated by an empty label of length 0, so if the length // is zero we're done if (len == 0) break; if (delim) |del| { @memcpy(outstr[out_pos .. out_pos + del.len], del); out_pos += del.len; } const read_len = try self.read(outstr[out_pos .. out_pos + len]); assert(read_len == len); delim = "."; pos += len; out_pos += len; } } if (!jumped) self.seek(1); } /// Read a qname /// /// The tricky part: Reading domain names, taking labels into consideration. /// Will take something like [3]www[6]google[3]com and append /// www.google.com to outstr. pub fn readQnameAlloc(self: *BytePacketBuffer, alloc: std.mem.Allocator) ![]u8 { var buffer: std.ArrayList(u8) = .init(alloc); // We might encounter jumps, therefore we need to keep thrack of our position locally var pos = self.pos; // track whether or nor we've jumped var jumped = false; const max_jumps: usize = 5; var jumps_performed: usize = 0; var delim: ?[]const u8 = null; while (true) { if (jumps_performed > max_jumps) return ReadError.JumpLimitExceeded; // Each label starts with a length byte const len = try self.get(pos); // If len has the two most signigicant bit set, it represents a jump to some other // offset in the packet: if ((len & 0xC0) == 0xC0) { // Update the buffer position to a point past the current label if (!jumped) self.seek(2); // Read another byte, calculate offset and performe the jump by updating our // local position variable const b2 = @as(u16, try self.get(pos + 1)); const offset = ((@as(u16, len) ^ 0xC0) << 8) | b2; pos = @as(usize, offset); // Indicate that a jump was performed jumped = true; jumps_performed += 1; continue; } else { // Move a single byte forward to move path the length self.pos += 1; pos += 1; // Domain names are terminated by an empty label of length 0, so if the length // is zero we're done if (len == 0) break; if (delim) |del| { try buffer.appendSlice(del); } try buffer.appendSlice(try self.getRange(pos, len)); delim = "."; pos += len; } } if (!jumped) self.seek(1); return buffer.toOwnedSlice(); } }; test "BytePacketBuffer.read" { const testing = std.testing; var buf = BytePacketBuffer{}; buf.buf[0] = 0x1; try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big)); } test "BytePacketBuffer.read_u16" { const testing = std.testing; var buf = BytePacketBuffer{}; buf.buf[0] = 0x1; buf.buf[1] = 0x1; try testing.expectEqual(0x101, try buf.reader().readInt(u16, .big)); } test "BytePacketBuffer.read_u32" { const testing = std.testing; var buf = BytePacketBuffer{}; buf.buf[0] = 0x1; buf.buf[1] = 0x1; buf.buf[2] = 0x1; buf.buf[3] = 0x1; try testing.expectEqual(0x1010101, try buf.reader().readInt(u32, .big)); } test "BytePacketBuffer.read last byte" { const testing = std.testing; var buf = BytePacketBuffer{}; buf.buf[buf.buf.len - 1] = 0x1; buf.pos = buf.buf.len - 1; try testing.expectEqual(0x1, try buf.reader().readInt(u8, .big)); try testing.expectError( BytePacketBuffer.ReadError.EndOfBuffer, buf.reader().readInt(u8, .big), ); } test "BytePacketBuffer.read_qname google.com" { const testing = std.testing; const allocator = testing.allocator; const input = [_]u8{ 0x06, // [6] 0x67, // g 0x6f, // o 0x6f, // o 0x67, // g 0x6c, // l 0x65, // e 0x03, // [3] 0x63, // c 0x6f, // o 0x6d, // m 0x00, // [0] }; const expected = "google.com"; var buf = BytePacketBuffer{}; for (input, 0..) |char, idx| { buf.buf[idx] = char; } const outstr = try allocator.alloc(u8, expected.len); defer allocator.free(outstr); try buf.readQname(outstr); try testing.expectEqualStrings(expected, outstr); } test "BytePacketBuffer.read_qname_alloc google.com" { const testing = std.testing; const allocator = testing.allocator; const input = [_]u8{ 0x06, // [6] 0x67, // g 0x6f, // o 0x6f, // o 0x67, // g 0x6c, // l 0x65, // e 0x03, // [3] 0x63, // c 0x6f, // o 0x6d, // m 0x00, // [0] }; const expected = "google.com"; var buf = BytePacketBuffer{}; for (input, 0..) |char, idx| { buf.buf[idx] = char; } const outstr = try buf.readQnameAlloc(allocator); defer allocator.free(outstr); try testing.expectEqualStrings(expected, outstr); } pub const DnsHeader = packed struct { id: u16 = 0, recursion_desired: bool = false, truncated_message: bool = false, authoritative_answer: bool = false, opcode: u4 = 0, response: bool = false, rescode: enum(u4) { noerror = 0, formerr = 1, servfail = 2, nxdomain = 3, notimp = 4, refused = 5, } = .noerror, checking_disabled: bool = false, authed_data: bool = false, z: bool = false, recursion_available: bool = false, questions: u16 = 0, answers: u16 = 0, authoritative_entries: u16 = 0, resource_entries: u16 = 0, }; pub const QueryType = enum(u16) { a = 1, unknown, }; pub const DnsQuestion = struct { name: []const u8, qtype: QueryType, pub fn read(self: *DnsQuestion, buffer: *BytePacketBuffer, alloc: std.mem.Allocator) !void { self.name = try buffer.readQnameAlloc(alloc); self.qtype = @enumFromInt(try buffer.reader().readInt(u16, .big)); _ = try buffer.reader().readInt(u16, .big); // class } }; // TODO: add DnsQuestion.read test pub const DnsRecord = union(QueryType) { unknown: struct { domain: []const u8, qtype: u16, data_len: u16, ttl: u32, }, a: struct { domain: []const u8, addr: net.Ip4Address, ttl: u32, }, };