atproto relay implementation in zig
zlay.waow.tech
1//! admin endpoint handlers for relay management.
2//!
3//! all handlers require Bearer token auth against RELAY_ADMIN_PASSWORD.
4//! includes host blocking/unblocking, account bans, and backfill control.
5
6const std = @import("std");
7const h = @import("http.zig");
8const router = @import("router.zig");
9const websocket = @import("websocket");
10const broadcaster = @import("../broadcaster.zig");
11const event_log_mod = @import("../event_log.zig");
12const backfill_mod = @import("../backfill.zig");
13
14const log = std.log.scoped(.relay);
15
16const HttpContext = router.HttpContext;
17
18/// check admin auth via headers, send error response if not authorized. returns true if authorized.
19pub fn checkAdmin(conn: *h.Conn, headers: ?*const websocket.Handshake.KeyValue) bool {
20 const admin_pw = std.posix.getenv("RELAY_ADMIN_PASSWORD") orelse {
21 h.respondJson(conn, .forbidden, "{\"error\":\"admin endpoint not configured\"}");
22 return false;
23 };
24
25 const kv = headers orelse {
26 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}");
27 return false;
28 };
29
30 // handshake parser lowercases all header names
31 const auth_value = kv.get("authorization") orelse {
32 h.respondJson(conn, .unauthorized, "{\"error\":\"missing authorization header\"}");
33 return false;
34 };
35
36 const bearer_prefix = "Bearer ";
37 if (!std.mem.startsWith(u8, auth_value, bearer_prefix)) {
38 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid authorization scheme\"}");
39 return false;
40 }
41 const token = auth_value[bearer_prefix.len..];
42 if (!std.mem.eql(u8, token, admin_pw)) {
43 h.respondJson(conn, .unauthorized, "{\"error\":\"invalid token\"}");
44 return false;
45 }
46 return true;
47}
48
49pub fn handleBan(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
50 if (!checkAdmin(conn, headers)) return;
51
52 const parsed = std.json.parseFromSlice(struct { did: []const u8 }, ctx.persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
53 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"did\\\":\\\"...\\\"}\"}");
54 return;
55 };
56 defer parsed.deinit();
57 const did = parsed.value.did;
58
59 // resolve DID → UID and take down
60 const uid = ctx.persist.uidForDid(did) catch {
61 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to resolve DID\"}");
62 return;
63 };
64 ctx.persist.takeDownUser(uid) catch {
65 h.respondJson(conn, .internal_server_error, "{\"error\":\"takedown failed\"}");
66 return;
67 };
68
69 // emit #account event so downstream consumers see the takedown
70 if (buildAccountFrame(ctx.persist.allocator, did)) |frame_bytes| {
71 if (ctx.persist.persist(.account, uid, frame_bytes)) |relay_seq| {
72 ctx.bc.stats.relay_seq.store(relay_seq, .release);
73 const broadcast_data = broadcaster.resequenceFrame(ctx.persist.allocator, frame_bytes, relay_seq) orelse frame_bytes;
74 ctx.bc.broadcast(relay_seq, broadcast_data);
75 log.info("admin: emitted #account takedown event for {s} (seq={d})", .{ did, relay_seq });
76 } else |err| {
77 log.warn("admin: failed to persist #account takedown event: {s}", .{@errorName(err)});
78 }
79 }
80
81 log.info("admin: banned {s} (uid={d})", .{ did, uid });
82 h.respondJson(conn, .ok, "{\"success\":true}");
83}
84
85pub fn handleAdminListHosts(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
86 if (!checkAdmin(conn, headers)) return;
87
88 const persist = ctx.persist;
89 const hosts = persist.listAllHosts(persist.allocator) catch {
90 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"query failed\"}");
91 return;
92 };
93 defer {
94 for (hosts) |host| {
95 persist.allocator.free(host.hostname);
96 persist.allocator.free(host.status);
97 }
98 persist.allocator.free(hosts);
99 }
100
101 var list: std.ArrayListUnmanaged(u8) = .{};
102 defer list.deinit(persist.allocator);
103 const w = list.writer(persist.allocator);
104
105 w.writeAll("{\"hosts\":[") catch return;
106
107 for (hosts, 0..) |host, i| {
108 if (i > 0) w.writeByte(',') catch return;
109 if (host.account_limit) |limit| {
110 std.fmt.format(w, "{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":{d}}}", .{
111 host.id,
112 host.hostname,
113 host.status,
114 host.last_seq,
115 host.failed_attempts,
116 limit,
117 }) catch return;
118 } else {
119 std.fmt.format(w, "{{\"id\":{d},\"hostname\":\"{s}\",\"status\":\"{s}\",\"last_seq\":{d},\"failed_attempts\":{d},\"account_limit\":null}}", .{
120 host.id,
121 host.hostname,
122 host.status,
123 host.last_seq,
124 host.failed_attempts,
125 }) catch return;
126 }
127 }
128
129 std.fmt.format(w, "],\"active_workers\":{d}}}", .{ctx.slurper.workerCount()}) catch return;
130 h.respondJson(conn, .ok, list.items);
131}
132
133pub fn handleAdminBlockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void {
134 if (!checkAdmin(conn, headers)) return;
135
136 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
137 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}");
138 return;
139 };
140 defer parsed.deinit();
141
142 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch {
143 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}");
144 return;
145 };
146
147 persist.updateHostStatus(host_info.id, "blocked") catch {
148 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}");
149 return;
150 };
151
152 log.info("admin: blocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id });
153 h.respondJson(conn, .ok, "{\"success\":true}");
154}
155
156pub fn handleAdminUnblockHost(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, persist: *event_log_mod.DiskPersist) void {
157 if (!checkAdmin(conn, headers)) return;
158
159 const parsed = std.json.parseFromSlice(struct { hostname: []const u8 }, persist.allocator, body, .{ .ignore_unknown_fields = true }) catch {
160 h.respondJson(conn, .bad_request, "{\"error\":\"BadRequest\",\"message\":\"invalid JSON\"}");
161 return;
162 };
163 defer parsed.deinit();
164
165 const host_info = persist.getOrCreateHost(parsed.value.hostname) catch {
166 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"host lookup failed\"}");
167 return;
168 };
169
170 persist.updateHostStatus(host_info.id, "active") catch {
171 h.respondJson(conn, .internal_server_error, "{\"error\":\"DatabaseError\",\"message\":\"status update failed\"}");
172 return;
173 };
174 persist.resetHostFailures(host_info.id) catch {};
175
176 log.info("admin: unblocked host {s} (id={d})", .{ parsed.value.hostname, host_info.id });
177 h.respondJson(conn, .ok, "{\"success\":true}");
178}
179
180/// set or clear the account_limit override for a host.
181/// POST {"host": "...", "account_limit": 100000} — set override
182/// POST {"host": "...", "account_limit": null} — clear override (revert to COUNT(*))
183pub fn handleAdminChangeLimits(conn: *h.Conn, body: []const u8, headers: *const websocket.Handshake.KeyValue, ctx: *HttpContext) void {
184 if (!checkAdmin(conn, headers)) return;
185
186 const parsed = std.json.parseFromSlice(
187 struct { host: []const u8, account_limit: ?u64 },
188 ctx.persist.allocator,
189 body,
190 .{ .ignore_unknown_fields = true },
191 ) catch {
192 h.respondJson(conn, .bad_request, "{\"error\":\"invalid JSON, expected {\\\"host\\\":\\\"...\\\",\\\"account_limit\\\":...}\"}");
193 return;
194 };
195 defer parsed.deinit();
196
197 const host_id = ctx.persist.getHostIdForHostname(parsed.value.host) catch {
198 h.respondJson(conn, .internal_server_error, "{\"error\":\"database error\"}");
199 return;
200 } orelse {
201 h.respondJson(conn, .not_found, "{\"error\":\"host not found\"}");
202 return;
203 };
204
205 ctx.persist.setHostAccountLimit(host_id, parsed.value.account_limit) catch {
206 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to update limit\"}");
207 return;
208 };
209
210 // update running subscriber's rate limits immediately
211 const effective = if (parsed.value.account_limit) |l| l else ctx.persist.getHostAccountCount(host_id);
212 ctx.slurper.updateHostLimits(host_id, effective);
213
214 if (parsed.value.account_limit) |limit| {
215 log.info("admin: set account_limit for {s} (id={d}): {d}", .{ parsed.value.host, host_id, limit });
216 } else {
217 log.info("admin: cleared account_limit for {s} (id={d}), reverted to COUNT(*)", .{ parsed.value.host, host_id });
218 }
219 h.respondJson(conn, .ok, "{\"success\":true}");
220}
221
222pub fn handleAdminBackfillTrigger(conn: *h.Conn, query: []const u8, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void {
223 if (!checkAdmin(conn, headers)) return;
224
225 const source = h.queryParam(query, "source") orelse "bsky.network";
226
227 backfiller.start(source) catch |err| {
228 switch (err) {
229 error.AlreadyRunning => {
230 h.respondJson(conn, .conflict, "{\"error\":\"backfill already in progress\"}");
231 },
232 else => {
233 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to start backfill\"}");
234 },
235 }
236 return;
237 };
238
239 var buf: [256]u8 = undefined;
240 const resp_body = std.fmt.bufPrint(&buf, "{{\"status\":\"started\",\"source\":\"{s}\"}}", .{source}) catch {
241 h.respondJson(conn, .ok, "{\"status\":\"started\"}");
242 return;
243 };
244 h.respondJson(conn, .ok, resp_body);
245}
246
247pub fn handleAdminBackfillStatus(conn: *h.Conn, headers: *const websocket.Handshake.KeyValue, backfiller: *backfill_mod.Backfiller) void {
248 if (!checkAdmin(conn, headers)) return;
249
250 const body = backfiller.getStatus(backfiller.allocator) catch {
251 h.respondJson(conn, .internal_server_error, "{\"error\":\"failed to query backfill status\"}");
252 return;
253 };
254 defer backfiller.allocator.free(body);
255
256 h.respondJson(conn, .ok, body);
257}
258
259// --- protocol helpers (used only by handleBan) ---
260
261/// build a CBOR #account frame for a takedown event.
262/// header: {op: 1, t: "#account"}, payload: {seq: 0, did: "...", time: "...", active: false, status: "takendown"}
263fn buildAccountFrame(allocator: std.mem.Allocator, did: []const u8) ?[]const u8 {
264 const zat = @import("zat");
265 const cbor = zat.cbor;
266
267 const header: cbor.Value = .{ .map = &.{
268 .{ .key = "op", .value = .{ .unsigned = 1 } },
269 .{ .key = "t", .value = .{ .text = "#account" } },
270 } };
271
272 var time_buf: [24]u8 = undefined;
273 const time_str = formatTimestamp(&time_buf);
274
275 const payload: cbor.Value = .{ .map = &.{
276 .{ .key = "seq", .value = .{ .unsigned = 0 } },
277 .{ .key = "did", .value = .{ .text = did } },
278 .{ .key = "time", .value = .{ .text = time_str } },
279 .{ .key = "active", .value = .{ .boolean = false } },
280 .{ .key = "status", .value = .{ .text = "takendown" } },
281 } };
282
283 const header_bytes = cbor.encodeAlloc(allocator, header) catch return null;
284 const payload_bytes = cbor.encodeAlloc(allocator, payload) catch {
285 allocator.free(header_bytes);
286 return null;
287 };
288
289 var frame = allocator.alloc(u8, header_bytes.len + payload_bytes.len) catch {
290 allocator.free(header_bytes);
291 allocator.free(payload_bytes);
292 return null;
293 };
294 @memcpy(frame[0..header_bytes.len], header_bytes);
295 @memcpy(frame[header_bytes.len..], payload_bytes);
296
297 allocator.free(header_bytes);
298 allocator.free(payload_bytes);
299
300 return frame;
301}
302
303/// format current UTC time as ISO 8601 (YYYY-MM-DDTHH:MM:SSZ)
304fn formatTimestamp(buf: *[24]u8) []const u8 {
305 const ts: u64 = @intCast(std.time.timestamp());
306 const es = std.time.epoch.EpochSeconds{ .secs = ts };
307 const day = es.getEpochDay();
308 const yd = day.calculateYearDay();
309 const md = yd.calculateMonthDay();
310 const ds = es.getDaySeconds();
311
312 return std.fmt.bufPrint(buf, "{d:0>4}-{d:0>2}-{d:0>2}T{d:0>2}:{d:0>2}:{d:0>2}Z", .{
313 yd.year,
314 @as(u32, @intFromEnum(md.month)) + 1,
315 @as(u32, md.day_index) + 1,
316 ds.getHoursIntoDay(),
317 ds.getMinutesIntoHour(),
318 ds.getSecondsIntoMinute(),
319 }) catch "1970-01-01T00:00:00Z";
320}