zig library for atproto applications
at main 651 lines 21 kB view raw
1const std = @import("std"); 2const stda = @import("stda"); 3const ourio = @import("ourio"); 4 5const atproto = @import("root.zig"); 6 7const Allocator = std.mem.Allocator; 8const dns = stda.net.dns; 9const json = std.json; 10const posix = std.posix; 11 12/// A stream of events from the jetstream 13pub const Stream = struct { 14 gpa: Allocator, 15 fd: posix.fd_t = -1, 16 bundle: std.crypto.Certificate.Bundle, 17 host: u2, 18 query: []const u8, 19 ctx: ourio.Context, 20 21 /// the expected Sec-WebSocket-Accept value 22 challenge: []const u8, 23 24 state: union(enum) { 25 dns, 26 connect: *stda.net.ConnectTask, 27 handshake: *stda.tls.Client.HandshakeTask, 28 conn: *stda.tls.Client, 29 }, 30 31 buffer: std.ArrayListUnmanaged(u8) = .empty, 32 33 const hosts = [_][]const u8{ 34 "jetstream1.us-east.bsky.network", 35 "jetstream2.us-east.bsky.network", 36 "jetstream1.us-west.bsky.network", 37 "jetstream2.us-west.bsky.network", 38 }; 39 40 const guid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; 41 42 const Msg = enum { 43 dns, 44 connect, 45 handshake, 46 upgrade, 47 recv, 48 write, 49 close, 50 }; 51 52 pub const Options = struct { 53 /// array of NSIDs to filter for. Can be wildcard prefixes 54 collections: []const []const u8 = &.{}, 55 56 /// array of DIDs to filter for 57 dids: []const []const u8 = &.{}, 58 59 /// maximum message size this client would like to receive 60 max_msg_size: usize = 0, 61 62 /// microsecond timestamp of when to begin replay of events. Now or in the future results in 63 /// a live tail of events 64 cursor_us: ?i64 = null, 65 }; 66 67 /// Calls ctx.cb with a .userptr pointing to an Event 68 pub fn init( 69 self: *Stream, 70 gpa: Allocator, 71 io: *ourio.Ring, 72 bundle: std.crypto.Certificate.Bundle, 73 opts: Options, 74 resolver: *dns.Resolver, 75 ctx: ourio.Context, 76 ) !void { 77 var seed: u64 = undefined; 78 try std.posix.getrandom(std.mem.asBytes(&seed)); 79 var prng = std.Random.DefaultPrng.init(seed); 80 const idx = prng.random().int(u2); 81 const host = hosts[idx]; 82 const question: dns.Question = .{ 83 .host = host, 84 .type = .A, 85 }; 86 87 try resolver.resolveQuery(io, question, .{ 88 .ptr = self, 89 .msg = @intFromEnum(Msg.dns), 90 .cb = Stream.onCompletion, 91 }); 92 93 var list: std.ArrayList(u8) = .init(gpa); 94 defer list.deinit(); 95 96 for (opts.collections) |collection| { 97 if (list.items.len > 0) { 98 try list.append('&'); 99 } 100 try list.appendSlice("wantedCollections="); 101 102 if (std.mem.indexOfScalar(u8, collection, '*')) |_| { 103 try list.appendSlice(std.mem.trimRight(u8, collection, "*")); 104 try list.appendSlice("%2A"); 105 } else { 106 try list.appendSlice(collection); 107 } 108 } 109 110 for (opts.dids) |did| { 111 if (list.items.len > 0) { 112 try list.append('&'); 113 } 114 try list.appendSlice("wantedDids="); 115 try list.appendSlice(did); 116 } 117 118 if (opts.max_msg_size > 0) { 119 if (list.items.len > 0) { 120 try list.append('&'); 121 } 122 try list.writer().print("maxMessageSizeBytes={d}", .{opts.max_msg_size}); 123 } 124 125 if (opts.cursor_us) |cursor| { 126 if (list.items.len > 0) { 127 try list.append('&'); 128 } 129 try list.writer().print("cursor={d}", .{cursor}); 130 } 131 132 self.* = .{ 133 .gpa = gpa, 134 .fd = -1, 135 .bundle = bundle, 136 .state = .dns, 137 .challenge = "", 138 .host = idx, 139 .query = try list.toOwnedSlice(), 140 .ctx = ctx, 141 }; 142 } 143 144 pub fn deinit(self: *Stream) void { 145 self.gpa.free(self.query); 146 self.gpa.free(self.challenge); 147 self.buffer.deinit(self.gpa); 148 switch (self.state) { 149 .dns => {}, 150 .connect => {}, 151 .handshake => {}, 152 .conn => |conn| { 153 conn.deinit(self.gpa); 154 self.gpa.destroy(conn); 155 }, 156 } 157 } 158 159 fn reportError(self: *Stream, io: *ourio.Ring, err: anyerror) !void { 160 const task: ourio.Task = .{ 161 .callback = self.ctx.cb, 162 .userdata = self.ctx.ptr, 163 .msg = self.ctx.msg, 164 .result = .{ .userptr = err }, 165 }; 166 try self.ctx.cb(io, task); 167 } 168 169 pub fn onCompletion(io: *ourio.Ring, task: ourio.Task) anyerror!void { 170 const self = task.userdataCast(Stream); 171 const result = task.result.?; 172 173 switch (task.msgToEnum(Msg)) { 174 .dns => { 175 const bytes = try result.userbytes; 176 const resp: dns.Response = .{ .bytes = bytes }; 177 var iter = try resp.answerIterator(); 178 while (iter.next()) |answer| { 179 switch (answer) { 180 .A => |ip4| { 181 const addr = std.net.Ip4Address.init(ip4, 443); 182 const conn_task = try stda.net.tcpConnectToAddr(io, .{ .in = addr }, .{ 183 .ptr = self, 184 .msg = @intFromEnum(Msg.connect), 185 .cb = Stream.onCompletion, 186 }); 187 self.state = .{ .connect = conn_task }; 188 }, 189 190 .CNAME => {}, 191 192 else => return error.Unexpected, 193 } 194 } 195 }, 196 197 .connect => { 198 self.fd = result.userfd catch |err| return self.reportError(io, err); 199 const hs_task = try stda.tls.Client.init( 200 io, 201 self.fd, 202 .{ .host = hosts[self.host], .root_ca = self.bundle }, 203 .{ 204 .ptr = self, 205 .msg = @intFromEnum(Msg.handshake), 206 .cb = Stream.onCompletion, 207 }, 208 ); 209 210 self.state = .{ .handshake = hs_task }; 211 }, 212 213 .handshake => { 214 const ptr = result.userptr catch |err| return self.reportError(io, err); 215 const conn: *stda.tls.Client = @ptrCast(@alignCast(ptr)); 216 217 conn.userdata = self; 218 conn.callback = Stream.onCompletion; 219 conn.recv_msg = @intFromEnum(Msg.upgrade); 220 conn.write_msg = @intFromEnum(Msg.write); 221 conn.close_msg = @intFromEnum(Msg.close); 222 223 self.state = .{ .conn = conn }; 224 225 var key: [16]u8 = undefined; 226 std.crypto.random.bytes(&key); 227 var base64_buf: [24]u8 = undefined; 228 const base64_key = std.base64.standard.Encoder.encode(&base64_buf, &key); 229 230 // Precompute the Accept response for the challenge 231 { 232 var accept_buf: [24 + guid.len]u8 = undefined; 233 @memcpy(accept_buf[0..24], base64_key); 234 @memcpy(accept_buf[24..], guid); 235 var hash_buf: [std.crypto.hash.Sha1.digest_length]u8 = undefined; 236 std.crypto.hash.Sha1.hash(&accept_buf, &hash_buf, .{}); 237 self.gpa.free(self.challenge); 238 const base64_hash = try self.gpa.alloc(u8, 28); 239 self.challenge = std.base64.standard.Encoder.encode(base64_hash, &hash_buf); 240 } 241 242 // Write an upgrade request 243 var writer = conn.cleartext_buf.writer(self.gpa); 244 try writer.print("GET /subscribe?{s} HTTP/1.1\r\n", .{self.query}); 245 try writer.print("Host: {s}\r\n", .{hosts[self.host]}); 246 try writer.writeAll("Upgrade: websocket\r\n"); 247 try writer.writeAll("Connection: Upgrade\r\n"); 248 try writer.print("Sec-WebSocket-Key: {s}\r\n", .{base64_key}); 249 try writer.writeAll("Sec-WebSocket-Version: 13\r\n"); 250 try writer.writeAll("\r\n"); 251 252 try conn.flush(self.gpa, io); 253 254 // Initiate receiving messages 255 try conn.recv(io); 256 }, 257 258 .upgrade => { 259 // Assume we'll always get the header in the first read 260 const n = result.recv catch |err| return self.reportError(io, err); 261 const buf = self.state.conn.read_buf[0..n]; 262 263 var iter = std.mem.splitSequence(u8, buf, "\r\n"); 264 if (!std.mem.startsWith(u8, iter.first(), "HTTP/1.1 101")) { 265 return error.InvalidUpgrade; 266 } 267 268 var valid_accept = false; 269 while (iter.next()) |line| { 270 if (line.len == 0) break; 271 var line_iter = std.mem.splitScalar(u8, line, ':'); 272 if (std.ascii.eqlIgnoreCase("Sec-WebSocket-Accept", line_iter.first())) { 273 const value = std.mem.trim(u8, line_iter.rest(), &std.ascii.whitespace); 274 valid_accept = std.mem.eql(u8, value, self.challenge); 275 } 276 } 277 278 if (!valid_accept) { 279 return error.InvalidSecWebSocketAccept; 280 } 281 282 self.state.conn.recv_msg = @intFromEnum(Msg.recv); 283 284 try self.decodeFrames(io, iter.rest()); 285 }, 286 287 .recv => { 288 const n = try result.recv; 289 const buf = self.state.conn.read_buf[0..n]; 290 try self.decodeFrames(io, buf); 291 }, 292 293 .write => { 294 _ = result.write catch |err| return self.reportError(io, err); 295 }, 296 297 .close => { 298 _ = result.close catch |err| return self.reportError(io, err); 299 }, 300 } 301 } 302 303 fn decodeFrames(self: *Stream, io: *ourio.Ring, bytes: []const u8) !void { 304 try self.buffer.appendSlice(self.gpa, bytes); 305 306 var arena: std.heap.ArenaAllocator = .init(self.gpa); 307 defer arena.deinit(); 308 309 var iter: FrameIterator = .{ .bytes = self.buffer.items }; 310 while (try iter.next()) |data| { 311 defer _ = arena.reset(.retain_capacity); 312 switch (data) { 313 .text => |s| { 314 var event: Event = .{ .raw = s, .value = try json.parseFromSliceLeaky( 315 json.Value, 316 arena.allocator(), 317 s, 318 .{ .allocate = .alloc_if_needed }, 319 ) }; 320 321 const task: ourio.Task = .{ 322 .callback = self.ctx.cb, 323 .userdata = self.ctx.ptr, 324 .msg = self.ctx.msg, 325 .result = .{ .userptr = &event }, 326 }; 327 try self.ctx.cb(io, task); 328 }, 329 330 .close => { 331 const task: ourio.Task = .{ 332 .callback = self.ctx.cb, 333 .userdata = self.ctx.ptr, 334 .msg = self.ctx.msg, 335 .result = .{ .userptr = error.ConnectionResetByPeer }, 336 }; 337 try self.ctx.cb(io, task); 338 }, 339 340 .ping => |v| { 341 const byte1: FrameIterator.Byte1 = .{ 342 .final = true, 343 .reserved = 0, 344 .opcode = .pong, 345 }; 346 const byte2: FrameIterator.Byte2 = .{ 347 .len = @intCast(v.len), 348 .mask = true, 349 }; 350 const mask: []const u8 = &.{ 0x12, 0x34, 0x56, 0x78 }; 351 try self.state.conn.write(self.gpa, &.{ @bitCast(byte1), @bitCast(byte2) }); 352 try self.state.conn.write(self.gpa, mask); 353 // TODO: actually apply the mask. As of 2025-05-21 the jetstream doesn't send a 354 // ping payload so it doesn't matter 355 try self.state.conn.write(self.gpa, v); 356 try self.state.conn.flush(self.gpa, io); 357 }, 358 359 else => { 360 return error.UnsupportedOp; 361 }, 362 } 363 } 364 365 self.buffer.replaceRangeAssumeCapacity(0, iter.idx, ""); 366 } 367}; 368 369pub const Event = struct { 370 raw: []const u8, 371 value: json.Value, 372 373 pub const KindEnum = enum { 374 commit, 375 identity, 376 account, 377 }; 378 379 pub const Kind = union(KindEnum) { 380 commit: Commit, 381 identity: Identity, 382 account: Account, 383 }; 384 385 pub fn did(self: Event) atproto.Did { 386 const value = self.value.object.get("did").?; 387 return atproto.Did.init(value.string) catch unreachable; 388 } 389 390 pub fn time(self: Event) i64 { 391 const value = self.value.object.get("time_us").?; 392 return value.integer; 393 } 394 395 pub fn kind(self: Event) Kind { 396 const value = self.value.object.get("kind").?; 397 const k = std.meta.stringToEnum(KindEnum, value.string).?; 398 switch (k) { 399 .commit => return .{ .commit = .{ .value = self.value.object.get("commit").? } }, 400 .identity => return .{ .identity = .{ .value = self.value.object.get("identity").? } }, 401 .account => return .{ .account = .{ .value = self.value.object.get("account").? } }, 402 } 403 } 404 405 pub const Commit = struct { 406 value: json.Value, 407 408 pub const Operation = enum { 409 create, 410 update, 411 delete, 412 }; 413 414 pub fn rev(self: Commit) []const u8 { 415 return self.value.object.get("rev").?; 416 } 417 418 pub fn operation(self: Commit) Operation { 419 const op = self.value.object.get("operation").?; 420 return std.meta.stringToEnum(Operation, op.string).?; 421 } 422 423 pub fn collection(self: Commit) []const u8 { 424 return self.value.object.get("collection").?.string; 425 } 426 427 pub fn rkey(self: Commit) []const u8 { 428 return self.value.object.get("rkey").?.string; 429 } 430 431 pub fn cid(self: Commit) []const u8 { 432 return self.value.object.get("cid").?.string; 433 } 434 435 pub fn record(self: Commit) json.Value { 436 return self.object.get("record").?; 437 } 438 }; 439 440 pub const Identity = struct { 441 value: json.Value, 442 443 pub fn did(self: Identity) atproto.Did { 444 const value = self.value.object.get("did").?; 445 return atproto.Did.init(value.string) catch unreachable; 446 } 447 448 pub fn seq(self: Identity) i64 { 449 return self.value.object.get("seq").?.integer; 450 } 451 452 pub fn time(self: Identity) []const u8 { 453 return self.value.object.get("time").?.string; 454 } 455 456 pub fn handle(self: Identity) ?[]const u8 { 457 const value = self.value.object.get("handle") orelse return null; 458 return value.string; 459 } 460 }; 461 462 pub const Account = struct { 463 value: json.Value, 464 465 pub fn active(self: Account) bool { 466 return self.value.object.get("active").?.bool; 467 } 468 469 pub fn did(self: Account) atproto.Did { 470 const value = self.value.object.get("did").?; 471 return atproto.Did.init(value.string) catch unreachable; 472 } 473 474 pub fn seq(self: Account) i64 { 475 return self.value.object.get("seq").?.integer; 476 } 477 478 pub fn time(self: Account) []const u8 { 479 return self.value.object.get("time").?.string; 480 } 481 }; 482}; 483 484const FrameIterator = struct { 485 bytes: []const u8, 486 idx: usize = 0, 487 488 const Data = union(enum) { 489 continuation: []const u8, 490 text: []const u8, 491 binary: []const u8, 492 close, 493 ping: []const u8, 494 pong, 495 }; 496 497 const Byte1 = packed struct { 498 opcode: enum(u4) { 499 continuation = 0x0, 500 text = 0x1, 501 binary = 0x2, 502 close = 0x8, 503 ping = 0x9, 504 pong = 0xA, 505 _, 506 }, 507 reserved: u3, 508 final: bool, 509 }; 510 511 const Byte2 = packed struct { 512 len: u7, 513 mask: bool, 514 }; 515 516 fn next(self: *FrameIterator) !?Data { 517 if (self.idx + 2 >= self.bytes.len) return null; 518 519 const byte1: Byte1 = @bitCast(self.bytes[self.idx]); 520 const byte2: Byte2 = @bitCast(self.bytes[self.idx + 1]); 521 self.idx += 2; 522 523 if (!byte1.final) @panic("TODO: continuuation frame support"); 524 if (byte2.mask) return error.MaskNotSupported; 525 526 const len: usize, const data_start: usize = switch (byte2.len) { 527 126 => blk: { 528 if (self.idx + 2 > self.bytes.len) { 529 self.idx -= 2; 530 return null; 531 } 532 break :blk .{ 533 std.mem.readInt(u16, self.bytes[self.idx..][0..2], .big), 534 self.idx + 2, 535 }; 536 }, 537 127 => blk: { 538 if (self.idx + 8 > self.bytes.len) { 539 self.idx -= 2; 540 return null; 541 } 542 break :blk .{ std.mem.readInt(u64, self.bytes[self.idx..][0..8], .big), self.idx + 8 }; 543 }, 544 else => .{ byte2.len, self.idx }, 545 }; 546 547 const end = data_start + len; 548 if (data_start >= self.bytes.len or end > self.bytes.len) { 549 self.idx -|= 2; 550 return null; 551 } 552 defer self.idx = end; 553 554 switch (byte1.opcode) { 555 .continuation => return .{ .continuation = self.bytes[data_start..end] }, 556 .text => return .{ .text = self.bytes[data_start..end] }, 557 .binary => return .{ .binary = self.bytes[data_start..end] }, 558 .close => return .close, 559 .ping => return .{ .ping = self.bytes[data_start..end] }, 560 .pong => return .pong, 561 else => return error.InvalidOpcode, 562 } 563 } 564 565 fn rest(self: *FrameIterator) []const u8 { 566 if (self.idx >= self.bytes.len) return ""; 567 return self.bytes[self.idx..]; 568 } 569}; 570 571test "FrameIterator: valid" { 572 const byte1: FrameIterator.Byte1 = .{ 573 .opcode = .text, 574 .reserved = 0, 575 .final = true, 576 }; 577 const byte2: FrameIterator.Byte2 = .{ 578 .len = 3, 579 .mask = false, 580 }; 581 582 var buf: [5]u8 = undefined; 583 buf[0] = @bitCast(byte1); 584 buf[1] = @bitCast(byte2); 585 buf[2] = 'f'; 586 buf[3] = 'o'; 587 buf[4] = 'o'; 588 589 var iter: FrameIterator = .{ .bytes = &buf }; 590 const data = try iter.next(); 591 try std.testing.expectEqualStrings("foo", data.?.text); 592 const last = try iter.next(); 593 try std.testing.expect(last == null); 594} 595 596test "FrameIterator: short read" { 597 const byte1: FrameIterator.Byte1 = .{ 598 .opcode = .text, 599 .reserved = 0, 600 .final = true, 601 }; 602 const byte2: FrameIterator.Byte2 = .{ 603 .len = 3, 604 .mask = false, 605 }; 606 607 var buf: [3]u8 = undefined; 608 buf[0] = @bitCast(byte1); 609 buf[1] = @bitCast(byte2); 610 buf[2] = 'f'; 611 612 var iter: FrameIterator = .{ .bytes = &buf }; 613 const data = try iter.next(); 614 try std.testing.expect(data == null); 615} 616 617test "FrameIterator: valid, multiple" { 618 const byte1: FrameIterator.Byte1 = .{ 619 .opcode = .text, 620 .reserved = 0, 621 .final = true, 622 }; 623 const byte2: FrameIterator.Byte2 = .{ 624 .len = 3, 625 .mask = false, 626 }; 627 628 var buf: [10]u8 = undefined; 629 buf[0] = @bitCast(byte1); 630 buf[1] = @bitCast(byte2); 631 buf[2] = 'f'; 632 buf[3] = 'o'; 633 buf[4] = 'o'; 634 buf[5] = @bitCast(byte1); 635 buf[6] = @bitCast(byte2); 636 buf[7] = 'f'; 637 buf[8] = 'o'; 638 buf[9] = 'o'; 639 640 var iter: FrameIterator = .{ .bytes = &buf }; 641 { 642 const data = try iter.next(); 643 try std.testing.expectEqualStrings("foo", data.?.text); 644 } 645 { 646 const data = try iter.next(); 647 try std.testing.expectEqualStrings("foo", data.?.text); 648 } 649 const last = try iter.next(); 650 try std.testing.expect(last == null); 651}