atproto relay implementation in zig zlay.waow.tech
at main 320 lines 13 kB view raw
1//! admin endpoint handlers for relay management. 2//! 3//! all handlers require Bearer token auth against RELAY_ADMIN_PASSWORD. 4//! includes host blocking/unblocking, account bans, and backfill control. 5 6const std = @import("std"); 7const h = @import("http.zig"); 8const router = @import("router.zig"); 9const websocket = @import("websocket"); 10const broadcaster = @import("../broadcaster.zig"); 11const event_log_mod = @import("../event_log.zig"); 12const backfill_mod = @import("../backfill.zig"); 13 14const log = std.log.scoped(.relay); 15 16const HttpContext = router.HttpContext; 17 18/// check admin auth via headers, send error response if not authorized. returns true if authorized. 19pub fn checkAdmin(conn: *h.Conn, headers: ?*const websocket.Handshake.KeyValue) bool { 20 const admin_pw = std.posix.getenv("RELAY_ADMIN_PASSWORD") orelse { 21 h.respondJson(conn, .forbidden, "{\"error\":\"admin endpoint not configured\"}"); 22 return false; 23 }; 24 25 const kv = headers orelse { 26 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}"); 27 return false; 28 }; 29 30 // handshake parser lowercases all header names 31 const auth_value = kv.get("authorization") orelse { 32 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}"); 33 return false; 34 }; 35 36 const bearer_prefix = "Bearer "; 37 if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) { 38 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid authorization scheme\"}"); 39 return false; 40 } 41 const token = auth_value[bearer_prefix.len..]; 42 if (!std.mem.eql(u8, token, admin_pw)) { 43 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid token\"}"); 44 return false; 45 } 46 return true; 47} 48 49pub fn handleBan(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 50 if (!checkAdmin(conn, headers)) return; 51 52 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 53 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}"); 54 return; 55 }; 56 defer parsed.deinit(); 57 const did = parsed.value.did; 58 59 // resolve DID → UID and take down 60 const uid = ctx.persist.uidForDid(did) catch { 61 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to resolve DID\"}"); 62 return; 63 }; 64 ctx.persist.takeDownUser(uid) catch { 65 h.respondJson(conn, .internal_server_error, "{\"error\":\"takedown failed\"}"); 66 return; 67 }; 68 69 // emit #account event so downstream consumers see the takedown 70 if (buildAccountFrame(ctx.persist.allocator, did)) |frame_bytes| { 71 if (ctx.persist.persist(.account, uid, frame_bytes)) |relay_seq| { 72 ctx.bc.stats.relay_seq.store(relay_seq, .release); 73 const broadcast_data = broadcaster.resequenceFrame(ctx.persist.allocator, frame_bytes, relay_seq) orelse frame_bytes; 74 ctx.bc.broadcast(relay_seq, broadcast_data); 75 log.info("admin: emitted #account takedown event for {s} (seq={d})", .{ did, relay_seq }); 76 } else |err| { 77 log.warn("admin: failed to persist #account takedown event: {s}", .{@errorName(err)}); 78 } 79 } 80 81 log.info("admin: banned {s} (uid={d})", .{ did, uid }); 82 h.respondJson(conn, .ok, "{\"success\":true}"); 83} 84 85pub fn handleAdminListHosts(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 86 if (!checkAdmin(conn, headers)) return; 87 88 const persist = ctx.persist; 89 const hosts = persist.listAllHosts(persist.allocator) catch { 90 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}"); 91 return; 92 }; 93 defer { 94 for (hosts) |host| { 95 persist.allocator.free(host.hostname); 96 persist.allocator.free(host.status); 97 } 98 persist.allocator.free(hosts); 99 } 100 101 var list: std.ArrayListUnmanaged(u8) = .{}; 102 defer list.deinit(persist.allocator); 103 const w = list.writer(persist.allocator); 104 105 w.writeAll("{\"hosts\":[") catch return; 106 107 for (hosts, 0..) |host, i| { 108 if (i > 0) w.writeByte(',') catch return; 109 if (host.account_limit) |limit| { 110 std.fmt.format(w, "{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":{d}}}", .{ 111 host.id, 112 host.hostname, 113 host.status, 114 host.last_seq, 115 host.failed_attempts, 116 limit, 117 }) catch return; 118 } else { 119 std.fmt.format(w, "{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":null}}", .{ 120 host.id, 121 host.hostname, 122 host.status, 123 host.last_seq, 124 host.failed_attempts, 125 }) catch return; 126 } 127 } 128 129 std.fmt.format(w, "],\"active_workers\":{d}}}", .{ctx.slurper.workerCount()}) catch return; 130 h.respondJson(conn, .ok, list.items); 131} 132 133pub fn handleAdminBlockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void { 134 if (!checkAdmin(conn, headers)) return; 135 136 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 137 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 138 return; 139 }; 140 defer parsed.deinit(); 141 142 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch { 143 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 144 return; 145 }; 146 147 persist.updateHostStatus(host_info.id, "blocked") catch { 148 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 149 return; 150 }; 151 152 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id }); 153 h.respondJson(conn, .ok, "{\"success\":true}"); 154} 155 156pub fn handleAdminUnblockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void { 157 if (!checkAdmin(conn, headers)) return; 158 159 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch { 160 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}"); 161 return; 162 }; 163 defer parsed.deinit(); 164 165 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch { 166 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}"); 167 return; 168 }; 169 170 persist.updateHostStatus(host_info.id, "active") catch { 171 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}"); 172 return; 173 }; 174 persist.resetHostFailures(host_info.id) catch {}; 175 176 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id }); 177 h.respondJson(conn, .ok, "{\"success\":true}"); 178} 179 180/// set or clear the account_limit override for a host. 181/// POST {"host": "...", "account_limit": 100000} — set override 182/// POST {"host": "...", "account_limit": null} — clear override (revert to COUNT(*)) 183pub fn handleAdminChangeLimits(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void { 184 if (!checkAdmin(conn, headers)) return; 185 186 const parsed = std.json.parseFromSlice( 187 struct { host: []const u8, account_limit: ?u64 }, 188 ctx.persist.allocator, 189 body, 190 .{ .ignore_unknown_fields = true }, 191 ) catch { 192 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"host\\\":\\\"...\\\",\\\"account_limit\\\":...}\"}"); 193 return; 194 }; 195 defer parsed.deinit(); 196 197 const host_id = ctx.persist.getHostIdForHostname(parsed.value.host) catch { 198 h.respondJson(conn, .internal_server_error, "{\"error\":\"database error\"}"); 199 return; 200 } orelse { 201 h.respondJson(conn, .not_found, "{\"error\":\"host not found\"}"); 202 return; 203 }; 204 205 ctx.persist.setHostAccountLimit(host_id, parsed.value.account_limit) catch { 206 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to update limit\"}"); 207 return; 208 }; 209 210 // update running subscriber's rate limits immediately 211 const effective = if (parsed.value.account_limit) |l| l else ctx.persist.getHostAccountCount(host_id); 212 ctx.slurper.updateHostLimits(host_id, effective); 213 214 if (parsed.value.account_limit) |limit| { 215 log.info("admin: set account_limit for {s} (id={d}): {d}", .{ parsed.value.host, host_id, limit }); 216 } else { 217 log.info("admin: cleared account_limit for {s} (id={d}), reverted to COUNT(*)", .{ parsed.value.host, host_id }); 218 } 219 h.respondJson(conn, .ok, "{\"success\":true}"); 220} 221 222pub fn handleAdminBackfillTrigger(conn: *h.Conn, query: []const u8, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void { 223 if (!checkAdmin(conn, headers)) return; 224 225 const source = h.queryParam(query, "source") orelse "bsky.network"; 226 227 backfiller.start(source) catch |err| { 228 switch (err) { 229 error.AlreadyRunning => { 230 h.respondJson(conn, .conflict, "{\"error\":\"backfill already in progress\"}"); 231 }, 232 else => { 233 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start backfill\"}"); 234 }, 235 } 236 return; 237 }; 238 239 var buf: [256]u8 = undefined; 240 const resp_body = std.fmt.bufPrint(&buf, "{{\"status\":\"started\",\"source\":\"{s}\"}}", .{source}) catch { 241 h.respondJson(conn, .ok, "{\"status\":\"started\"}"); 242 return; 243 }; 244 h.respondJson(conn, .ok, resp_body); 245} 246 247pub fn handleAdminBackfillStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void { 248 if (!checkAdmin(conn, headers)) return; 249 250 const body = backfiller.getStatus(backfiller.allocator) catch { 251 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to query backfill status\"}"); 252 return; 253 }; 254 defer backfiller.allocator.free(body); 255 256 h.respondJson(conn, .ok, body); 257} 258 259// --- protocol helpers (used only by handleBan) --- 260 261/// build a CBOR #account frame for a takedown event. 262/// header: {op: 1, t: "#account"}, payload: {seq: 0, did: "...", time: "...", active: false, status: "takendown"} 263fn buildAccountFrame(allocator: std.mem.Allocator, did: []const u8) ?[]const u8 { 264 const zat = @import("zat"); 265 const cbor = zat.cbor; 266 267 const header: cbor.Value = .{ .map = &.{ 268 .{ .key = "op", .value = .{ .unsigned = 1 } }, 269 .{ .key = "t", .value = .{ .text = "#account" } }, 270 } }; 271 272 var time_buf: [24]u8 = undefined; 273 const time_str = formatTimestamp(&time_buf); 274 275 const payload: cbor.Value = .{ .map = &.{ 276 .{ .key = "seq", .value = .{ .unsigned = 0 } }, 277 .{ .key = "did", .value = .{ .text = did } }, 278 .{ .key = "time", .value = .{ .text = time_str } }, 279 .{ .key = "active", .value = .{ .boolean = false } }, 280 .{ .key = "status", .value = .{ .text = "takendown" } }, 281 } }; 282 283 const header_bytes = cbor.encodeAlloc(allocator, header) catch return null; 284 const payload_bytes = cbor.encodeAlloc(allocator, payload) catch { 285 allocator.free(header_bytes); 286 return null; 287 }; 288 289 var frame = allocator.alloc(u8, header_bytes.len + payload_bytes.len) catch { 290 allocator.free(header_bytes); 291 allocator.free(payload_bytes); 292 return null; 293 }; 294 @memcpy(frame[0..header_bytes.len], header_bytes); 295 @memcpy(frame[header_bytes.len..], payload_bytes); 296 297 allocator.free(header_bytes); 298 allocator.free(payload_bytes); 299 300 return frame; 301} 302 303/// format current UTC time as ISO 8601 (YYYY-MM-DDTHH:MM:SSZ) 304fn formatTimestamp(buf: *[24]u8) []const u8 { 305 const ts: u64 = @intCast(std.time.timestamp()); 306 const es = std.time.epoch.EpochSeconds{ .secs = ts }; 307 const day = es.getEpochDay(); 308 const yd = day.calculateYearDay(); 309 const md = yd.calculateMonthDay(); 310 const ds = es.getDaySeconds(); 311 312 return std.fmt.bufPrint(buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{ 313 yd.year, 314 @as(u32, @intFromEnum(md.month)) + 1, 315 @as(u32, md.day_index) + 1, 316 ds.getHoursIntoDay(), 317 ds.getMinutesIntoHour(), 318 ds.getSecondsIntoMinute(), 319 }) catch "1970-01-01T00:00:00Z"; 320}