An asynchronous IO runtime
at main 16 kB view raw
1const std = @import("std"); 2const tls = @import("tls"); 3const io = @import("ourio"); 4 5const Allocator = std.mem.Allocator; 6const CertBundle = tls.config.cert.Bundle; 7const assert = std.debug.assert; 8const mem = std.mem; 9const posix = std.posix; 10 11pub const Client = struct { 12 gpa: Allocator, 13 fd: posix.fd_t, 14 tls: tls.nonblock.Connection, 15 recv_task: ?*io.Task = null, 16 17 read_buf: [tls.max_ciphertext_record_len]u8 = undefined, 18 read_end: usize = 0, 19 20 cleartext_buf: std.ArrayListUnmanaged(u8) = .empty, 21 ciphertext_buf: std.ArrayListUnmanaged(u8) = .empty, 22 written: usize = 0, 23 24 userdata: ?*anyopaque = null, 25 callback: *const fn (*io.Ring, io.Task) anyerror!void = io.noopCallback, 26 close_msg: u16 = 0, 27 write_msg: u16 = 0, 28 recv_msg: u16 = 0, 29 30 pub const HandshakeTask = struct { 31 userdata: ?*anyopaque, 32 callback: io.Callback, 33 msg: u16, 34 35 fd: posix.fd_t, 36 buffer: [tls.max_ciphertext_record_len]u8 = undefined, 37 read_end: usize = 0, 38 handshake: tls.nonblock.Client, 39 task: *io.Task, 40 41 pub fn handleMsg(rt: *io.Ring, task: io.Task) anyerror!void { 42 const self = task.userdataCast(HandshakeTask); 43 const result = task.result.?; 44 45 switch (result) { 46 .write => { 47 _ = result.write catch |err| { 48 defer rt.gpa.destroy(self); 49 // send the error to the callback 50 try self.callback(rt, .{ 51 .userdata = self.userdata, 52 .msg = self.msg, 53 .result = .{ .userptr = err }, 54 .callback = self.callback, 55 .req = .userptr, 56 }); 57 return; 58 }; 59 60 if (self.handshake.done()) { 61 defer rt.gpa.destroy(self); 62 // Handshake is done. Create a client and deliver it to the callback 63 const client = try self.initClient(rt.gpa); 64 try self.callback(rt, .{ 65 .userdata = self.userdata, 66 .msg = self.msg, 67 .result = .{ .userptr = client }, 68 .callback = self.callback, 69 .req = .userptr, 70 }); 71 return; 72 } 73 74 // Arm a recv task 75 self.task = try rt.recv(self.fd, &self.buffer, .{ 76 .ptr = self, 77 .cb = handleMsg, 78 }); 79 }, 80 81 .recv => { 82 const n = result.recv catch |err| { 83 defer rt.gpa.destroy(self); 84 // send the error to the callback 85 try self.callback(rt, .{ 86 .userdata = self.userdata, 87 .msg = self.msg, 88 .result = .{ .userptr = err }, 89 .callback = self.callback, 90 .req = .userptr, 91 }); 92 return; 93 }; 94 95 self.read_end += n; 96 const slice = self.buffer[0..self.read_end]; 97 var scratch: [tls.max_ciphertext_record_len]u8 = undefined; 98 const r = try self.handshake.run(slice, &scratch); 99 100 if (r.unused_recv.len > 0) { 101 // Arm a recv task 102 self.task = try rt.recv(self.fd, self.buffer[self.read_end..], .{ 103 .ptr = self, 104 .cb = handleMsg, 105 }); 106 return; 107 } 108 109 if (r.send.len > 0) { 110 // Queue another send 111 @memcpy(self.buffer[0..r.send.len], r.send); 112 self.task = try rt.write( 113 self.fd, 114 self.buffer[0..r.send.len], 115 .{ .ptr = self, .cb = HandshakeTask.handleMsg }, 116 ); 117 return; 118 } 119 120 if (self.handshake.done()) { 121 defer rt.gpa.destroy(self); 122 // Handshake is done. Create a client and deliver it to the callback 123 const client = try self.initClient(rt.gpa); 124 try self.callback(rt, .{ 125 .userdata = self.userdata, 126 .msg = self.msg, 127 .result = .{ .userptr = client }, 128 .callback = self.callback, 129 .req = .userptr, 130 }); 131 return; 132 } 133 }, 134 135 else => unreachable, 136 } 137 } 138 139 fn initClient(self: *HandshakeTask, gpa: Allocator) !*Client { 140 const client = try gpa.create(Client); 141 client.* = .{ 142 .gpa = gpa, 143 .fd = self.fd, 144 .tls = .{ .cipher = self.handshake.inner.cipher }, 145 }; 146 return client; 147 } 148 149 /// Tries to cancel the handshake. Callback will receive an error.Canceled if cancelation 150 /// was successful, otherwise handhsake will proceed 151 pub fn cancel(self: *HandshakeTask, rt: *io.Ring) void { 152 self.task.cancel(rt, null, 0, io.noopCallback) catch {}; 153 } 154 }; 155 156 const Msg = enum { 157 write, 158 recv, 159 close_notify, 160 }; 161 162 /// Initializes a handshake, which will ultimately deliver a Client to the callback via a 163 /// userptr result 164 pub fn init( 165 rt: *io.Ring, 166 fd: posix.fd_t, 167 opts: tls.config.Client, 168 ctx: io.Context, 169 ) !*HandshakeTask { 170 const hs = try rt.gpa.create(HandshakeTask); 171 hs.* = .{ 172 .userdata = ctx.ptr, 173 .callback = ctx.cb, 174 .msg = ctx.msg, 175 176 .fd = fd, 177 .handshake = .init(opts), 178 .task = undefined, 179 }; 180 181 const result = try hs.handshake.run("", &hs.buffer); 182 const hs_ctx: io.Context = .{ .ptr = hs, .cb = HandshakeTask.handleMsg }; 183 hs.task = try rt.write(hs.fd, result.send, hs_ctx); 184 return hs; 185 } 186 187 pub fn deinit(self: *Client, gpa: Allocator) void { 188 self.ciphertext_buf.deinit(gpa); 189 self.cleartext_buf.deinit(gpa); 190 } 191 192 pub fn close(self: *Client, gpa: Allocator, rt: *io.Ring) !void { 193 // close notify is 2 bytes long 194 const len = self.tls.encryptedLength(2); 195 try self.ciphertext_buf.ensureUnusedCapacity(gpa, len); 196 const buf = self.ciphertext_buf.unusedCapacitySlice(); 197 const msg = try self.tls.close(buf); 198 199 self.ciphertext_buf.items.len += msg.len; 200 _ = try rt.write(self.fd, self.ciphertext_buf.items[self.written..], .{ 201 .ptr = self, 202 .cb = Client.onCompletion, 203 .msg = @intFromEnum(Client.Msg.close_notify), 204 }); 205 206 if (self.recv_task) |task| { 207 try task.cancel(rt, .{}); 208 self.recv_task = null; 209 } 210 } 211 212 fn onCompletion(rt: *io.Ring, task: io.Task) anyerror!void { 213 const self = task.userdataCast(Client); 214 const result = task.result.?; 215 216 switch (task.msgToEnum(Client.Msg)) { 217 .recv => { 218 assert(result == .recv); 219 self.recv_task = null; 220 const n = result.recv catch |err| { 221 return self.callback(rt, .{ 222 .userdata = self.userdata, 223 .msg = self.recv_msg, 224 .callback = self.callback, 225 .req = .{ .recv = .{ .fd = self.fd, .buffer = &self.read_buf } }, 226 .result = .{ .recv = err }, 227 }); 228 }; 229 self.read_end += n; 230 const end = self.read_end; 231 const r = try self.tls.decrypt(self.read_buf[0..end], self.read_buf[0..end]); 232 233 if (r.cleartext.len > 0) { 234 try self.callback(rt, .{ 235 .userdata = self.userdata, 236 .msg = self.recv_msg, 237 .callback = self.callback, 238 .req = .{ .recv = .{ .fd = self.fd, .buffer = &self.read_buf } }, 239 .result = .{ .recv = r.cleartext.len }, 240 }); 241 } 242 mem.copyForwards(u8, &self.read_buf, r.unused_ciphertext); 243 self.read_end = r.unused_ciphertext.len; 244 245 if (r.closed) { 246 _ = try rt.close(self.fd, self.closeContext()); 247 return; 248 } 249 250 self.recv_task = try rt.recv( 251 self.fd, 252 self.read_buf[self.read_end..], 253 self.recvContext(), 254 ); 255 }, 256 257 .write => { 258 assert(result == .write); 259 const n = result.write catch { 260 return self.callback(rt, .{ 261 .userdata = self.userdata, 262 .msg = self.write_msg, 263 .callback = self.callback, 264 .req = .{ .write = .{ .fd = self.fd, .buffer = self.ciphertext_buf.items } }, 265 .result = .{ .write = error.Unexpected }, 266 }); 267 }; 268 self.written += n; 269 270 if (self.written < self.ciphertext_buf.items.len) { 271 _ = try rt.write( 272 self.fd, 273 self.ciphertext_buf.items[self.written..], 274 self.writeContext(), 275 ); 276 } else { 277 defer { 278 self.written = 0; 279 self.ciphertext_buf.clearRetainingCapacity(); 280 } 281 return self.callback(rt, .{ 282 .userdata = self.userdata, 283 .msg = self.write_msg, 284 .callback = self.callback, 285 .req = .{ .write = .{ .fd = self.fd, .buffer = self.ciphertext_buf.items } }, 286 .result = .{ .write = self.written }, 287 }); 288 } 289 }, 290 291 .close_notify => { 292 assert(result == .write); 293 const n = result.write catch { 294 return self.callback(rt, .{ 295 .userdata = self.userdata, 296 .msg = self.close_msg, 297 .callback = self.callback, 298 .req = .{ .close = self.fd }, 299 .result = .{ .close = error.Unexpected }, 300 }); 301 }; 302 303 self.written += n; 304 305 if (self.written < self.ciphertext_buf.items.len) { 306 _ = try rt.write(self.fd, self.ciphertext_buf.items[self.written..], .{ 307 .ptr = self, 308 .cb = Client.onCompletion, 309 .msg = @intFromEnum(Client.Msg.close_notify), 310 }); 311 } else { 312 self.written = 0; 313 self.ciphertext_buf.clearRetainingCapacity(); 314 _ = try rt.close(self.fd, self.closeContext()); 315 } 316 }, 317 } 318 } 319 320 pub fn recv(self: *Client, rt: *io.Ring) !void { 321 if (self.recv_task != null) return; 322 self.recv_task = try rt.recv( 323 self.fd, 324 self.read_buf[self.read_end..], 325 self.recvContext(), 326 ); 327 } 328 329 pub fn write(self: *Client, gpa: Allocator, bytes: []const u8) Allocator.Error!void { 330 try self.cleartext_buf.appendSlice(gpa, bytes); 331 } 332 333 pub fn flush(self: *Client, gpa: Allocator, rt: *io.Ring) !void { 334 const len = self.tls.encryptedLength(self.cleartext_buf.items.len); 335 try self.ciphertext_buf.ensureUnusedCapacity(gpa, len); 336 const slice = self.ciphertext_buf.unusedCapacitySlice(); 337 const result = try self.tls.encrypt(self.cleartext_buf.items, slice); 338 self.ciphertext_buf.appendSliceAssumeCapacity(result.ciphertext); 339 self.cleartext_buf.replaceRangeAssumeCapacity(0, result.cleartext_pos, ""); 340 341 _ = try rt.write( 342 self.fd, 343 self.ciphertext_buf.items.len, 344 self, 345 @intFromEnum(Client.Msg.write), 346 Client.onCompletion, 347 ); 348 } 349 350 fn closeContext(self: Client) io.Context { 351 return .{ .ptr = self.userdata, .cb = self.callback, .msg = self.close_msg }; 352 } 353 354 fn recvContext(self: *Client) io.Context { 355 return .{ 356 .ptr = self, 357 .cb = Client.onCompletion, 358 .msg = @intFromEnum(Client.Msg.recv), 359 }; 360 } 361 362 fn writeContext(self: *Client) io.Context { 363 return .{ 364 .ptr = self, 365 .cb = Client.onCompletion, 366 .msg = @intFromEnum(Client.Msg.write), 367 }; 368 } 369}; 370 371test "tls: Client" { 372 const net = @import("net.zig"); 373 const gpa = std.testing.allocator; 374 375 var rt = try io.Ring.init(gpa, 16); 376 defer rt.deinit(); 377 378 const Foo = struct { 379 const Self = @This(); 380 gpa: Allocator, 381 fd: ?posix.fd_t = null, 382 tls: ?*Client = null, 383 384 const Msg = enum { 385 connect, 386 handshake, 387 close, 388 write, 389 recv, 390 }; 391 392 fn callback(_: *io.Ring, task: io.Task) anyerror!void { 393 const self = task.userdataCast(Self); 394 const result = task.result.?; 395 errdefer { 396 if (self.tls) |client| { 397 client.deinit(self.gpa); 398 self.gpa.destroy(client); 399 self.tls = null; 400 } 401 } 402 403 switch (task.msgToEnum(Msg)) { 404 .connect => { 405 self.fd = try result.userfd; 406 }, 407 .handshake => { 408 const ptr = try result.userptr; 409 self.tls = @ptrCast(@alignCast(ptr)); 410 self.tls.?.userdata = self; 411 self.tls.?.close_msg = @intFromEnum(@This().Msg.close); 412 self.tls.?.write_msg = @intFromEnum(@This().Msg.write); 413 self.tls.?.recv_msg = @intFromEnum(@This().Msg.recv); 414 self.tls.?.callback = @This().callback; 415 }, 416 .close => { 417 self.tls.?.deinit(self.gpa); 418 self.gpa.destroy(self.tls.?); 419 self.tls = null; 420 self.fd = null; 421 }, 422 423 else => {}, 424 } 425 } 426 }; 427 428 var foo: Foo = .{ .gpa = gpa }; 429 defer { 430 if (foo.tls) |client| { 431 client.deinit(gpa); 432 gpa.destroy(client); 433 } 434 if (foo.fd) |fd| posix.close(fd); 435 } 436 437 _ = try net.tcpConnectToHost( 438 &rt, 439 "google.com", 440 443, 441 .{ .ptr = &foo, .cb = Foo.callback, .msg = @intFromEnum(Foo.Msg.connect) }, 442 ); 443 444 try rt.run(.until_done); 445 446 try std.testing.expect(foo.fd != null); 447 448 var bundle: CertBundle = .{}; 449 try bundle.rescan(gpa); 450 defer bundle.deinit(gpa); 451 452 _ = try Client.init( 453 &rt, 454 foo.fd.?, 455 .{ .root_ca = bundle, .host = "google.com" }, 456 .{ .ptr = &foo, .cb = Foo.callback, .msg = @intFromEnum(Foo.Msg.handshake) }, 457 ); 458 try rt.run(.until_done); 459 try std.testing.expect(foo.tls != null); 460 461 try foo.tls.?.recv(&rt); 462 try foo.tls.?.close(gpa, &rt); 463 try rt.run(.until_done); 464 try std.testing.expect(foo.tls == null); 465 try std.testing.expect(foo.fd == null); 466}