atproto relay implementation in zig zlay.waow.tech
at main 462 lines 18 kB view raw
1//! XRPC endpoint handlers for the AT Protocol sync API. 2//! 3//! implements com.atproto.sync.* lexicon endpoints: 4//! listRepos, getRepoStatus, getLatestCommit, listReposByCollection, 5//! listHosts, getHostStatus, requestCrawl 6 7const std = @import("std"); 8const h = @import("http.zig"); 9const event_log_mod = @import("../event_log.zig"); 10const collection_index_mod = @import("../collection_index.zig"); 11const slurper_mod = @import("../slurper.zig"); 12 13const log = std.log.scoped(.relay); 14 15pub fn handleListRepos(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void { 16 const cursor_str = h.queryParam(query, "cursor") orelse "0"; 17 const limit_str = h.queryParam(query, "limit") orelse "500"; 18 19 const cursor_val = std.fmt.parseInt(i64, cursor_str, 10) catch { 20 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid cursor\"}"); 21 return; 22 }; 23 if (cursor_val < 0) { 24 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"cursor must be >= 0\"}"); 25 return; 26 } 27 28 const limit = std.fmt.parseInt(i64, limit_str, 10) catch { 29 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 30 return; 31 }; 32 if (limit < 1 or limit > 1000) { 33 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}"); 34 return; 35 } 36 37 // query accounts with repo state, paginated by UID 38 // includes both local status and upstream_status for combined active check 39 var result = persist.db.query( 40 \\SELECT a.uid, a.did, a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') 41 \\FROM account a LEFT JOIN account_repo r ON a.uid = r.uid 42 \\WHERE a.uid > $1 ORDER BY a.uid ASC LIMIT $2 43 , .{ cursor_val, limit }) catch { 44 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 45 return; 46 }; 47 defer result.deinit(); 48 49 // build JSON response into a buffer 50 var buf: [65536]u8 = undefined; 51 var fbs = std.io.fixedBufferStream(&buf); 52 const w = fbs.writer(); 53 54 var count: i64 = 0; 55 var last_uid: i64 = 0; 56 57 w.writeAll("{\"repos\":[") catch return; 58 59 while (result.nextUnsafe() catch null) |row| { 60 if (count > 0) w.writeByte(',') catch return; 61 62 const uid = row.get(i64, 0); 63 const did = row.get([]const u8, 1); 64 const local_status = row.get([]const u8, 2); 65 const upstream_status = row.get([]const u8, 3); 66 const rev = row.get([]const u8, 4); 67 const head = row.get([]const u8, 5); 68 69 // Go relay: Account.IsActive() — both local AND upstream must be active 70 const local_ok = std.mem.eql(u8, local_status, "active"); 71 const upstream_ok = std.mem.eql(u8, upstream_status, "active"); 72 const active = local_ok and upstream_ok; 73 // Go relay: Account.AccountStatus() — local takes priority 74 const status = if (!local_ok) local_status else upstream_status; 75 76 w.writeAll("{\"did\":\"") catch return; 77 w.writeAll(did) catch return; 78 w.writeAll("\"") catch return; 79 80 w.writeAll(",\"head\":\"") catch return; 81 w.writeAll(head) catch return; 82 w.writeAll("\",\"rev\":\"") catch return; 83 w.writeAll(rev) catch return; 84 w.writeAll("\"") catch return; 85 86 if (active) { 87 w.writeAll(",\"active\":true") catch return; 88 } else { 89 w.writeAll(",\"active\":false,\"status\":\"") catch return; 90 w.writeAll(status) catch return; 91 w.writeAll("\"") catch return; 92 } 93 94 w.writeByte('}') catch return; 95 last_uid = uid; 96 count += 1; 97 } 98 99 w.writeByte(']') catch return; 100 101 // include cursor if we got a full page 102 if (count >= limit and count >= 2) { 103 std.fmt.format(w, ",\"cursor\":\"{d}\"", .{last_uid}) catch return; 104 } 105 106 w.writeByte('}') catch return; 107 108 h.respondJson(conn, .ok, fbs.getWritten()); 109} 110 111pub fn handleGetRepoStatus(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void { 112 var did_buf: [256]u8 = undefined; 113 const did = h.queryParamDecoded(query, "did", &did_buf) orelse { 114 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}"); 115 return; 116 }; 117 118 // basic DID syntax check 119 if (!std.mem.startsWith(u8, did, "did:")) { 120 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}"); 121 return; 122 } 123 124 // look up account (includes both local and upstream status) 125 var row = (persist.db.rowUnsafe( 126 "SELECT a.uid, a.status, a.upstream_status, COALESCE(r.rev, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 127 .{did}, 128 ) catch { 129 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 130 return; 131 }) orelse { 132 h.respondJson(conn, .not_found, "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}"); 133 return; 134 }; 135 defer row.deinit() catch {}; 136 137 const local_status = row.get([]const u8, 1); 138 const upstream_status = row.get([]const u8, 2); 139 const rev = row.get([]const u8, 3); 140 // Go relay: Account.IsActive() / AccountStatus() 141 const local_ok = std.mem.eql(u8, local_status, "active"); 142 const upstream_ok = std.mem.eql(u8, upstream_status, "active"); 143 const active = local_ok and upstream_ok; 144 const status = if (!local_ok) local_status else upstream_status; 145 146 var buf: [4096]u8 = undefined; 147 var fbs = std.io.fixedBufferStream(&buf); 148 const w = fbs.writer(); 149 150 w.writeAll("{\"did\":\"") catch return; 151 w.writeAll(did) catch return; 152 w.writeAll("\"") catch return; 153 154 if (active) { 155 w.writeAll(",\"active\":true") catch return; 156 } else { 157 w.writeAll(",\"active\":false,\"status\":\"") catch return; 158 w.writeAll(status) catch return; 159 w.writeAll("\"") catch return; 160 } 161 162 if (rev.len > 0) { 163 w.writeAll(",\"rev\":\"") catch return; 164 w.writeAll(rev) catch return; 165 w.writeAll("\"") catch return; 166 } 167 168 w.writeByte('}') catch return; 169 h.respondJson(conn, .ok, fbs.getWritten()); 170} 171 172pub fn handleGetLatestCommit(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void { 173 var did_buf: [256]u8 = undefined; 174 const did = h.queryParamDecoded(query, "did", &did_buf) orelse { 175 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"did parameter required\"}"); 176 return; 177 }; 178 179 if (!std.mem.startsWith(u8, did, "did:")) { 180 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid DID\"}"); 181 return; 182 } 183 184 // look up account + repo state (includes both local and upstream status) 185 var row = (persist.db.rowUnsafe( 186 "SELECT a.status, a.upstream_status, COALESCE(r.rev, ''), COALESCE(r.commit_data_cid, '') FROM account a LEFT JOIN account_repo r ON a.uid = r.uid WHERE a.did = $1", 187 .{did}, 188 ) catch { 189 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 190 return; 191 }) orelse { 192 h.respondJson(conn, .not_found, "{\"error\":\"RepoNotFound\",\"message\":\"account not found\"}"); 193 return; 194 }; 195 defer row.deinit() catch {}; 196 197 const local_status = row.get([]const u8, 0); 198 const upstream_status = row.get([]const u8, 1); 199 const rev = row.get([]const u8, 2); 200 const cid = row.get([]const u8, 3); 201 202 // combined status: local takes priority (Go relay: AccountStatus()) 203 const status = if (!std.mem.eql(u8, local_status, "active")) local_status else upstream_status; 204 205 // check account status (match Go relay behavior) 206 if (std.mem.eql(u8, status, "takendown") or std.mem.eql(u8, status, "suspended")) { 207 h.respondJson(conn, .forbidden, "{\"error\":\"RepoTakendown\",\"message\":\"account has been taken down\"}"); 208 return; 209 } else if (std.mem.eql(u8, status, "deactivated")) { 210 h.respondJson(conn, .forbidden, "{\"error\":\"RepoDeactivated\",\"message\":\"account is deactivated\"}"); 211 return; 212 } else if (std.mem.eql(u8, status, "deleted")) { 213 h.respondJson(conn, .forbidden, "{\"error\":\"RepoDeleted\",\"message\":\"account is deleted\"}"); 214 return; 215 } else if (!std.mem.eql(u8, status, "active")) { 216 h.respondJson(conn, .forbidden, "{\"error\":\"RepoInactive\",\"message\":\"account is not active\"}"); 217 return; 218 } 219 220 if (rev.len == 0 or cid.len == 0) { 221 h.respondJson(conn, .not_found, "{\"error\":\"RepoNotSynchronized\",\"message\":\"relay has no repo data for this account\"}"); 222 return; 223 } 224 225 var buf: [4096]u8 = undefined; 226 var fbs = std.io.fixedBufferStream(&buf); 227 const w = fbs.writer(); 228 229 w.writeAll("{\"cid\":\"") catch return; 230 w.writeAll(cid) catch return; 231 w.writeAll("\",\"rev\":\"") catch return; 232 w.writeAll(rev) catch return; 233 w.writeAll("\"}") catch return; 234 235 h.respondJson(conn, .ok, fbs.getWritten()); 236} 237 238pub fn handleListReposByCollection(conn: *h.Conn, query: []const u8, ci: *collection_index_mod.CollectionIndex) void { 239 const collection = h.queryParam(query, "collection") orelse { 240 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"collection parameter required\"}"); 241 return; 242 }; 243 244 if (collection.len == 0 or !std.mem.containsAtLeast(u8, collection, 1, ".")) { 245 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid collection NSID\"}"); 246 return; 247 } 248 249 const limit_str = h.queryParam(query, "limit") orelse "500"; 250 const limit = std.fmt.parseInt(usize, limit_str, 10) catch { 251 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 252 return; 253 }; 254 if (limit < 1 or limit > 2000) { 255 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..2000\"}"); 256 return; 257 } 258 259 var cursor_buf: [256]u8 = undefined; 260 const cursor_did = h.queryParamDecoded(query, "cursor", &cursor_buf); 261 262 // scan collection index 263 var did_buf: [65536]u8 = undefined; 264 const ci_result = ci.listReposByCollection(collection, limit, cursor_did, &did_buf) catch { 265 h.respondJson(conn, .internal_server_error, "{\"error\":\"InternalError\",\"message\":\"index scan failed\"}"); 266 return; 267 }; 268 269 // build JSON response 270 var buf: [65536]u8 = undefined; 271 var fbs = std.io.fixedBufferStream(&buf); 272 const w = fbs.writer(); 273 274 w.writeAll("{\"repos\":[") catch return; 275 for (0..ci_result.count) |i| { 276 if (i > 0) w.writeByte(',') catch return; 277 w.writeAll("{\"did\":\"") catch return; 278 w.writeAll(ci_result.getDid(i)) catch return; 279 w.writeAll("\"}") catch return; 280 } 281 w.writeByte(']') catch return; 282 283 if (ci_result.last_did) |last| { 284 if (ci_result.count >= limit) { 285 w.writeAll(",\"cursor\":\"") catch return; 286 w.writeAll(last) catch return; 287 w.writeAll("\"") catch return; 288 } 289 } 290 291 w.writeByte('}') catch return; 292 h.respondJson(conn, .ok, fbs.getWritten()); 293} 294 295pub fn handleListHosts(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void { 296 const cursor_str = h.queryParam(query, "cursor") orelse "0"; 297 const limit_str = h.queryParam(query, "limit") orelse "200"; 298 299 const cursor_val = std.fmt.parseInt(i64, cursor_str, 10) catch { 300 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid cursor\"}"); 301 return; 302 }; 303 if (cursor_val < 0) { 304 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"cursor must be >= 0\"}"); 305 return; 306 } 307 308 const limit = std.fmt.parseInt(i64, limit_str, 10) catch { 309 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid limit\"}"); 310 return; 311 }; 312 if (limit < 1 or limit > 1000) { 313 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"limit must be 1..1000\"}"); 314 return; 315 } 316 317 var result = persist.db.query( 318 "SELECT id, hostname, status, last_seq FROM host WHERE id > $1 AND last_seq > 0 ORDER BY id ASC LIMIT $2", 319 .{ cursor_val, limit }, 320 ) catch { 321 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 322 return; 323 }; 324 defer result.deinit(); 325 326 var buf: [65536]u8 = undefined; 327 var fbs = std.io.fixedBufferStream(&buf); 328 const w = fbs.writer(); 329 330 var count: i64 = 0; 331 var last_id: i64 = 0; 332 333 w.writeAll("{\"hosts\":[") catch return; 334 335 while (result.nextUnsafe() catch null) |row| { 336 if (count > 0) w.writeByte(',') catch return; 337 338 const id = row.get(i64, 0); 339 const hostname = row.get([]const u8, 1); 340 const status = row.get([]const u8, 2); 341 const seq = row.get(i64, 3); 342 343 w.writeAll("{\"hostname\":\"") catch return; 344 w.writeAll(hostname) catch return; 345 w.writeAll("\"") catch return; 346 std.fmt.format(w, ",\"seq\":{d}", .{seq}) catch return; 347 w.writeAll(",\"status\":\"") catch return; 348 w.writeAll(status) catch return; 349 w.writeAll("\"}") catch return; 350 351 last_id = id; 352 count += 1; 353 } 354 355 w.writeByte(']') catch return; 356 357 if (count >= limit and count > 1) { 358 std.fmt.format(w, ",\"cursor\":\"{d}\"", .{last_id}) catch return; 359 } 360 361 w.writeByte('}') catch return; 362 h.respondJson(conn, .ok, fbs.getWritten()); 363} 364 365pub fn handleGetHostStatus(conn: *h.Conn, query: []const u8, persist: *event_log_mod.DiskPersist) void { 366 var hostname_buf: [256]u8 = undefined; 367 const hostname = h.queryParamDecoded(query, "hostname", &hostname_buf) orelse { 368 h.respondJson(conn, .bad_request, "{\"error\":\"InvalidRequest\",\"message\":\"hostname parameter required\"}"); 369 return; 370 }; 371 372 // look up host 373 var row = (persist.db.rowUnsafe( 374 "SELECT id, hostname, status, last_seq FROM host WHERE hostname = $1", 375 .{hostname}, 376 ) catch { 377 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 378 return; 379 }) orelse { 380 h.respondJson(conn, .not_found, "{\"error\":\"HostNotFound\",\"message\":\"host not found\"}"); 381 return; 382 }; 383 defer row.deinit() catch {}; 384 385 const host_id = row.get(i64, 0); 386 const host_name = row.get([]const u8, 1); 387 const raw_status = row.get([]const u8, 2); 388 const seq = row.get(i64, 3); 389 390 // map internal status to lexicon hostStatus values 391 const status = if (std.mem.eql(u8, raw_status, "blocked")) 392 "banned" 393 else if (std.mem.eql(u8, raw_status, "exhausted")) 394 "offline" 395 else 396 raw_status; // active, idle pass through 397 398 // count accounts on this host 399 const account_count: i64 = if (persist.db.rowUnsafe( 400 "SELECT COUNT(*) FROM account WHERE host_id = $1", 401 .{host_id}, 402 ) catch null) |cnt_row| blk: { 403 var r = cnt_row; 404 defer r.deinit() catch {}; 405 break :blk r.get(i64, 0); 406 } else 0; 407 408 var buf: [4096]u8 = undefined; 409 var fbs = std.io.fixedBufferStream(&buf); 410 const w = fbs.writer(); 411 412 w.writeAll("{\"hostname\":\"") catch return; 413 w.writeAll(host_name) catch return; 414 w.writeAll("\"") catch return; 415 std.fmt.format(w, ",\"seq\":{d},\"accountCount\":{d}", .{ seq, account_count }) catch return; 416 w.writeAll(",\"status\":\"") catch return; 417 w.writeAll(status) catch return; 418 w.writeAll("\"}") catch return; 419 420 h.respondJson(conn, .ok, fbs.getWritten()); 421} 422 423pub fn handleRequestCrawl(conn: *h.Conn, body: []const u8, slurper: *slurper_mod.Slurper) void { 424 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, slurper.allocator, body, .{ .ignore_unknown_fields = true }) catch { 425 h.respondJson(conn, .bad_request, "{\"error\":\"InvalidRequest\",\"message\":\"invalid JSON, expected {\\\"hostname\\\":\\\"...\\\"}\"}"); 426 return; 427 }; 428 defer parsed.deinit(); 429 430 // fast validation: hostname format (Go relay does this synchronously in handler) 431 const hostname = slurper_mod.validateHostname(slurper.allocator, parsed.value.hostname) catch |err| { 432 log.warn("requestCrawl rejected '{s}': {s}", .{ parsed.value.hostname, @errorName(err) }); 433 h.respondJson(conn, .bad_request, switch (err) { 434 error.EmptyHostname => "{\"error\":\"InvalidRequest\",\"message\":\"empty hostname\"}", 435 error.InvalidCharacter => "{\"error\":\"InvalidRequest\",\"message\":\"hostname contains invalid characters\"}", 436 error.InvalidLabel => "{\"error\":\"InvalidRequest\",\"message\":\"hostname has invalid label\"}", 437 error.TooFewLabels => "{\"error\":\"InvalidRequest\",\"message\":\"hostname must have at least two labels (e.g. pds.example.com)\"}", 438 error.LooksLikeIpAddress => "{\"error\":\"InvalidRequest\",\"message\":\"IP addresses not allowed, use a hostname\"}", 439 error.PortNotAllowed => "{\"error\":\"InvalidRequest\",\"message\":\"port numbers not allowed\"}", 440 error.LocalhostNotAllowed => "{\"error\":\"InvalidRequest\",\"message\":\"localhost not allowed\"}", 441 else => "{\"error\":\"InvalidRequest\",\"message\":\"invalid hostname\"}", 442 }); 443 return; 444 }; 445 defer slurper.allocator.free(hostname); 446 447 // fast validation: domain ban check 448 if (slurper.persist.isDomainBanned(hostname)) { 449 log.warn("requestCrawl rejected '{s}': domain banned", .{hostname}); 450 h.respondJson(conn, .bad_request, "{\"error\":\"InvalidRequest\",\"message\":\"domain is banned\"}"); 451 return; 452 } 453 454 // enqueue for async processing (describeServer check happens in crawl processor) 455 slurper.addCrawlRequest(hostname) catch { 456 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to store crawl request\"}"); 457 return; 458 }; 459 460 log.info("crawl requested: {s}", .{hostname}); 461 h.respondJson(conn, .ok, "{\"success\":true}"); 462}