atproto relay implementation in zig
zlay.waow.tech
1//! XRPC endpoint handlers for the AT Protocol sync API.
2//!
3//! implements com.atproto.sync.* lexicon endpoints:
4//! listRepos, getRepoStatus, getLatestCommit, listReposByCollection,
5//! listHosts, getHostStatus, requestCrawl
6
7const std = @import("std");
8const h = @import("http.zig");
9const event_log_mod = @import("../event_log.zig");
10const collection_index_mod = @import("../collection_index.zig");
11const slurper_mod = @import("../slurper.zig");
12
13const log = std.log.scoped(.relay);
14
15pub fn handleListRepos(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void {
16 const cursor_str = h.queryParam(query, "cursor") orelse "0";
17 const limit_str = h.queryParam(query, "limit") orelse "500";
18
19 const cursor_val = std.fmt.parseInt(i64, cursor_str, 10) catch {
20 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid cursor\"}");
21 return;
22 };
23 if (cursor_val < 0) {
24 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"cursor must be >= 0\"}");
25 return;
26 }
27
28 const limit = std.fmt.parseInt(i64, limit_str, 10) catch {
29 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}");
30 return;
31 };
32 if (limit < 1 or limit > 1000) {
33 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}");
34 return;
35 }
36
37 // query accounts with repo state, paginated by UID
38 // includes both local status and upstream_status for combined active check
39 var result = persist.db.query(
40 \\SELECT a.uid, a.did, a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '')
41 \\FROM account a LEFT JOIN account_repo r ON a.uid = r.uid
42 \\WHERE a.uid > $1 ORDER BY a.uid ASC LIMIT $2
43 , .{ cursor_val, limit }) catch {
44 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
45 return;
46 };
47 defer result.deinit();
48
49 // build JSON response into a buffer
50 var buf: [65536]u8 = undefined;
51 var fbs = std.io.fixedBufferStream(&buf);
52 const w = fbs.writer();
53
54 var count: i64 = 0;
55 var last_uid: i64 = 0;
56
57 w.writeAll("{\"repos\":[") catch return;
58
59 while (result.nextUnsafe() catch null) |row| {
60 if (count > 0) w.writeByte(',') catch return;
61
62 const uid = row.get(i64, 0);
63 const did = row.get([]const u8, 1);
64 const local_status = row.get([]const u8, 2);
65 const upstream_status = row.get([]const u8, 3);
66 const rev = row.get([]const u8, 4);
67 const head = row.get([]const u8, 5);
68
69 // Go relay: Account.IsActive() — both local AND upstream must be active
70 const local_ok = std.mem.eql(u8, local_status, "active");
71 const upstream_ok = std.mem.eql(u8, upstream_status, "active");
72 const active = local_ok and upstream_ok;
73 // Go relay: Account.AccountStatus() — local takes priority
74 const status = if (!local_ok) local_status else upstream_status;
75
76 w.writeAll("{\"did\":\"") catch return;
77 w.writeAll(did) catch return;
78 w.writeAll("\"") catch return;
79
80 w.writeAll(",\"head\":\"") catch return;
81 w.writeAll(head) catch return;
82 w.writeAll("\",\"rev\":\"") catch return;
83 w.writeAll(rev) catch return;
84 w.writeAll("\"") catch return;
85
86 if (active) {
87 w.writeAll(",\"active\":true") catch return;
88 } else {
89 w.writeAll(",\"active\":false,\"status\":\"") catch return;
90 w.writeAll(status) catch return;
91 w.writeAll("\"") catch return;
92 }
93
94 w.writeByte('}') catch return;
95 last_uid = uid;
96 count += 1;
97 }
98
99 w.writeByte(']') catch return;
100
101 // include cursor if we got a full page
102 if (count >= limit and count >= 2) {
103 std.fmt.format(w, ",\"cursor\":\"{d}\"", .{last_uid}) catch return;
104 }
105
106 w.writeByte('}') catch return;
107
108 h.respondJson(conn, .ok, fbs.getWritten());
109}
110
111pub fn handleGetRepoStatus(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void {
112 var did_buf: [256]u8 = undefined;
113 const did = h.queryParamDecoded(query, "did", &did_buf) orelse {
114 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}");
115 return;
116 };
117
118 // basic DID syntax check
119 if (!std.mem.startsWith(u8, did, "did:")) {
120 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}");
121 return;
122 }
123
124 // look up account (includes both local and upstream status)
125 var row = (persist.db.rowUnsafe(
126 "SELECT a.uid, a.status, a.upstream_status, COALESCE(r.rev, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1",
127 .{did},
128 ) catch {
129 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
130 return;
131 }) orelse {
132 h.respondJson(conn, .not_found, "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}");
133 return;
134 };
135 defer row.deinit() catch {};
136
137 const local_status = row.get([]const u8, 1);
138 const upstream_status = row.get([]const u8, 2);
139 const rev = row.get([]const u8, 3);
140 // Go relay: Account.IsActive() / AccountStatus()
141 const local_ok = std.mem.eql(u8, local_status, "active");
142 const upstream_ok = std.mem.eql(u8, upstream_status, "active");
143 const active = local_ok and upstream_ok;
144 const status = if (!local_ok) local_status else upstream_status;
145
146 var buf: [4096]u8 = undefined;
147 var fbs = std.io.fixedBufferStream(&buf);
148 const w = fbs.writer();
149
150 w.writeAll("{\"did\":\"") catch return;
151 w.writeAll(did) catch return;
152 w.writeAll("\"") catch return;
153
154 if (active) {
155 w.writeAll(",\"active\":true") catch return;
156 } else {
157 w.writeAll(",\"active\":false,\"status\":\"") catch return;
158 w.writeAll(status) catch return;
159 w.writeAll("\"") catch return;
160 }
161
162 if (rev.len > 0) {
163 w.writeAll(",\"rev\":\"") catch return;
164 w.writeAll(rev) catch return;
165 w.writeAll("\"") catch return;
166 }
167
168 w.writeByte('}') catch return;
169 h.respondJson(conn, .ok, fbs.getWritten());
170}
171
172pub fn handleGetLatestCommit(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void {
173 var did_buf: [256]u8 = undefined;
174 const did = h.queryParamDecoded(query, "did", &did_buf) orelse {
175 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}");
176 return;
177 };
178
179 if (!std.mem.startsWith(u8, did, "did:")) {
180 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}");
181 return;
182 }
183
184 // look up account + repo state (includes both local and upstream status)
185 var row = (persist.db.rowUnsafe(
186 "SELECT a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1",
187 .{did},
188 ) catch {
189 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
190 return;
191 }) orelse {
192 h.respondJson(conn, .not_found, "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}");
193 return;
194 };
195 defer row.deinit() catch {};
196
197 const local_status = row.get([]const u8, 0);
198 const upstream_status = row.get([]const u8, 1);
199 const rev = row.get([]const u8, 2);
200 const cid = row.get([]const u8, 3);
201
202 // combined status: local takes priority (Go relay: AccountStatus())
203 const status = if (!std.mem.eql(u8, local_status, "active")) local_status else upstream_status;
204
205 // check account status (match Go relay behavior)
206 if (std.mem.eql(u8, status, "takendown") or std.mem.eql(u8, status, "suspended")) {
207 h.respondJson(conn, .forbidden, "{\"error\":\"RepoTakendown\",\"message\":\"account has been taken down\"}");
208 return;
209 } else if (std.mem.eql(u8, status, "deactivated")) {
210 h.respondJson(conn, .forbidden, "{\"error\":\"RepoDeactivated\",\"message\":\"account is deactivated\"}");
211 return;
212 } else if (std.mem.eql(u8, status, "deleted")) {
213 h.respondJson(conn, .forbidden, "{\"error\":\"RepoDeleted\",\"message\":\"account is deleted\"}");
214 return;
215 } else if (!std.mem.eql(u8, status, "active")) {
216 h.respondJson(conn, .forbidden, "{\"error\":\"RepoInactive\",\"message\":\"account is not active\"}");
217 return;
218 }
219
220 if (rev.len == 0 or cid.len == 0) {
221 h.respondJson(conn, .not_found, "{\"error\":\"RepoNotSynchronized\",\"message\":\"relay has no repo data for this account\"}");
222 return;
223 }
224
225 var buf: [4096]u8 = undefined;
226 var fbs = std.io.fixedBufferStream(&buf);
227 const w = fbs.writer();
228
229 w.writeAll("{\"cid\":\"") catch return;
230 w.writeAll(cid) catch return;
231 w.writeAll("\",\"rev\":\"") catch return;
232 w.writeAll(rev) catch return;
233 w.writeAll("\"}") catch return;
234
235 h.respondJson(conn, .ok, fbs.getWritten());
236}
237
238pub fn handleListReposByCollection(conn: *h.Conn, query: []const u8, ci: *collection_index_mod.CollectionIndex) void {
239 const collection = h.queryParam(query, "collection") orelse {
240 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"collection parameter required\"}");
241 return;
242 };
243
244 if (collection.len == 0 or !std.mem.containsAtLeast(u8, collection, 1, ".")) {
245 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid collection NSID\"}");
246 return;
247 }
248
249 const limit_str = h.queryParam(query, "limit") orelse "500";
250 const limit = std.fmt.parseInt(usize, limit_str, 10) catch {
251 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}");
252 return;
253 };
254 if (limit < 1 or limit > 2000) {
255 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..2000\"}");
256 return;
257 }
258
259 var cursor_buf: [256]u8 = undefined;
260 const cursor_did = h.queryParamDecoded(query, "cursor", &cursor_buf);
261
262 // scan collection index
263 var did_buf: [65536]u8 = undefined;
264 const ci_result = ci.listReposByCollection(collection, limit, cursor_did, &did_buf) catch {
265 h.respondJson(conn, .internal_server_error, "{\"error\":\"InternalError\",\"message\":\"index scan failed\"}");
266 return;
267 };
268
269 // build JSON response
270 var buf: [65536]u8 = undefined;
271 var fbs = std.io.fixedBufferStream(&buf);
272 const w = fbs.writer();
273
274 w.writeAll("{\"repos\":[") catch return;
275 for (0..ci_result.count) |i| {
276 if (i > 0) w.writeByte(',') catch return;
277 w.writeAll("{\"did\":\"") catch return;
278 w.writeAll(ci_result.getDid(i)) catch return;
279 w.writeAll("\"}") catch return;
280 }
281 w.writeByte(']') catch return;
282
283 if (ci_result.last_did) |last| {
284 if (ci_result.count >= limit) {
285 w.writeAll(",\"cursor\":\"") catch return;
286 w.writeAll(last) catch return;
287 w.writeAll("\"") catch return;
288 }
289 }
290
291 w.writeByte('}') catch return;
292 h.respondJson(conn, .ok, fbs.getWritten());
293}
294
295pub fn handleListHosts(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void {
296 const cursor_str = h.queryParam(query, "cursor") orelse "0";
297 const limit_str = h.queryParam(query, "limit") orelse "200";
298
299 const cursor_val = std.fmt.parseInt(i64, cursor_str, 10) catch {
300 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid cursor\"}");
301 return;
302 };
303 if (cursor_val < 0) {
304 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"cursor must be >= 0\"}");
305 return;
306 }
307
308 const limit = std.fmt.parseInt(i64, limit_str, 10) catch {
309 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}");
310 return;
311 };
312 if (limit < 1 or limit > 1000) {
313 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}");
314 return;
315 }
316
317 var result = persist.db.query(
318 "SELECT id, hostname, status, last_seq FROM host WHERE id > $1 AND last_seq > 0 ORDER BY id ASC LIMIT $2",
319 .{ cursor_val, limit },
320 ) catch {
321 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
322 return;
323 };
324 defer result.deinit();
325
326 var buf: [65536]u8 = undefined;
327 var fbs = std.io.fixedBufferStream(&buf);
328 const w = fbs.writer();
329
330 var count: i64 = 0;
331 var last_id: i64 = 0;
332
333 w.writeAll("{\"hosts\":[") catch return;
334
335 while (result.nextUnsafe() catch null) |row| {
336 if (count > 0) w.writeByte(',') catch return;
337
338 const id = row.get(i64, 0);
339 const hostname = row.get([]const u8, 1);
340 const status = row.get([]const u8, 2);
341 const seq = row.get(i64, 3);
342
343 w.writeAll("{\"hostname\":\"") catch return;
344 w.writeAll(hostname) catch return;
345 w.writeAll("\"") catch return;
346 std.fmt.format(w, ",\"seq\":{d}", .{seq}) catch return;
347 w.writeAll(",\"status\":\"") catch return;
348 w.writeAll(status) catch return;
349 w.writeAll("\"}") catch return;
350
351 last_id = id;
352 count += 1;
353 }
354
355 w.writeByte(']') catch return;
356
357 if (count >= limit and count > 1) {
358 std.fmt.format(w, ",\"cursor\":\"{d}\"", .{last_id}) catch return;
359 }
360
361 w.writeByte('}') catch return;
362 h.respondJson(conn, .ok, fbs.getWritten());
363}
364
365pub fn handleGetHostStatus(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void {
366 var hostname_buf: [256]u8 = undefined;
367 const hostname = h.queryParamDecoded(query, "hostname", &hostname_buf) orelse {
368 h.respondJson(conn, .bad_request, "{\"error\":\"InvalidRequest\",\"message\":\"hostname parameter required\"}");
369 return;
370 };
371
372 // look up host
373 var row = (persist.db.rowUnsafe(
374 "SELECT id, hostname, status, last_seq FROM host WHERE hostname = $1",
375 .{hostname},
376 ) catch {
377 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
378 return;
379 }) orelse {
380 h.respondJson(conn, .not_found, "{\"error\":\"HostNotFound\",\"message\":\"host not found\"}");
381 return;
382 };
383 defer row.deinit() catch {};
384
385 const host_id = row.get(i64, 0);
386 const host_name = row.get([]const u8, 1);
387 const raw_status = row.get([]const u8, 2);
388 const seq = row.get(i64, 3);
389
390 // map internal status to lexicon hostStatus values
391 const status = if (std.mem.eql(u8, raw_status, "blocked"))
392 "banned"
393 else if (std.mem.eql(u8, raw_status, "exhausted"))
394 "offline"
395 else
396 raw_status; // active, idle pass through
397
398 // count accounts on this host
399 const account_count: i64 = if (persist.db.rowUnsafe(
400 "SELECT COUNT(*) FROM account WHERE host_id = $1",
401 .{host_id},
402 ) catch null) |cnt_row| blk: {
403 var r = cnt_row;
404 defer r.deinit() catch {};
405 break :blk r.get(i64, 0);
406 } else 0;
407
408 var buf: [4096]u8 = undefined;
409 var fbs = std.io.fixedBufferStream(&buf);
410 const w = fbs.writer();
411
412 w.writeAll("{\"hostname\":\"") catch return;
413 w.writeAll(host_name) catch return;
414 w.writeAll("\"") catch return;
415 std.fmt.format(w, ",\"seq\":{d},\"accountCount\":{d}", .{ seq, account_count }) catch return;
416 w.writeAll(",\"status\":\"") catch return;
417 w.writeAll(status) catch return;
418 w.writeAll("\"}") catch return;
419
420 h.respondJson(conn, .ok, fbs.getWritten());
421}
422
423pub fn handleRequestCrawl(conn: *h.Conn, body: []const u8, slurper: *slurper_mod.Slurper) void {
424 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, slurper.allocator, body, .{ .ignore_unknown_fields = true }) catch {
425 h.respondJson(conn, .bad_request, "{\"error\":\"InvalidRequest\",\"message\":\"invalid JSON, expected {\\\"hostname\\\":\\\"...\\\"}\"}");
426 return;
427 };
428 defer parsed.deinit();
429
430 // fast validation: hostname format (Go relay does this synchronously in handler)
431 const hostname = slurper_mod.validateHostname(slurper.allocator, parsed.value.hostname) catch |err| {
432 log.warn("requestCrawl rejected '{s}': {s}", .{ parsed.value.hostname, @errorName(err) });
433 h.respondJson(conn, .bad_request, switch (err) {
434 error.EmptyHostname => "{\"error\":\"InvalidRequest\",\"message\":\"empty hostname\"}",
435 error.InvalidCharacter => "{\"error\":\"InvalidRequest\",\"message\":\"hostname contains invalid characters\"}",
436 error.InvalidLabel => "{\"error\":\"InvalidRequest\",\"message\":\"hostname has invalid label\"}",
437 error.TooFewLabels => "{\"error\":\"InvalidRequest\",\"message\":\"hostname must have at least two labels (e.g. pds.example.com)\"}",
438 error.LooksLikeIpAddress => "{\"error\":\"InvalidRequest\",\"message\":\"IP addresses not allowed, use a hostname\"}",
439 error.PortNotAllowed => "{\"error\":\"InvalidRequest\",\"message\":\"port numbers not allowed\"}",
440 error.LocalhostNotAllowed => "{\"error\":\"InvalidRequest\",\"message\":\"localhost not allowed\"}",
441 else => "{\"error\":\"InvalidRequest\",\"message\":\"invalid hostname\"}",
442 });
443 return;
444 };
445 defer slurper.allocator.free(hostname);
446
447 // fast validation: domain ban check
448 if (slurper.persist.isDomainBanned(hostname)) {
449 log.warn("requestCrawl rejected '{s}': domain banned", .{hostname});
450 h.respondJson(conn, .bad_request, "{\"error\":\"InvalidRequest\",\"message\":\"domain is banned\"}");
451 return;
452 }
453
454 // enqueue for async processing (describeServer check happens in crawl processor)
455 slurper.addCrawlRequest(hostname) catch {
456 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to store crawl request\"}");
457 return;
458 };
459
460 log.info("crawl requested: {s}", .{hostname});
461 h.respondJson(conn, .ok, "{\"success\":true}");
462}