An asynchronous IO runtime
1const std = @import("std");
2const tls = @import("tls");
3const io = @import("ourio");
4
5const Allocator = std.mem.Allocator;
6const CertBundle = tls.config.cert.Bundle;
7const assert = std.debug.assert;
8const mem = std.mem;
9const posix = std.posix;
10
11pub const Client = struct {
12 gpa: Allocator,
13 fd: posix.fd_t,
14 tls: tls.nonblock.Connection,
15 recv_task: ?*io.Task = null,
16
17 read_buf: [tls.max_ciphertext_record_len]u8 = undefined,
18 read_end: usize = 0,
19
20 cleartext_buf: std.ArrayListUnmanaged(u8) = .empty,
21 ciphertext_buf: std.ArrayListUnmanaged(u8) = .empty,
22 written: usize = 0,
23
24 userdata: ?*anyopaque = null,
25 callback: *const fn (*io.Ring, io.Task) anyerror!void = io.noopCallback,
26 close_msg: u16 = 0,
27 write_msg: u16 = 0,
28 recv_msg: u16 = 0,
29
30 pub const HandshakeTask = struct {
31 userdata: ?*anyopaque,
32 callback: io.Callback,
33 msg: u16,
34
35 fd: posix.fd_t,
36 buffer: [tls.max_ciphertext_record_len]u8 = undefined,
37 read_end: usize = 0,
38 handshake: tls.nonblock.Client,
39 task: *io.Task,
40
41 pub fn handleMsg(rt: *io.Ring, task: io.Task) anyerror!void {
42 const self = task.userdataCast(HandshakeTask);
43 const result = task.result.?;
44
45 switch (result) {
46 .write => {
47 _ = result.write catch |err| {
48 defer rt.gpa.destroy(self);
49 // send the error to the callback
50 try self.callback(rt, .{
51 .userdata = self.userdata,
52 .msg = self.msg,
53 .result = .{ .userptr = err },
54 .callback = self.callback,
55 .req = .userptr,
56 });
57 return;
58 };
59
60 if (self.handshake.done()) {
61 defer rt.gpa.destroy(self);
62 // Handshake is done. Create a client and deliver it to the callback
63 const client = try self.initClient(rt.gpa);
64 try self.callback(rt, .{
65 .userdata = self.userdata,
66 .msg = self.msg,
67 .result = .{ .userptr = client },
68 .callback = self.callback,
69 .req = .userptr,
70 });
71 return;
72 }
73
74 // Arm a recv task
75 self.task = try rt.recv(self.fd, &self.buffer, .{
76 .ptr = self,
77 .cb = handleMsg,
78 });
79 },
80
81 .recv => {
82 const n = result.recv catch |err| {
83 defer rt.gpa.destroy(self);
84 // send the error to the callback
85 try self.callback(rt, .{
86 .userdata = self.userdata,
87 .msg = self.msg,
88 .result = .{ .userptr = err },
89 .callback = self.callback,
90 .req = .userptr,
91 });
92 return;
93 };
94
95 self.read_end += n;
96 const slice = self.buffer[0..self.read_end];
97 var scratch: [tls.max_ciphertext_record_len]u8 = undefined;
98 const r = try self.handshake.run(slice, &scratch);
99
100 if (r.unused_recv.len > 0) {
101 // Arm a recv task
102 self.task = try rt.recv(self.fd, self.buffer[self.read_end..], .{
103 .ptr = self,
104 .cb = handleMsg,
105 });
106 return;
107 }
108
109 if (r.send.len > 0) {
110 // Queue another send
111 @memcpy(self.buffer[0..r.send.len], r.send);
112 self.task = try rt.write(
113 self.fd,
114 self.buffer[0..r.send.len],
115 .{ .ptr = self, .cb = HandshakeTask.handleMsg },
116 );
117 return;
118 }
119
120 if (self.handshake.done()) {
121 defer rt.gpa.destroy(self);
122 // Handshake is done. Create a client and deliver it to the callback
123 const client = try self.initClient(rt.gpa);
124 try self.callback(rt, .{
125 .userdata = self.userdata,
126 .msg = self.msg,
127 .result = .{ .userptr = client },
128 .callback = self.callback,
129 .req = .userptr,
130 });
131 return;
132 }
133 },
134
135 else => unreachable,
136 }
137 }
138
139 fn initClient(self: *HandshakeTask, gpa: Allocator) !*Client {
140 const client = try gpa.create(Client);
141 client.* = .{
142 .gpa = gpa,
143 .fd = self.fd,
144 .tls = .{ .cipher = self.handshake.inner.cipher },
145 };
146 return client;
147 }
148
149 /// Tries to cancel the handshake. Callback will receive an error.Canceled if cancelation
150 /// was successful, otherwise handhsake will proceed
151 pub fn cancel(self: *HandshakeTask, rt: *io.Ring) void {
152 self.task.cancel(rt, null, 0, io.noopCallback) catch {};
153 }
154 };
155
156 const Msg = enum {
157 write,
158 recv,
159 close_notify,
160 };
161
162 /// Initializes a handshake, which will ultimately deliver a Client to the callback via a
163 /// userptr result
164 pub fn init(
165 rt: *io.Ring,
166 fd: posix.fd_t,
167 opts: tls.config.Client,
168 ctx: io.Context,
169 ) !*HandshakeTask {
170 const hs = try rt.gpa.create(HandshakeTask);
171 hs.* = .{
172 .userdata = ctx.ptr,
173 .callback = ctx.cb,
174 .msg = ctx.msg,
175
176 .fd = fd,
177 .handshake = .init(opts),
178 .task = undefined,
179 };
180
181 const result = try hs.handshake.run("", &hs.buffer);
182 const hs_ctx: io.Context = .{ .ptr = hs, .cb = HandshakeTask.handleMsg };
183 hs.task = try rt.write(hs.fd, result.send, hs_ctx);
184 return hs;
185 }
186
187 pub fn deinit(self: *Client, gpa: Allocator) void {
188 self.ciphertext_buf.deinit(gpa);
189 self.cleartext_buf.deinit(gpa);
190 }
191
192 pub fn close(self: *Client, gpa: Allocator, rt: *io.Ring) !void {
193 // close notify is 2 bytes long
194 const len = self.tls.encryptedLength(2);
195 try self.ciphertext_buf.ensureUnusedCapacity(gpa, len);
196 const buf = self.ciphertext_buf.unusedCapacitySlice();
197 const msg = try self.tls.close(buf);
198
199 self.ciphertext_buf.items.len += msg.len;
200 _ = try rt.write(self.fd, self.ciphertext_buf.items[self.written..], .{
201 .ptr = self,
202 .cb = Client.onCompletion,
203 .msg = @intFromEnum(Client.Msg.close_notify),
204 });
205
206 if (self.recv_task) |task| {
207 try task.cancel(rt, .{});
208 self.recv_task = null;
209 }
210 }
211
212 fn onCompletion(rt: *io.Ring, task: io.Task) anyerror!void {
213 const self = task.userdataCast(Client);
214 const result = task.result.?;
215
216 switch (task.msgToEnum(Client.Msg)) {
217 .recv => {
218 assert(result == .recv);
219 self.recv_task = null;
220 const n = result.recv catch |err| {
221 return self.callback(rt, .{
222 .userdata = self.userdata,
223 .msg = self.recv_msg,
224 .callback = self.callback,
225 .req = .{ .recv = .{ .fd = self.fd, .buffer = &self.read_buf } },
226 .result = .{ .recv = err },
227 });
228 };
229 self.read_end += n;
230 const end = self.read_end;
231 const r = try self.tls.decrypt(self.read_buf[0..end], self.read_buf[0..end]);
232
233 if (r.cleartext.len > 0) {
234 try self.callback(rt, .{
235 .userdata = self.userdata,
236 .msg = self.recv_msg,
237 .callback = self.callback,
238 .req = .{ .recv = .{ .fd = self.fd, .buffer = &self.read_buf } },
239 .result = .{ .recv = r.cleartext.len },
240 });
241 }
242 mem.copyForwards(u8, &self.read_buf, r.unused_ciphertext);
243 self.read_end = r.unused_ciphertext.len;
244
245 if (r.closed) {
246 _ = try rt.close(self.fd, self.closeContext());
247 return;
248 }
249
250 self.recv_task = try rt.recv(
251 self.fd,
252 self.read_buf[self.read_end..],
253 self.recvContext(),
254 );
255 },
256
257 .write => {
258 assert(result == .write);
259 const n = result.write catch {
260 return self.callback(rt, .{
261 .userdata = self.userdata,
262 .msg = self.write_msg,
263 .callback = self.callback,
264 .req = .{ .write = .{ .fd = self.fd, .buffer = self.ciphertext_buf.items } },
265 .result = .{ .write = error.Unexpected },
266 });
267 };
268 self.written += n;
269
270 if (self.written < self.ciphertext_buf.items.len) {
271 _ = try rt.write(
272 self.fd,
273 self.ciphertext_buf.items[self.written..],
274 self.writeContext(),
275 );
276 } else {
277 defer {
278 self.written = 0;
279 self.ciphertext_buf.clearRetainingCapacity();
280 }
281 return self.callback(rt, .{
282 .userdata = self.userdata,
283 .msg = self.write_msg,
284 .callback = self.callback,
285 .req = .{ .write = .{ .fd = self.fd, .buffer = self.ciphertext_buf.items } },
286 .result = .{ .write = self.written },
287 });
288 }
289 },
290
291 .close_notify => {
292 assert(result == .write);
293 const n = result.write catch {
294 return self.callback(rt, .{
295 .userdata = self.userdata,
296 .msg = self.close_msg,
297 .callback = self.callback,
298 .req = .{ .close = self.fd },
299 .result = .{ .close = error.Unexpected },
300 });
301 };
302
303 self.written += n;
304
305 if (self.written < self.ciphertext_buf.items.len) {
306 _ = try rt.write(self.fd, self.ciphertext_buf.items[self.written..], .{
307 .ptr = self,
308 .cb = Client.onCompletion,
309 .msg = @intFromEnum(Client.Msg.close_notify),
310 });
311 } else {
312 self.written = 0;
313 self.ciphertext_buf.clearRetainingCapacity();
314 _ = try rt.close(self.fd, self.closeContext());
315 }
316 },
317 }
318 }
319
320 pub fn recv(self: *Client, rt: *io.Ring) !void {
321 if (self.recv_task != null) return;
322 self.recv_task = try rt.recv(
323 self.fd,
324 self.read_buf[self.read_end..],
325 self.recvContext(),
326 );
327 }
328
329 pub fn write(self: *Client, gpa: Allocator, bytes: []const u8) Allocator.Error!void {
330 try self.cleartext_buf.appendSlice(gpa, bytes);
331 }
332
333 pub fn flush(self: *Client, gpa: Allocator, rt: *io.Ring) !void {
334 const len = self.tls.encryptedLength(self.cleartext_buf.items.len);
335 try self.ciphertext_buf.ensureUnusedCapacity(gpa, len);
336 const slice = self.ciphertext_buf.unusedCapacitySlice();
337 const result = try self.tls.encrypt(self.cleartext_buf.items, slice);
338 self.ciphertext_buf.appendSliceAssumeCapacity(result.ciphertext);
339 self.cleartext_buf.replaceRangeAssumeCapacity(0, result.cleartext_pos, "");
340
341 _ = try rt.write(
342 self.fd,
343 self.ciphertext_buf.items.len,
344 self,
345 @intFromEnum(Client.Msg.write),
346 Client.onCompletion,
347 );
348 }
349
350 fn closeContext(self: Client) io.Context {
351 return .{ .ptr = self.userdata, .cb = self.callback, .msg = self.close_msg };
352 }
353
354 fn recvContext(self: *Client) io.Context {
355 return .{
356 .ptr = self,
357 .cb = Client.onCompletion,
358 .msg = @intFromEnum(Client.Msg.recv),
359 };
360 }
361
362 fn writeContext(self: *Client) io.Context {
363 return .{
364 .ptr = self,
365 .cb = Client.onCompletion,
366 .msg = @intFromEnum(Client.Msg.write),
367 };
368 }
369};
370
371test "tls: Client" {
372 const net = @import("net.zig");
373 const gpa = std.testing.allocator;
374
375 var rt = try io.Ring.init(gpa, 16);
376 defer rt.deinit();
377
378 const Foo = struct {
379 const Self = @This();
380 gpa: Allocator,
381 fd: ?posix.fd_t = null,
382 tls: ?*Client = null,
383
384 const Msg = enum {
385 connect,
386 handshake,
387 close,
388 write,
389 recv,
390 };
391
392 fn callback(_: *io.Ring, task: io.Task) anyerror!void {
393 const self = task.userdataCast(Self);
394 const result = task.result.?;
395 errdefer {
396 if (self.tls) |client| {
397 client.deinit(self.gpa);
398 self.gpa.destroy(client);
399 self.tls = null;
400 }
401 }
402
403 switch (task.msgToEnum(Msg)) {
404 .connect => {
405 self.fd = try result.userfd;
406 },
407 .handshake => {
408 const ptr = try result.userptr;
409 self.tls = @ptrCast(@alignCast(ptr));
410 self.tls.?.userdata = self;
411 self.tls.?.close_msg = @intFromEnum(@This().Msg.close);
412 self.tls.?.write_msg = @intFromEnum(@This().Msg.write);
413 self.tls.?.recv_msg = @intFromEnum(@This().Msg.recv);
414 self.tls.?.callback = @This().callback;
415 },
416 .close => {
417 self.tls.?.deinit(self.gpa);
418 self.gpa.destroy(self.tls.?);
419 self.tls = null;
420 self.fd = null;
421 },
422
423 else => {},
424 }
425 }
426 };
427
428 var foo: Foo = .{ .gpa = gpa };
429 defer {
430 if (foo.tls) |client| {
431 client.deinit(gpa);
432 gpa.destroy(client);
433 }
434 if (foo.fd) |fd| posix.close(fd);
435 }
436
437 _ = try net.tcpConnectToHost(
438 &rt,
439 "google.com",
440 443,
441 .{ .ptr = &foo, .cb = Foo.callback, .msg = @intFromEnum(Foo.Msg.connect) },
442 );
443
444 try rt.run(.until_done);
445
446 try std.testing.expect(foo.fd != null);
447
448 var bundle: CertBundle = .{};
449 try bundle.rescan(gpa);
450 defer bundle.deinit(gpa);
451
452 _ = try Client.init(
453 &rt,
454 foo.fd.?,
455 .{ .root_ca = bundle, .host = "google.com" },
456 .{ .ptr = &foo, .cb = Foo.callback, .msg = @intFromEnum(Foo.Msg.handshake) },
457 );
458 try rt.run(.until_done);
459 try std.testing.expect(foo.tls != null);
460
461 try foo.tls.?.recv(&rt);
462 try foo.tls.?.close(gpa, &rt);
463 try rt.run(.until_done);
464 try std.testing.expect(foo.tls == null);
465 try std.testing.expect(foo.fd == null);
466}