ocaml http/1, http/2 and websocket client and server library
1open Eio.Std
2
3(** {1 Constants} *)
4
5(** UUID used in WebSocket handshake per RFC 6455 *)
6let websocket_uuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
7
8(** {1 Types} *)
9
10(** WebSocket frame opcode *)
11module Opcode = struct
12 type t =
13 | Continuation (** 0x0 *)
14 | Text (** 0x1 *)
15 | Binary (** 0x2 *)
16 | Close (** 0x8 *)
17 | Ping (** 0x9 *)
18 | Pong (** 0xA *)
19 | Ctrl of int (** Other control opcodes *)
20 | Nonctrl of int (** Other non-control opcodes *)
21
22 let to_int = function
23 | Continuation -> 0
24 | Text -> 1
25 | Binary -> 2
26 | Close -> 8
27 | Ping -> 9
28 | Pong -> 10
29 | Ctrl n -> n
30 | Nonctrl n -> n
31
32 let of_int = function
33 | 0 -> Continuation
34 | 1 -> Text
35 | 2 -> Binary
36 | 8 -> Close
37 | 9 -> Ping
38 | 10 -> Pong
39 | n when n > 7 -> Ctrl n
40 | n -> Nonctrl n
41
42 let int_to_string n =
43 if n = 0 then "0"
44 else
45 let n0 = n in
46 let n = ref n in
47 let buf_len = ref 0 in
48 while !n > 0 do
49 buf_len := !buf_len + 1;
50 n := !n / 10
51 done;
52 let buf = Bytes.create !buf_len in
53 let n = ref n0 in
54 for i = !buf_len - 1 downto 0 do
55 let digit = !n mod 10 in
56 Bytes.set buf i (Char.chr (48 + digit));
57 n := !n / 10
58 done;
59 Bytes.unsafe_to_string buf
60
61 let to_string = function
62 | Continuation -> "continuation"
63 | Text -> "text"
64 | Binary -> "binary"
65 | Close -> "close"
66 | Ping -> "ping"
67 | Pong -> "pong"
68 | Ctrl n -> "ctrl(" ^ int_to_string n ^ ")"
69 | Nonctrl n -> "nonctrl(" ^ int_to_string n ^ ")"
70
71 let is_control = function Close | Ping | Pong | Ctrl _ -> true | _ -> false
72end
73
74type frame = {
75 opcode : Opcode.t;
76 extension : int;
77 final : bool;
78 content : string;
79}
80(** WebSocket frame *)
81
82let pp_frame fmt frame =
83 Format.fprintf fmt "{opcode=%s; final=%b; len=%d}"
84 (Opcode.to_string frame.opcode)
85 frame.final
86 (String.length frame.content)
87
88let make_frame ?(opcode = Opcode.Text) ?(extension = 0) ?(final = true)
89 ?(content = "") () =
90 { opcode; extension; final; content }
91
92let close_frame code =
93 let content =
94 if code < 0 then ""
95 else
96 let buf = Bytes.create 2 in
97 Bytes.set buf 0 (Char.chr ((code lsr 8) land 0xff));
98 Bytes.set buf 1 (Char.chr (code land 0xff));
99 Bytes.to_string buf
100 in
101 { opcode = Close; extension = 0; final = true; content }
102
103let default_max_payload_size = 16 * 1024 * 1024
104
105type t = {
106 flow : Eio.Flow.two_way_ty r;
107 mutable closed : bool;
108 is_client : bool;
109 read_buf : Buffer.t;
110 max_payload_size : int;
111}
112
113type error =
114 | Connection_closed
115 | Protocol_error of string
116 | Io_error of string
117 | Payload_too_large of int
118
119(** {1 Cryptographic helpers} *)
120
121(** Compute SHA-1 hash and base64 encode *)
122let b64_encoded_sha1sum s =
123 let hash = Digestif.SHA1.digest_string s in
124 Base64.encode_exn (Digestif.SHA1.to_raw_string hash)
125
126(** Compute the Sec-WebSocket-Accept value *)
127let compute_accept_key key = b64_encoded_sha1sum (key ^ websocket_uuid)
128
129module Rng = struct
130 let generate n = Secure_random.generate n
131end
132
133(** {1 Frame parsing/serialization} *)
134
135(** Apply XOR mask to data *)
136let xor_mask mask data =
137 let len = String.length data in
138 let result = Bytes.create len in
139 for i = 0 to len - 1 do
140 let mask_byte = Char.code mask.[i mod 4] in
141 let data_byte = Char.code data.[i] in
142 Bytes.set result i (Char.chr (data_byte lxor mask_byte))
143 done;
144 Bytes.to_string result
145
146(** Serialize a frame to bytes. Client frames must be masked, server frames must
147 not be masked. *)
148let write_frame_to_buf ~is_client buf frame =
149 let mask = is_client in
150 let opcode = Opcode.to_int frame.opcode in
151 let fin = if frame.final then 0x80 else 0 in
152 let rsv = (frame.extension land 0x7) lsl 4 in
153
154 Buffer.add_char buf (Char.chr (fin lor rsv lor opcode));
155
156 let len = String.length frame.content in
157 let mask_bit = if mask then 0x80 else 0 in
158
159 (* Encode payload length *)
160 if len < 126 then Buffer.add_char buf (Char.chr (mask_bit lor len))
161 else if len < 65536 then begin
162 Buffer.add_char buf (Char.chr (mask_bit lor 126));
163 Buffer.add_char buf (Char.chr ((len lsr 8) land 0xff));
164 Buffer.add_char buf (Char.chr (len land 0xff))
165 end
166 else begin
167 Buffer.add_char buf (Char.chr (mask_bit lor 127));
168 (* 64-bit length, big-endian *)
169 for i = 7 downto 0 do
170 Buffer.add_char buf (Char.chr ((len lsr (i * 8)) land 0xff))
171 done
172 end;
173
174 (* Add mask and payload *)
175 if mask then begin
176 let mask_key = Rng.generate 4 in
177 Buffer.add_string buf mask_key;
178 Buffer.add_string buf (xor_mask mask_key frame.content)
179 end
180 else Buffer.add_string buf frame.content
181
182(** Read exactly n bytes from flow *)
183let read_exactly flow n =
184 let buf = Cstruct.create n in
185 let rec loop off =
186 if off < n then begin
187 let cs = Cstruct.sub buf off (n - off) in
188 let read = Eio.Flow.single_read flow cs in
189 loop (off + read)
190 end
191 in
192 loop 0;
193 Cstruct.to_string buf
194
195let read_frame ~is_client ~max_payload_size flow =
196 try
197 let header = read_exactly flow 2 in
198 let b0 = Char.code header.[0] in
199 let b1 = Char.code header.[1] in
200
201 let final = b0 land 0x80 <> 0 in
202 let extension = (b0 land 0x70) lsr 4 in
203 let opcode = Opcode.of_int (b0 land 0x0f) in
204 let masked = b1 land 0x80 <> 0 in
205 let len0 = b1 land 0x7f in
206
207 if (not is_client) && not masked then
208 Error (Protocol_error "Client frames must be masked")
209 else if is_client && masked then
210 Error (Protocol_error "Server frames must not be masked")
211 else begin
212 let len =
213 if len0 < 126 then len0
214 else if len0 = 126 then begin
215 let ext = read_exactly flow 2 in
216 (Char.code ext.[0] lsl 8) lor Char.code ext.[1]
217 end
218 else begin
219 let ext = read_exactly flow 8 in
220 let len = ref 0 in
221 for i = 0 to 7 do
222 len := (!len lsl 8) lor Char.code ext.[i]
223 done;
224 !len
225 end
226 in
227
228 if len > max_payload_size then Error (Payload_too_large len)
229 else if Opcode.is_control opcode && ((not final) || len > 125) then
230 Error (Protocol_error "Invalid control frame")
231 else begin
232 let mask_key = if masked then Some (read_exactly flow 4) else None in
233 let content = if len > 0 then read_exactly flow len else "" in
234 let content =
235 match mask_key with
236 | Some key -> xor_mask key content
237 | None -> content
238 in
239 Ok { opcode; extension; final; content }
240 end
241 end
242 with
243 | End_of_file -> Error Connection_closed
244 | exn -> Error (Io_error (Printexc.to_string exn))
245
246(** {1 Connection API} *)
247
248(** Check if connection is open *)
249let is_open t = not t.closed
250
251(** Send a frame *)
252let send t frame =
253 if t.closed then Error Connection_closed
254 else
255 try
256 let buf = Buffer.create 128 in
257 write_frame_to_buf ~is_client:t.is_client buf frame;
258 Eio.Flow.write t.flow [ Cstruct.of_string (Buffer.contents buf) ];
259 Ok ()
260 with exn -> Error (Io_error (Printexc.to_string exn))
261
262(** Send a text message *)
263let send_text t content = send t (make_frame ~opcode:Text ~content ())
264
265(** Send a binary message *)
266let send_binary t content = send t (make_frame ~opcode:Binary ~content ())
267
268(** Send a ping *)
269let send_ping t ?(content = "") () =
270 send t (make_frame ~opcode:Ping ~content ())
271
272(** Send a pong *)
273let send_pong t ?(content = "") () =
274 send t (make_frame ~opcode:Pong ~content ())
275
276let recv t =
277 if t.closed then Error Connection_closed
278 else
279 match
280 read_frame ~is_client:t.is_client ~max_payload_size:t.max_payload_size
281 t.flow
282 with
283 | Ok frame ->
284 (match frame.opcode with
285 | Close ->
286 t.closed <- true;
287 ignore (send t (close_frame (-1)))
288 | Ping -> ignore (send_pong t ~content:frame.content ())
289 | _ -> ());
290 Ok frame
291 | Error e ->
292 t.closed <- true;
293 Error e
294
295(** Receive a complete message (handles fragmentation) *)
296let recv_message t =
297 let rec collect_fragments first_opcode buf =
298 match recv t with
299 | Error e -> Error e
300 | Ok frame -> (
301 Buffer.add_string buf frame.content;
302 if frame.final then Ok (first_opcode, Buffer.contents buf)
303 else
304 match frame.opcode with
305 | Continuation -> collect_fragments first_opcode buf
306 | _ -> Error (Protocol_error "Expected continuation frame"))
307 in
308 let rec loop () =
309 match recv t with
310 | Error e -> Error e
311 | Ok frame -> (
312 match frame.opcode with
313 | Text | Binary ->
314 if frame.final then Ok (frame.opcode, frame.content)
315 else begin
316 let buf = Buffer.create 256 in
317 Buffer.add_string buf frame.content;
318 collect_fragments frame.opcode buf
319 end
320 | Close -> Error Connection_closed
321 | Ping | Pong ->
322 (* Control frames handled in recv, try again *)
323 loop ()
324 | Continuation -> Error (Protocol_error "Unexpected continuation")
325 | _ -> Error (Protocol_error "Unexpected opcode"))
326 in
327 loop ()
328
329(** Close the connection *)
330let close ?(code = 1000) t =
331 if not t.closed then begin
332 t.closed <- true;
333 ignore (send t (close_frame code))
334 end
335
336(** {1 Handshake helpers} *)
337
338(** Check if request headers indicate a WebSocket upgrade *)
339let is_upgrade_request headers =
340 let upgrade = H1.Headers.get headers "upgrade" in
341 let connection = H1.Headers.get headers "connection" in
342 let key = H1.Headers.get headers "sec-websocket-key" in
343 match (upgrade, connection, key) with
344 | Some u, Some c, Some _ ->
345 let u = String.lowercase_ascii u in
346 let c = String.lowercase_ascii c in
347 u = "websocket" && (c = "upgrade" || String.sub c 0 7 = "upgrade")
348 | _ -> false
349
350(** Get the Sec-WebSocket-Key from request headers *)
351let get_websocket_key headers = H1.Headers.get headers "sec-websocket-key"
352
353let supported_websocket_version = "13"
354
355let get_websocket_version headers =
356 H1.Headers.get headers "sec-websocket-version"
357
358let validate_websocket_version headers =
359 match get_websocket_version headers with
360 | Some v when v = supported_websocket_version -> Ok ()
361 | Some v -> Error ("Unsupported WebSocket version: " ^ v)
362 | None -> Error "Missing Sec-WebSocket-Version header"
363
364(** Get the Origin header from request headers *)
365let get_origin headers = H1.Headers.get headers "origin"
366
367(** Origin policy for WebSocket connections.
368 - [`Allow_all] accepts connections from any origin (NOT RECOMMENDED for
369 production)
370 - [`Allow_list origins] only accepts connections from the specified origins
371 - [`Allow_same_origin] only accepts connections where Origin matches the
372 Host header *)
373type origin_policy = Allow_all | Allow_list of string list | Allow_same_origin
374
375(** Validate Origin header against policy.
376 @param policy The origin validation policy
377 @param headers The request headers (must contain Origin, may contain Host)
378 @return [Ok ()] if origin is allowed, [Error reason] if rejected *)
379let validate_origin ~policy headers =
380 match policy with
381 | Allow_all -> Ok ()
382 | Allow_list allowed -> (
383 match get_origin headers with
384 | None ->
385 (* Missing Origin header - could be same-origin request or non-browser client.
386 For security, we require Origin for Allow_list policy. *)
387 Error "Missing Origin header"
388 | Some origin ->
389 let origin_lower = String.lowercase_ascii origin in
390 if
391 List.exists
392 (fun allowed -> String.lowercase_ascii allowed = origin_lower)
393 allowed
394 then Ok ()
395 else Error ("Origin not allowed: " ^ origin))
396 | Allow_same_origin -> (
397 match (get_origin headers, H1.Headers.get headers "host") with
398 | None, _ ->
399 (* No Origin header - likely same-origin or non-browser, allow it *)
400 Ok ()
401 | Some origin, Some host ->
402 (* Extract host from origin URL (e.g., "https://example.com" -> "example.com") *)
403 let origin_host =
404 let uri = Uri.of_string origin in
405 match Uri.host uri with
406 | Some h -> (
407 match Uri.port uri with
408 | Some p -> h ^ ":" ^ string_of_int p
409 | None -> h)
410 | None -> origin
411 in
412 if String.lowercase_ascii origin_host = String.lowercase_ascii host
413 then Ok ()
414 else Error ("Cross-origin request: " ^ origin ^ " vs " ^ host)
415 | Some origin, None ->
416 Error ("Missing Host header for origin check: " ^ origin))
417
418(** Generate random base64-encoded key for client handshake *)
419let generate_key () = Base64.encode_exn (Rng.generate 16)
420
421(** {1 Client API} *)
422
423(** Connect to a WebSocket server *)
424let connect ~sw ~net ?(tls_config = Tls_config.Client.default) ?protocols url =
425 let uri = Uri.of_string url in
426 let scheme = Uri.scheme uri |> Option.value ~default:"ws" in
427 let is_secure = scheme = "wss" in
428 let host = Uri.host uri |> Option.value ~default:"localhost" in
429 let default_port = if is_secure then 443 else 80 in
430 let port = Uri.port uri |> Option.value ~default:default_port in
431 let path =
432 let p = Uri.path_and_query uri in
433 if p = "" then "/" else p
434 in
435
436 (* Resolve and connect *)
437 let addrs = Eio.Net.getaddrinfo_stream net host in
438 match addrs with
439 | [] -> Error (Io_error ("Cannot resolve host: " ^ host))
440 | addr_info :: _ -> (
441 let addr =
442 match addr_info with
443 | `Tcp (ip, _) -> `Tcp (ip, port)
444 | `Unix _ -> failwith "Unix sockets not supported"
445 in
446 let tcp_flow = Eio.Net.connect ~sw net addr in
447
448 (* Wrap with TLS if secure *)
449 let flow_result =
450 if is_secure then
451 match Tls_config.Client.to_tls_config tls_config ~host with
452 | Error msg -> Error (Io_error ("TLS error: " ^ msg))
453 | Ok tls_cfg -> (
454 try
455 let host_domain =
456 match Domain_name.of_string host with
457 | Ok dn -> (
458 match Domain_name.host dn with
459 | Ok h -> Some h
460 | Error _ -> None)
461 | Error _ -> None
462 in
463 let tls_flow =
464 Tls_eio.client_of_flow tls_cfg ?host:host_domain tcp_flow
465 in
466 Ok (tls_flow :> Eio.Flow.two_way_ty r)
467 with exn -> Error (Io_error (Printexc.to_string exn)))
468 else Ok (tcp_flow :> Eio.Flow.two_way_ty r)
469 in
470
471 match flow_result with
472 | Error e -> Error e
473 | Ok flow -> (
474 (* Generate key and build upgrade request *)
475 let key = generate_key () in
476 let expected_accept = compute_accept_key key in
477
478 let headers =
479 [
480 ("Host", host);
481 ("Upgrade", "websocket");
482 ("Connection", "Upgrade");
483 ("Sec-WebSocket-Key", key);
484 ("Sec-WebSocket-Version", "13");
485 ]
486 in
487 let headers =
488 match protocols with
489 | Some ps ->
490 ("Sec-WebSocket-Protocol", String.concat ", " ps) :: headers
491 | None -> headers
492 in
493
494 (* Send HTTP upgrade request *)
495 let buf = Buffer.create 256 in
496 Buffer.add_string buf "GET ";
497 Buffer.add_string buf path;
498 Buffer.add_string buf " HTTP/1.1\r\n";
499 List.iter
500 (fun (k, v) ->
501 Buffer.add_string buf k;
502 Buffer.add_string buf ": ";
503 Buffer.add_string buf v;
504 Buffer.add_string buf "\r\n")
505 headers;
506 Buffer.add_string buf "\r\n";
507
508 try
509 Eio.Flow.write flow [ Cstruct.of_string (Buffer.contents buf) ];
510
511 (* Read response headers *)
512 let response_buf = Buffer.create 1024 in
513 let rec read_until_crlf_crlf () =
514 let byte = read_exactly flow 1 in
515 Buffer.add_string response_buf byte;
516 let len = Buffer.length response_buf in
517 if
518 len >= 4
519 && Buffer.nth response_buf (len - 4) = '\r'
520 && Buffer.nth response_buf (len - 3) = '\n'
521 && Buffer.nth response_buf (len - 2) = '\r'
522 && Buffer.nth response_buf (len - 1) = '\n'
523 then ()
524 else read_until_crlf_crlf ()
525 in
526 read_until_crlf_crlf ();
527
528 (* Parse status line and headers *)
529 let response_str = Buffer.contents response_buf in
530 let lines = String.split_on_char '\n' response_str in
531
532 (* Check status line *)
533 match lines with
534 | status_line :: header_lines -> (
535 let status_line = String.trim status_line in
536 if
537 not
538 (String.length status_line >= 12
539 && String.sub status_line 9 3 = "101")
540 then Error (Protocol_error ("Bad status: " ^ status_line))
541 else
542 (* Parse headers *)
543 let headers =
544 List.filter_map
545 (fun line ->
546 let line = String.trim line in
547 if line = "" then None
548 else
549 match String.index_opt line ':' with
550 | Some i ->
551 let key =
552 String.lowercase_ascii (String.sub line 0 i)
553 in
554 let value =
555 String.trim
556 (String.sub line (i + 1)
557 (String.length line - i - 1))
558 in
559 Some (key, value)
560 | None -> None)
561 header_lines
562 in
563
564 let accept = List.assoc_opt "sec-websocket-accept" headers in
565 match accept with
566 | Some a when a = expected_accept ->
567 Ok
568 {
569 flow;
570 closed = false;
571 is_client = true;
572 read_buf = Buffer.create 4096;
573 max_payload_size = default_max_payload_size;
574 }
575 | Some a ->
576 let buf =
577 Buffer.create
578 (String.length a + String.length expected_accept + 32)
579 in
580 Buffer.add_string buf "Bad accept key: ";
581 Buffer.add_string buf a;
582 Buffer.add_string buf " (expected ";
583 Buffer.add_string buf expected_accept;
584 Buffer.add_char buf ')';
585 Error (Protocol_error (Buffer.contents buf))
586 | None -> Error (Protocol_error "Missing accept key"))
587 | [] -> Error (Protocol_error "Empty response")
588 with exn -> Error (Io_error (Printexc.to_string exn))))
589
590(** {1 Server API} *)
591
592let accept ?max_payload_size ~flow ~key () =
593 let max_payload_size =
594 Option.value max_payload_size ~default:default_max_payload_size
595 in
596 let accept = compute_accept_key key in
597 let buf = Buffer.create (String.length accept + 128) in
598 Buffer.add_string buf "HTTP/1.1 101 Switching Protocols\r\n";
599 Buffer.add_string buf "Upgrade: websocket\r\n";
600 Buffer.add_string buf "Connection: Upgrade\r\n";
601 Buffer.add_string buf "Sec-WebSocket-Accept: ";
602 Buffer.add_string buf accept;
603 Buffer.add_string buf "\r\n\r\n";
604 let response = Buffer.contents buf in
605 try
606 Eio.Flow.write flow [ Cstruct.of_string response ];
607 Ok
608 {
609 flow :> Eio.Flow.two_way_ty r;
610 closed = false;
611 is_client = false;
612 read_buf = Buffer.create 4096;
613 max_payload_size;
614 }
615 with exn -> Error (Io_error (Printexc.to_string exn))