const std = @import("std"); const stda = @import("stda"); const ourio = @import("ourio"); const atproto = @import("root.zig"); const Allocator = std.mem.Allocator; const dns = stda.net.dns; const json = std.json; const posix = std.posix; /// A stream of events from the jetstream pub const Stream = struct { gpa: Allocator, fd: posix.fd_t = -1, bundle: std.crypto.Certificate.Bundle, host: u2, query: []const u8, ctx: ourio.Context, /// the expected Sec-WebSocket-Accept value challenge: []const u8, state: union(enum) { dns, connect: *stda.net.ConnectTask, handshake: *stda.tls.Client.HandshakeTask, conn: *stda.tls.Client, }, buffer: std.ArrayListUnmanaged(u8) = .empty, const hosts = [_][]const u8{ "jetstream1.us-east.bsky.network", "jetstream2.us-east.bsky.network", "jetstream1.us-west.bsky.network", "jetstream2.us-west.bsky.network", }; const guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; const Msg = enum { dns, connect, handshake, upgrade, recv, write, close, }; pub const Options = struct { /// array of NSIDs to filter for. Can be wildcard prefixes collections: []const []const u8 = &.{}, /// array of DIDs to filter for dids: []const []const u8 = &.{}, /// maximum message size this client would like to receive max_msg_size: usize = 0, /// microsecond timestamp of when to begin replay of events. Now or in the future results in /// a live tail of events cursor_us: ?i64 = null, }; /// Calls ctx.cb with a .userptr pointing to an Event pub fn init( self: *Stream, gpa: Allocator, io: *ourio.Ring, bundle: std.crypto.Certificate.Bundle, opts: Options, resolver: *dns.Resolver, ctx: ourio.Context, ) !void { var seed: u64 = undefined; try std.posix.getrandom(std.mem.asBytes(&seed)); var prng = std.Random.DefaultPrng.init(seed); const idx = prng.random().int(u2); const host = hosts[idx]; const question: dns.Question = .{ .host = host, .type = .A, }; try resolver.resolveQuery(io, question, .{ .ptr = self, .msg = @intFromEnum(Msg.dns), .cb = Stream.onCompletion, }); var list: std.ArrayList(u8) = .init(gpa); defer list.deinit(); for (opts.collections) |collection| { if (list.items.len > 0) { try list.append('&'); } try list.appendSlice("wantedCollections="); if (std.mem.indexOfScalar(u8, collection, '*')) |_| { try list.appendSlice(std.mem.trimRight(u8, collection, "*")); try list.appendSlice("%2A"); } else { try list.appendSlice(collection); } } for (opts.dids) |did| { if (list.items.len > 0) { try list.append('&'); } try list.appendSlice("wantedDids="); try list.appendSlice(did); } if (opts.max_msg_size > 0) { if (list.items.len > 0) { try list.append('&'); } try list.writer().print("maxMessageSizeBytes={d}", .{opts.max_msg_size}); } if (opts.cursor_us) |cursor| { if (list.items.len > 0) { try list.append('&'); } try list.writer().print("cursor={d}", .{cursor}); } self.* = .{ .gpa = gpa, .fd = -1, .bundle = bundle, .state = .dns, .challenge = "", .host = idx, .query = try list.toOwnedSlice(), .ctx = ctx, }; } pub fn deinit(self: *Stream) void { self.gpa.free(self.query); self.gpa.free(self.challenge); self.buffer.deinit(self.gpa); switch (self.state) { .dns => {}, .connect => {}, .handshake => {}, .conn => |conn| { conn.deinit(self.gpa); self.gpa.destroy(conn); }, } } fn reportError(self: *Stream, io: *ourio.Ring, err: anyerror) !void { const task: ourio.Task = .{ .callback = self.ctx.cb, .userdata = self.ctx.ptr, .msg = self.ctx.msg, .result = .{ .userptr = err }, }; try self.ctx.cb(io, task); } pub fn onCompletion(io: *ourio.Ring, task: ourio.Task) anyerror!void { const self = task.userdataCast(Stream); const result = task.result.?; switch (task.msgToEnum(Msg)) { .dns => { const bytes = try result.userbytes; const resp: dns.Response = .{ .bytes = bytes }; var iter = try resp.answerIterator(); while (iter.next()) |answer| { switch (answer) { .A => |ip4| { const addr = std.net.Ip4Address.init(ip4, 443); const conn_task = try stda.net.tcpConnectToAddr(io, .{ .in = addr }, .{ .ptr = self, .msg = @intFromEnum(Msg.connect), .cb = Stream.onCompletion, }); self.state = .{ .connect = conn_task }; }, .CNAME => {}, else => return error.Unexpected, } } }, .connect => { self.fd = result.userfd catch |err| return self.reportError(io, err); const hs_task = try stda.tls.Client.init( io, self.fd, .{ .host = hosts[self.host], .root_ca = self.bundle }, .{ .ptr = self, .msg = @intFromEnum(Msg.handshake), .cb = Stream.onCompletion, }, ); self.state = .{ .handshake = hs_task }; }, .handshake => { const ptr = result.userptr catch |err| return self.reportError(io, err); const conn: *stda.tls.Client = @ptrCast(@alignCast(ptr)); conn.userdata = self; conn.callback = Stream.onCompletion; conn.recv_msg = @intFromEnum(Msg.upgrade); conn.write_msg = @intFromEnum(Msg.write); conn.close_msg = @intFromEnum(Msg.close); self.state = .{ .conn = conn }; var key: [16]u8 = undefined; std.crypto.random.bytes(&key); var base64_buf: [24]u8 = undefined; const base64_key = std.base64.standard.Encoder.encode(&base64_buf, &key); // Precompute the Accept response for the challenge { var accept_buf: [24 + guid.len]u8 = undefined; @memcpy(accept_buf[0..24], base64_key); @memcpy(accept_buf[24..], guid); var hash_buf: [std.crypto.hash.Sha1.digest_length]u8 = undefined; std.crypto.hash.Sha1.hash(&accept_buf, &hash_buf, .{}); self.gpa.free(self.challenge); const base64_hash = try self.gpa.alloc(u8, 28); self.challenge = std.base64.standard.Encoder.encode(base64_hash, &hash_buf); } // Write an upgrade request var writer = conn.cleartext_buf.writer(self.gpa); try writer.print("GET /subscribe?{s} HTTP/1.1\r\n", .{self.query}); try writer.print("Host: {s}\r\n", .{hosts[self.host]}); try writer.writeAll("Upgrade: websocket\r\n"); try writer.writeAll("Connection: Upgrade\r\n"); try writer.print("Sec-WebSocket-Key: {s}\r\n", .{base64_key}); try writer.writeAll("Sec-WebSocket-Version: 13\r\n"); try writer.writeAll("\r\n"); try conn.flush(self.gpa, io); // Initiate receiving messages try conn.recv(io); }, .upgrade => { // Assume we'll always get the header in the first read const n = result.recv catch |err| return self.reportError(io, err); const buf = self.state.conn.read_buf[0..n]; var iter = std.mem.splitSequence(u8, buf, "\r\n"); if (!std.mem.startsWith(u8, iter.first(), "HTTP/1.1 101")) { return error.InvalidUpgrade; } var valid_accept = false; while (iter.next()) |line| { if (line.len == 0) break; var line_iter = std.mem.splitScalar(u8, line, ':'); if (std.ascii.eqlIgnoreCase("Sec-WebSocket-Accept", line_iter.first())) { const value = std.mem.trim(u8, line_iter.rest(), &std.ascii.whitespace); valid_accept = std.mem.eql(u8, value, self.challenge); } } if (!valid_accept) { return error.InvalidSecWebSocketAccept; } self.state.conn.recv_msg = @intFromEnum(Msg.recv); try self.decodeFrames(io, iter.rest()); }, .recv => { const n = try result.recv; const buf = self.state.conn.read_buf[0..n]; try self.decodeFrames(io, buf); }, .write => { _ = result.write catch |err| return self.reportError(io, err); }, .close => { _ = result.close catch |err| return self.reportError(io, err); }, } } fn decodeFrames(self: *Stream, io: *ourio.Ring, bytes: []const u8) !void { try self.buffer.appendSlice(self.gpa, bytes); var arena: std.heap.ArenaAllocator = .init(self.gpa); defer arena.deinit(); var iter: FrameIterator = .{ .bytes = self.buffer.items }; while (try iter.next()) |data| { defer _ = arena.reset(.retain_capacity); switch (data) { .text => |s| { var event: Event = .{ .raw = s, .value = try json.parseFromSliceLeaky( json.Value, arena.allocator(), s, .{ .allocate = .alloc_if_needed }, ) }; const task: ourio.Task = .{ .callback = self.ctx.cb, .userdata = self.ctx.ptr, .msg = self.ctx.msg, .result = .{ .userptr = &event }, }; try self.ctx.cb(io, task); }, .close => { const task: ourio.Task = .{ .callback = self.ctx.cb, .userdata = self.ctx.ptr, .msg = self.ctx.msg, .result = .{ .userptr = error.ConnectionResetByPeer }, }; try self.ctx.cb(io, task); }, .ping => |v| { const byte1: FrameIterator.Byte1 = .{ .final = true, .reserved = 0, .opcode = .pong, }; const byte2: FrameIterator.Byte2 = .{ .len = @intCast(v.len), .mask = true, }; const mask: []const u8 = &.{ 0x12, 0x34, 0x56, 0x78 }; try self.state.conn.write(self.gpa, &.{ @bitCast(byte1), @bitCast(byte2) }); try self.state.conn.write(self.gpa, mask); // TODO: actually apply the mask. As of 2025-05-21 the jetstream doesn't send a // ping payload so it doesn't matter try self.state.conn.write(self.gpa, v); try self.state.conn.flush(self.gpa, io); }, else => { return error.UnsupportedOp; }, } } self.buffer.replaceRangeAssumeCapacity(0, iter.idx, ""); } }; pub const Event = struct { raw: []const u8, value: json.Value, pub const KindEnum = enum { commit, identity, account, }; pub const Kind = union(KindEnum) { commit: Commit, identity: Identity, account: Account, }; pub fn did(self: Event) atproto.Did { const value = self.value.object.get("did").?; return atproto.Did.init(value.string) catch unreachable; } pub fn time(self: Event) i64 { const value = self.value.object.get("time_us").?; return value.integer; } pub fn kind(self: Event) Kind { const value = self.value.object.get("kind").?; const k = std.meta.stringToEnum(KindEnum, value.string).?; switch (k) { .commit => return .{ .commit = .{ .value = self.value.object.get("commit").? } }, .identity => return .{ .identity = .{ .value = self.value.object.get("identity").? } }, .account => return .{ .account = .{ .value = self.value.object.get("account").? } }, } } pub const Commit = struct { value: json.Value, pub const Operation = enum { create, update, delete, }; pub fn rev(self: Commit) []const u8 { return self.value.object.get("rev").?; } pub fn operation(self: Commit) Operation { const op = self.value.object.get("operation").?; return std.meta.stringToEnum(Operation, op.string).?; } pub fn collection(self: Commit) []const u8 { return self.value.object.get("collection").?.string; } pub fn rkey(self: Commit) []const u8 { return self.value.object.get("rkey").?.string; } pub fn cid(self: Commit) []const u8 { return self.value.object.get("cid").?.string; } pub fn record(self: Commit) json.Value { return self.object.get("record").?; } }; pub const Identity = struct { value: json.Value, pub fn did(self: Identity) atproto.Did { const value = self.value.object.get("did").?; return atproto.Did.init(value.string) catch unreachable; } pub fn seq(self: Identity) i64 { return self.value.object.get("seq").?.integer; } pub fn time(self: Identity) []const u8 { return self.value.object.get("time").?.string; } pub fn handle(self: Identity) ?[]const u8 { const value = self.value.object.get("handle") orelse return null; return value.string; } }; pub const Account = struct { value: json.Value, pub fn active(self: Account) bool { return self.value.object.get("active").?.bool; } pub fn did(self: Account) atproto.Did { const value = self.value.object.get("did").?; return atproto.Did.init(value.string) catch unreachable; } pub fn seq(self: Account) i64 { return self.value.object.get("seq").?.integer; } pub fn time(self: Account) []const u8 { return self.value.object.get("time").?.string; } }; }; const FrameIterator = struct { bytes: []const u8, idx: usize = 0, const Data = union(enum) { continuation: []const u8, text: []const u8, binary: []const u8, close, ping: []const u8, pong, }; const Byte1 = packed struct { opcode: enum(u4) { continuation = 0x0, text = 0x1, binary = 0x2, close = 0x8, ping = 0x9, pong = 0xA, _, }, reserved: u3, final: bool, }; const Byte2 = packed struct { len: u7, mask: bool, }; fn next(self: *FrameIterator) !?Data { if (self.idx + 2 >= self.bytes.len) return null; const byte1: Byte1 = @bitCast(self.bytes[self.idx]); const byte2: Byte2 = @bitCast(self.bytes[self.idx + 1]); self.idx += 2; if (!byte1.final) @panic("TODO: continuuation frame support"); if (byte2.mask) return error.MaskNotSupported; const len: usize, const data_start: usize = switch (byte2.len) { 126 => blk: { if (self.idx + 2 > self.bytes.len) { self.idx -= 2; return null; } break :blk .{ std.mem.readInt(u16, self.bytes[self.idx..][0..2], .big), self.idx + 2, }; }, 127 => blk: { if (self.idx + 8 > self.bytes.len) { self.idx -= 2; return null; } break :blk .{ std.mem.readInt(u64, self.bytes[self.idx..][0..8], .big), self.idx + 8 }; }, else => .{ byte2.len, self.idx }, }; const end = data_start + len; if (data_start >= self.bytes.len or end > self.bytes.len) { self.idx -|= 2; return null; } defer self.idx = end; switch (byte1.opcode) { .continuation => return .{ .continuation = self.bytes[data_start..end] }, .text => return .{ .text = self.bytes[data_start..end] }, .binary => return .{ .binary = self.bytes[data_start..end] }, .close => return .close, .ping => return .{ .ping = self.bytes[data_start..end] }, .pong => return .pong, else => return error.InvalidOpcode, } } fn rest(self: *FrameIterator) []const u8 { if (self.idx >= self.bytes.len) return ""; return self.bytes[self.idx..]; } }; test "FrameIterator: valid" { const byte1: FrameIterator.Byte1 = .{ .opcode = .text, .reserved = 0, .final = true, }; const byte2: FrameIterator.Byte2 = .{ .len = 3, .mask = false, }; var buf: [5]u8 = undefined; buf[0] = @bitCast(byte1); buf[1] = @bitCast(byte2); buf[2] = 'f'; buf[3] = 'o'; buf[4] = 'o'; var iter: FrameIterator = .{ .bytes = &buf }; const data = try iter.next(); try std.testing.expectEqualStrings("foo", data.?.text); const last = try iter.next(); try std.testing.expect(last == null); } test "FrameIterator: short read" { const byte1: FrameIterator.Byte1 = .{ .opcode = .text, .reserved = 0, .final = true, }; const byte2: FrameIterator.Byte2 = .{ .len = 3, .mask = false, }; var buf: [3]u8 = undefined; buf[0] = @bitCast(byte1); buf[1] = @bitCast(byte2); buf[2] = 'f'; var iter: FrameIterator = .{ .bytes = &buf }; const data = try iter.next(); try std.testing.expect(data == null); } test "FrameIterator: valid, multiple" { const byte1: FrameIterator.Byte1 = .{ .opcode = .text, .reserved = 0, .final = true, }; const byte2: FrameIterator.Byte2 = .{ .len = 3, .mask = false, }; var buf: [10]u8 = undefined; buf[0] = @bitCast(byte1); buf[1] = @bitCast(byte2); buf[2] = 'f'; buf[3] = 'o'; buf[4] = 'o'; buf[5] = @bitCast(byte1); buf[6] = @bitCast(byte2); buf[7] = 'f'; buf[8] = 'o'; buf[9] = 'o'; var iter: FrameIterator = .{ .bytes = &buf }; { const data = try iter.next(); try std.testing.expectEqualStrings("foo", data.?.text); } { const data = try iter.next(); try std.testing.expectEqualStrings("foo", data.?.text); } const last = try iter.next(); try std.testing.expect(last == null); }