zig library for atproto applications
1const std = @import("std");
2const stda = @import("stda");
3const ourio = @import("ourio");
4
5const atproto = @import("root.zig");
6
7const Allocator = std.mem.Allocator;
8const dns = stda.net.dns;
9const json = std.json;
10const posix = std.posix;
11
12/// A stream of events from the jetstream
13pub const Stream = struct {
14 gpa: Allocator,
15 fd: posix.fd_t = -1,
16 bundle: std.crypto.Certificate.Bundle,
17 host: u2,
18 query: []const u8,
19 ctx: ourio.Context,
20
21 /// the expected Sec-WebSocket-Accept value
22 challenge: []const u8,
23
24 state: union(enum) {
25 dns,
26 connect: *stda.net.ConnectTask,
27 handshake: *stda.tls.Client.HandshakeTask,
28 conn: *stda.tls.Client,
29 },
30
31 buffer: std.ArrayListUnmanaged(u8) = .empty,
32
33 const hosts = [_][]const u8{
34 "jetstream1.us-east.bsky.network",
35 "jetstream2.us-east.bsky.network",
36 "jetstream1.us-west.bsky.network",
37 "jetstream2.us-west.bsky.network",
38 };
39
40 const guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
41
42 const Msg = enum {
43 dns,
44 connect,
45 handshake,
46 upgrade,
47 recv,
48 write,
49 close,
50 };
51
52 pub const Options = struct {
53 /// array of NSIDs to filter for. Can be wildcard prefixes
54 collections: []const []const u8 = &.{},
55
56 /// array of DIDs to filter for
57 dids: []const []const u8 = &.{},
58
59 /// maximum message size this client would like to receive
60 max_msg_size: usize = 0,
61
62 /// microsecond timestamp of when to begin replay of events. Now or in the future results in
63 /// a live tail of events
64 cursor_us: ?i64 = null,
65 };
66
67 /// Calls ctx.cb with a .userptr pointing to an Event
68 pub fn init(
69 self: *Stream,
70 gpa: Allocator,
71 io: *ourio.Ring,
72 bundle: std.crypto.Certificate.Bundle,
73 opts: Options,
74 resolver: *dns.Resolver,
75 ctx: ourio.Context,
76 ) !void {
77 var seed: u64 = undefined;
78 try std.posix.getrandom(std.mem.asBytes(&seed));
79 var prng = std.Random.DefaultPrng.init(seed);
80 const idx = prng.random().int(u2);
81 const host = hosts[idx];
82 const question: dns.Question = .{
83 .host = host,
84 .type = .A,
85 };
86
87 try resolver.resolveQuery(io, question, .{
88 .ptr = self,
89 .msg = @intFromEnum(Msg.dns),
90 .cb = Stream.onCompletion,
91 });
92
93 var list: std.ArrayList(u8) = .init(gpa);
94 defer list.deinit();
95
96 for (opts.collections) |collection| {
97 if (list.items.len > 0) {
98 try list.append('&');
99 }
100 try list.appendSlice("wantedCollections=");
101
102 if (std.mem.indexOfScalar(u8, collection, '*')) |_| {
103 try list.appendSlice(std.mem.trimRight(u8, collection, "*"));
104 try list.appendSlice("%2A");
105 } else {
106 try list.appendSlice(collection);
107 }
108 }
109
110 for (opts.dids) |did| {
111 if (list.items.len > 0) {
112 try list.append('&');
113 }
114 try list.appendSlice("wantedDids=");
115 try list.appendSlice(did);
116 }
117
118 if (opts.max_msg_size > 0) {
119 if (list.items.len > 0) {
120 try list.append('&');
121 }
122 try list.writer().print("maxMessageSizeBytes={d}", .{opts.max_msg_size});
123 }
124
125 if (opts.cursor_us) |cursor| {
126 if (list.items.len > 0) {
127 try list.append('&');
128 }
129 try list.writer().print("cursor={d}", .{cursor});
130 }
131
132 self.* = .{
133 .gpa = gpa,
134 .fd = -1,
135 .bundle = bundle,
136 .state = .dns,
137 .challenge = "",
138 .host = idx,
139 .query = try list.toOwnedSlice(),
140 .ctx = ctx,
141 };
142 }
143
144 pub fn deinit(self: *Stream) void {
145 self.gpa.free(self.query);
146 self.gpa.free(self.challenge);
147 self.buffer.deinit(self.gpa);
148 switch (self.state) {
149 .dns => {},
150 .connect => {},
151 .handshake => {},
152 .conn => |conn| {
153 conn.deinit(self.gpa);
154 self.gpa.destroy(conn);
155 },
156 }
157 }
158
159 fn reportError(self: *Stream, io: *ourio.Ring, err: anyerror) !void {
160 const task: ourio.Task = .{
161 .callback = self.ctx.cb,
162 .userdata = self.ctx.ptr,
163 .msg = self.ctx.msg,
164 .result = .{ .userptr = err },
165 };
166 try self.ctx.cb(io, task);
167 }
168
169 pub fn onCompletion(io: *ourio.Ring, task: ourio.Task) anyerror!void {
170 const self = task.userdataCast(Stream);
171 const result = task.result.?;
172
173 switch (task.msgToEnum(Msg)) {
174 .dns => {
175 const bytes = try result.userbytes;
176 const resp: dns.Response = .{ .bytes = bytes };
177 var iter = try resp.answerIterator();
178 while (iter.next()) |answer| {
179 switch (answer) {
180 .A => |ip4| {
181 const addr = std.net.Ip4Address.init(ip4, 443);
182 const conn_task = try stda.net.tcpConnectToAddr(io, .{ .in = addr }, .{
183 .ptr = self,
184 .msg = @intFromEnum(Msg.connect),
185 .cb = Stream.onCompletion,
186 });
187 self.state = .{ .connect = conn_task };
188 },
189
190 .CNAME => {},
191
192 else => return error.Unexpected,
193 }
194 }
195 },
196
197 .connect => {
198 self.fd = result.userfd catch |err| return self.reportError(io, err);
199 const hs_task = try stda.tls.Client.init(
200 io,
201 self.fd,
202 .{ .host = hosts[self.host], .root_ca = self.bundle },
203 .{
204 .ptr = self,
205 .msg = @intFromEnum(Msg.handshake),
206 .cb = Stream.onCompletion,
207 },
208 );
209
210 self.state = .{ .handshake = hs_task };
211 },
212
213 .handshake => {
214 const ptr = result.userptr catch |err| return self.reportError(io, err);
215 const conn: *stda.tls.Client = @ptrCast(@alignCast(ptr));
216
217 conn.userdata = self;
218 conn.callback = Stream.onCompletion;
219 conn.recv_msg = @intFromEnum(Msg.upgrade);
220 conn.write_msg = @intFromEnum(Msg.write);
221 conn.close_msg = @intFromEnum(Msg.close);
222
223 self.state = .{ .conn = conn };
224
225 var key: [16]u8 = undefined;
226 std.crypto.random.bytes(&key);
227 var base64_buf: [24]u8 = undefined;
228 const base64_key = std.base64.standard.Encoder.encode(&base64_buf, &key);
229
230 // Precompute the Accept response for the challenge
231 {
232 var accept_buf: [24 + guid.len]u8 = undefined;
233 @memcpy(accept_buf[0..24], base64_key);
234 @memcpy(accept_buf[24..], guid);
235 var hash_buf: [std.crypto.hash.Sha1.digest_length]u8 = undefined;
236 std.crypto.hash.Sha1.hash(&accept_buf, &hash_buf, .{});
237 self.gpa.free(self.challenge);
238 const base64_hash = try self.gpa.alloc(u8, 28);
239 self.challenge = std.base64.standard.Encoder.encode(base64_hash, &hash_buf);
240 }
241
242 // Write an upgrade request
243 var writer = conn.cleartext_buf.writer(self.gpa);
244 try writer.print("GET /subscribe?{s} HTTP/1.1\r\n", .{self.query});
245 try writer.print("Host: {s}\r\n", .{hosts[self.host]});
246 try writer.writeAll("Upgrade: websocket\r\n");
247 try writer.writeAll("Connection: Upgrade\r\n");
248 try writer.print("Sec-WebSocket-Key: {s}\r\n", .{base64_key});
249 try writer.writeAll("Sec-WebSocket-Version: 13\r\n");
250 try writer.writeAll("\r\n");
251
252 try conn.flush(self.gpa, io);
253
254 // Initiate receiving messages
255 try conn.recv(io);
256 },
257
258 .upgrade => {
259 // Assume we'll always get the header in the first read
260 const n = result.recv catch |err| return self.reportError(io, err);
261 const buf = self.state.conn.read_buf[0..n];
262
263 var iter = std.mem.splitSequence(u8, buf, "\r\n");
264 if (!std.mem.startsWith(u8, iter.first(), "HTTP/1.1 101")) {
265 return error.InvalidUpgrade;
266 }
267
268 var valid_accept = false;
269 while (iter.next()) |line| {
270 if (line.len == 0) break;
271 var line_iter = std.mem.splitScalar(u8, line, ':');
272 if (std.ascii.eqlIgnoreCase("Sec-WebSocket-Accept", line_iter.first())) {
273 const value = std.mem.trim(u8, line_iter.rest(), &std.ascii.whitespace);
274 valid_accept = std.mem.eql(u8, value, self.challenge);
275 }
276 }
277
278 if (!valid_accept) {
279 return error.InvalidSecWebSocketAccept;
280 }
281
282 self.state.conn.recv_msg = @intFromEnum(Msg.recv);
283
284 try self.decodeFrames(io, iter.rest());
285 },
286
287 .recv => {
288 const n = try result.recv;
289 const buf = self.state.conn.read_buf[0..n];
290 try self.decodeFrames(io, buf);
291 },
292
293 .write => {
294 _ = result.write catch |err| return self.reportError(io, err);
295 },
296
297 .close => {
298 _ = result.close catch |err| return self.reportError(io, err);
299 },
300 }
301 }
302
303 fn decodeFrames(self: *Stream, io: *ourio.Ring, bytes: []const u8) !void {
304 try self.buffer.appendSlice(self.gpa, bytes);
305
306 var arena: std.heap.ArenaAllocator = .init(self.gpa);
307 defer arena.deinit();
308
309 var iter: FrameIterator = .{ .bytes = self.buffer.items };
310 while (try iter.next()) |data| {
311 defer _ = arena.reset(.retain_capacity);
312 switch (data) {
313 .text => |s| {
314 var event: Event = .{ .raw = s, .value = try json.parseFromSliceLeaky(
315 json.Value,
316 arena.allocator(),
317 s,
318 .{ .allocate = .alloc_if_needed },
319 ) };
320
321 const task: ourio.Task = .{
322 .callback = self.ctx.cb,
323 .userdata = self.ctx.ptr,
324 .msg = self.ctx.msg,
325 .result = .{ .userptr = &event },
326 };
327 try self.ctx.cb(io, task);
328 },
329
330 .close => {
331 const task: ourio.Task = .{
332 .callback = self.ctx.cb,
333 .userdata = self.ctx.ptr,
334 .msg = self.ctx.msg,
335 .result = .{ .userptr = error.ConnectionResetByPeer },
336 };
337 try self.ctx.cb(io, task);
338 },
339
340 .ping => |v| {
341 const byte1: FrameIterator.Byte1 = .{
342 .final = true,
343 .reserved = 0,
344 .opcode = .pong,
345 };
346 const byte2: FrameIterator.Byte2 = .{
347 .len = @intCast(v.len),
348 .mask = true,
349 };
350 const mask: []const u8 = &.{ 0x12, 0x34, 0x56, 0x78 };
351 try self.state.conn.write(self.gpa, &.{ @bitCast(byte1), @bitCast(byte2) });
352 try self.state.conn.write(self.gpa, mask);
353 // TODO: actually apply the mask. As of 2025-05-21 the jetstream doesn't send a
354 // ping payload so it doesn't matter
355 try self.state.conn.write(self.gpa, v);
356 try self.state.conn.flush(self.gpa, io);
357 },
358
359 else => {
360 return error.UnsupportedOp;
361 },
362 }
363 }
364
365 self.buffer.replaceRangeAssumeCapacity(0, iter.idx, "");
366 }
367};
368
369pub const Event = struct {
370 raw: []const u8,
371 value: json.Value,
372
373 pub const KindEnum = enum {
374 commit,
375 identity,
376 account,
377 };
378
379 pub const Kind = union(KindEnum) {
380 commit: Commit,
381 identity: Identity,
382 account: Account,
383 };
384
385 pub fn did(self: Event) atproto.Did {
386 const value = self.value.object.get("did").?;
387 return atproto.Did.init(value.string) catch unreachable;
388 }
389
390 pub fn time(self: Event) i64 {
391 const value = self.value.object.get("time_us").?;
392 return value.integer;
393 }
394
395 pub fn kind(self: Event) Kind {
396 const value = self.value.object.get("kind").?;
397 const k = std.meta.stringToEnum(KindEnum, value.string).?;
398 switch (k) {
399 .commit => return .{ .commit = .{ .value = self.value.object.get("commit").? } },
400 .identity => return .{ .identity = .{ .value = self.value.object.get("identity").? } },
401 .account => return .{ .account = .{ .value = self.value.object.get("account").? } },
402 }
403 }
404
405 pub const Commit = struct {
406 value: json.Value,
407
408 pub const Operation = enum {
409 create,
410 update,
411 delete,
412 };
413
414 pub fn rev(self: Commit) []const u8 {
415 return self.value.object.get("rev").?;
416 }
417
418 pub fn operation(self: Commit) Operation {
419 const op = self.value.object.get("operation").?;
420 return std.meta.stringToEnum(Operation, op.string).?;
421 }
422
423 pub fn collection(self: Commit) []const u8 {
424 return self.value.object.get("collection").?.string;
425 }
426
427 pub fn rkey(self: Commit) []const u8 {
428 return self.value.object.get("rkey").?.string;
429 }
430
431 pub fn cid(self: Commit) []const u8 {
432 return self.value.object.get("cid").?.string;
433 }
434
435 pub fn record(self: Commit) json.Value {
436 return self.object.get("record").?;
437 }
438 };
439
440 pub const Identity = struct {
441 value: json.Value,
442
443 pub fn did(self: Identity) atproto.Did {
444 const value = self.value.object.get("did").?;
445 return atproto.Did.init(value.string) catch unreachable;
446 }
447
448 pub fn seq(self: Identity) i64 {
449 return self.value.object.get("seq").?.integer;
450 }
451
452 pub fn time(self: Identity) []const u8 {
453 return self.value.object.get("time").?.string;
454 }
455
456 pub fn handle(self: Identity) ?[]const u8 {
457 const value = self.value.object.get("handle") orelse return null;
458 return value.string;
459 }
460 };
461
462 pub const Account = struct {
463 value: json.Value,
464
465 pub fn active(self: Account) bool {
466 return self.value.object.get("active").?.bool;
467 }
468
469 pub fn did(self: Account) atproto.Did {
470 const value = self.value.object.get("did").?;
471 return atproto.Did.init(value.string) catch unreachable;
472 }
473
474 pub fn seq(self: Account) i64 {
475 return self.value.object.get("seq").?.integer;
476 }
477
478 pub fn time(self: Account) []const u8 {
479 return self.value.object.get("time").?.string;
480 }
481 };
482};
483
484const FrameIterator = struct {
485 bytes: []const u8,
486 idx: usize = 0,
487
488 const Data = union(enum) {
489 continuation: []const u8,
490 text: []const u8,
491 binary: []const u8,
492 close,
493 ping: []const u8,
494 pong,
495 };
496
497 const Byte1 = packed struct {
498 opcode: enum(u4) {
499 continuation = 0x0,
500 text = 0x1,
501 binary = 0x2,
502 close = 0x8,
503 ping = 0x9,
504 pong = 0xA,
505 _,
506 },
507 reserved: u3,
508 final: bool,
509 };
510
511 const Byte2 = packed struct {
512 len: u7,
513 mask: bool,
514 };
515
516 fn next(self: *FrameIterator) !?Data {
517 if (self.idx + 2 >= self.bytes.len) return null;
518
519 const byte1: Byte1 = @bitCast(self.bytes[self.idx]);
520 const byte2: Byte2 = @bitCast(self.bytes[self.idx + 1]);
521 self.idx += 2;
522
523 if (!byte1.final) @panic("TODO: continuuation frame support");
524 if (byte2.mask) return error.MaskNotSupported;
525
526 const len: usize, const data_start: usize = switch (byte2.len) {
527 126 => blk: {
528 if (self.idx + 2 > self.bytes.len) {
529 self.idx -= 2;
530 return null;
531 }
532 break :blk .{
533 std.mem.readInt(u16, self.bytes[self.idx..][0..2], .big),
534 self.idx + 2,
535 };
536 },
537 127 => blk: {
538 if (self.idx + 8 > self.bytes.len) {
539 self.idx -= 2;
540 return null;
541 }
542 break :blk .{ std.mem.readInt(u64, self.bytes[self.idx..][0..8], .big), self.idx + 8 };
543 },
544 else => .{ byte2.len, self.idx },
545 };
546
547 const end = data_start + len;
548 if (data_start >= self.bytes.len or end > self.bytes.len) {
549 self.idx -|= 2;
550 return null;
551 }
552 defer self.idx = end;
553
554 switch (byte1.opcode) {
555 .continuation => return .{ .continuation = self.bytes[data_start..end] },
556 .text => return .{ .text = self.bytes[data_start..end] },
557 .binary => return .{ .binary = self.bytes[data_start..end] },
558 .close => return .close,
559 .ping => return .{ .ping = self.bytes[data_start..end] },
560 .pong => return .pong,
561 else => return error.InvalidOpcode,
562 }
563 }
564
565 fn rest(self: *FrameIterator) []const u8 {
566 if (self.idx >= self.bytes.len) return "";
567 return self.bytes[self.idx..];
568 }
569};
570
571test "FrameIterator: valid" {
572 const byte1: FrameIterator.Byte1 = .{
573 .opcode = .text,
574 .reserved = 0,
575 .final = true,
576 };
577 const byte2: FrameIterator.Byte2 = .{
578 .len = 3,
579 .mask = false,
580 };
581
582 var buf: [5]u8 = undefined;
583 buf[0] = @bitCast(byte1);
584 buf[1] = @bitCast(byte2);
585 buf[2] = 'f';
586 buf[3] = 'o';
587 buf[4] = 'o';
588
589 var iter: FrameIterator = .{ .bytes = &buf };
590 const data = try iter.next();
591 try std.testing.expectEqualStrings("foo", data.?.text);
592 const last = try iter.next();
593 try std.testing.expect(last == null);
594}
595
596test "FrameIterator: short read" {
597 const byte1: FrameIterator.Byte1 = .{
598 .opcode = .text,
599 .reserved = 0,
600 .final = true,
601 };
602 const byte2: FrameIterator.Byte2 = .{
603 .len = 3,
604 .mask = false,
605 };
606
607 var buf: [3]u8 = undefined;
608 buf[0] = @bitCast(byte1);
609 buf[1] = @bitCast(byte2);
610 buf[2] = 'f';
611
612 var iter: FrameIterator = .{ .bytes = &buf };
613 const data = try iter.next();
614 try std.testing.expect(data == null);
615}
616
617test "FrameIterator: valid, multiple" {
618 const byte1: FrameIterator.Byte1 = .{
619 .opcode = .text,
620 .reserved = 0,
621 .final = true,
622 };
623 const byte2: FrameIterator.Byte2 = .{
624 .len = 3,
625 .mask = false,
626 };
627
628 var buf: [10]u8 = undefined;
629 buf[0] = @bitCast(byte1);
630 buf[1] = @bitCast(byte2);
631 buf[2] = 'f';
632 buf[3] = 'o';
633 buf[4] = 'o';
634 buf[5] = @bitCast(byte1);
635 buf[6] = @bitCast(byte2);
636 buf[7] = 'f';
637 buf[8] = 'o';
638 buf[9] = 'o';
639
640 var iter: FrameIterator = .{ .bytes = &buf };
641 {
642 const data = try iter.next();
643 try std.testing.expectEqualStrings("foo", data.?.text);
644 }
645 {
646 const data = try iter.next();
647 try std.testing.expectEqualStrings("foo", data.?.text);
648 }
649 const last = try iter.next();
650 try std.testing.expect(last == null);
651}