ocaml http/1, http/2 and websocket client and server library
at v0.3.3 20 kB view raw
1open Eio.Std 2 3(** {1 Constants} *) 4 5(** UUID used in WebSocket handshake per RFC 6455 *) 6let websocket_uuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" 7 8(** {1 Types} *) 9 10(** WebSocket frame opcode *) 11module Opcode = struct 12 type t = 13 | Continuation (** 0x0 *) 14 | Text (** 0x1 *) 15 | Binary (** 0x2 *) 16 | Close (** 0x8 *) 17 | Ping (** 0x9 *) 18 | Pong (** 0xA *) 19 | Ctrl of int (** Other control opcodes *) 20 | Nonctrl of int (** Other non-control opcodes *) 21 22 let to_int = function 23 | Continuation -> 0 24 | Text -> 1 25 | Binary -> 2 26 | Close -> 8 27 | Ping -> 9 28 | Pong -> 10 29 | Ctrl n -> n 30 | Nonctrl n -> n 31 32 let of_int = function 33 | 0 -> Continuation 34 | 1 -> Text 35 | 2 -> Binary 36 | 8 -> Close 37 | 9 -> Ping 38 | 10 -> Pong 39 | n when n > 7 -> Ctrl n 40 | n -> Nonctrl n 41 42 let int_to_string n = 43 if n = 0 then "0" 44 else 45 let n0 = n in 46 let n = ref n in 47 let buf_len = ref 0 in 48 while !n > 0 do 49 buf_len := !buf_len + 1; 50 n := !n / 10 51 done; 52 let buf = Bytes.create !buf_len in 53 let n = ref n0 in 54 for i = !buf_len - 1 downto 0 do 55 let digit = !n mod 10 in 56 Bytes.set buf i (Char.chr (48 + digit)); 57 n := !n / 10 58 done; 59 Bytes.unsafe_to_string buf 60 61 let to_string = function 62 | Continuation -> "continuation" 63 | Text -> "text" 64 | Binary -> "binary" 65 | Close -> "close" 66 | Ping -> "ping" 67 | Pong -> "pong" 68 | Ctrl n -> "ctrl(" ^ int_to_string n ^ ")" 69 | Nonctrl n -> "nonctrl(" ^ int_to_string n ^ ")" 70 71 let is_control = function Close | Ping | Pong | Ctrl _ -> true | _ -> false 72end 73 74type frame = { 75 opcode : Opcode.t; 76 extension : int; 77 final : bool; 78 content : string; 79} 80(** WebSocket frame *) 81 82let pp_frame fmt frame = 83 Format.fprintf fmt "{opcode=%s; final=%b; len=%d}" 84 (Opcode.to_string frame.opcode) 85 frame.final 86 (String.length frame.content) 87 88let make_frame ?(opcode = Opcode.Text) ?(extension = 0) ?(final = true) 89 ?(content = "") () = 90 { opcode; extension; final; content } 91 92let close_frame code = 93 let content = 94 if code < 0 then "" 95 else 96 let buf = Bytes.create 2 in 97 Bytes.set buf 0 (Char.chr ((code lsr 8) land 0xff)); 98 Bytes.set buf 1 (Char.chr (code land 0xff)); 99 Bytes.to_string buf 100 in 101 { opcode = Close; extension = 0; final = true; content } 102 103let default_max_payload_size = 16 * 1024 * 1024 104 105type t = { 106 flow : Eio.Flow.two_way_ty r; 107 mutable closed : bool; 108 is_client : bool; 109 read_buf : Buffer.t; 110 max_payload_size : int; 111} 112 113type error = 114 | Connection_closed 115 | Protocol_error of string 116 | Io_error of string 117 | Payload_too_large of int 118 119(** {1 Cryptographic helpers} *) 120 121(** Compute SHA-1 hash and base64 encode *) 122let b64_encoded_sha1sum s = 123 let hash = Digestif.SHA1.digest_string s in 124 Base64.encode_exn (Digestif.SHA1.to_raw_string hash) 125 126(** Compute the Sec-WebSocket-Accept value *) 127let compute_accept_key key = b64_encoded_sha1sum (key ^ websocket_uuid) 128 129module Rng = struct 130 let generate n = Secure_random.generate n 131end 132 133(** {1 Frame parsing/serialization} *) 134 135(** Apply XOR mask to data *) 136let xor_mask mask data = 137 let len = String.length data in 138 let result = Bytes.create len in 139 for i = 0 to len - 1 do 140 let mask_byte = Char.code mask.[i mod 4] in 141 let data_byte = Char.code data.[i] in 142 Bytes.set result i (Char.chr (data_byte lxor mask_byte)) 143 done; 144 Bytes.to_string result 145 146(** Serialize a frame to bytes. Client frames must be masked, server frames must 147 not be masked. *) 148let write_frame_to_buf ~is_client buf frame = 149 let mask = is_client in 150 let opcode = Opcode.to_int frame.opcode in 151 let fin = if frame.final then 0x80 else 0 in 152 let rsv = (frame.extension land 0x7) lsl 4 in 153 154 Buffer.add_char buf (Char.chr (fin lor rsv lor opcode)); 155 156 let len = String.length frame.content in 157 let mask_bit = if mask then 0x80 else 0 in 158 159 (* Encode payload length *) 160 if len < 126 then Buffer.add_char buf (Char.chr (mask_bit lor len)) 161 else if len < 65536 then begin 162 Buffer.add_char buf (Char.chr (mask_bit lor 126)); 163 Buffer.add_char buf (Char.chr ((len lsr 8) land 0xff)); 164 Buffer.add_char buf (Char.chr (len land 0xff)) 165 end 166 else begin 167 Buffer.add_char buf (Char.chr (mask_bit lor 127)); 168 (* 64-bit length, big-endian *) 169 for i = 7 downto 0 do 170 Buffer.add_char buf (Char.chr ((len lsr (i * 8)) land 0xff)) 171 done 172 end; 173 174 (* Add mask and payload *) 175 if mask then begin 176 let mask_key = Rng.generate 4 in 177 Buffer.add_string buf mask_key; 178 Buffer.add_string buf (xor_mask mask_key frame.content) 179 end 180 else Buffer.add_string buf frame.content 181 182(** Read exactly n bytes from flow *) 183let read_exactly flow n = 184 let buf = Cstruct.create n in 185 let rec loop off = 186 if off < n then begin 187 let cs = Cstruct.sub buf off (n - off) in 188 let read = Eio.Flow.single_read flow cs in 189 loop (off + read) 190 end 191 in 192 loop 0; 193 Cstruct.to_string buf 194 195let read_frame ~is_client ~max_payload_size flow = 196 try 197 let header = read_exactly flow 2 in 198 let b0 = Char.code header.[0] in 199 let b1 = Char.code header.[1] in 200 201 let final = b0 land 0x80 <> 0 in 202 let extension = (b0 land 0x70) lsr 4 in 203 let opcode = Opcode.of_int (b0 land 0x0f) in 204 let masked = b1 land 0x80 <> 0 in 205 let len0 = b1 land 0x7f in 206 207 if (not is_client) && not masked then 208 Error (Protocol_error "Client frames must be masked") 209 else if is_client && masked then 210 Error (Protocol_error "Server frames must not be masked") 211 else begin 212 let len = 213 if len0 < 126 then len0 214 else if len0 = 126 then begin 215 let ext = read_exactly flow 2 in 216 (Char.code ext.[0] lsl 8) lor Char.code ext.[1] 217 end 218 else begin 219 let ext = read_exactly flow 8 in 220 let len = ref 0 in 221 for i = 0 to 7 do 222 len := (!len lsl 8) lor Char.code ext.[i] 223 done; 224 !len 225 end 226 in 227 228 if len > max_payload_size then Error (Payload_too_large len) 229 else if Opcode.is_control opcode && ((not final) || len > 125) then 230 Error (Protocol_error "Invalid control frame") 231 else begin 232 let mask_key = if masked then Some (read_exactly flow 4) else None in 233 let content = if len > 0 then read_exactly flow len else "" in 234 let content = 235 match mask_key with 236 | Some key -> xor_mask key content 237 | None -> content 238 in 239 Ok { opcode; extension; final; content } 240 end 241 end 242 with 243 | End_of_file -> Error Connection_closed 244 | exn -> Error (Io_error (Printexc.to_string exn)) 245 246(** {1 Connection API} *) 247 248(** Check if connection is open *) 249let is_open t = not t.closed 250 251(** Send a frame *) 252let send t frame = 253 if t.closed then Error Connection_closed 254 else 255 try 256 let buf = Buffer.create 128 in 257 write_frame_to_buf ~is_client:t.is_client buf frame; 258 Eio.Flow.write t.flow [ Cstruct.of_string (Buffer.contents buf) ]; 259 Ok () 260 with exn -> Error (Io_error (Printexc.to_string exn)) 261 262(** Send a text message *) 263let send_text t content = send t (make_frame ~opcode:Text ~content ()) 264 265(** Send a binary message *) 266let send_binary t content = send t (make_frame ~opcode:Binary ~content ()) 267 268(** Send a ping *) 269let send_ping t ?(content = "") () = 270 send t (make_frame ~opcode:Ping ~content ()) 271 272(** Send a pong *) 273let send_pong t ?(content = "") () = 274 send t (make_frame ~opcode:Pong ~content ()) 275 276let recv t = 277 if t.closed then Error Connection_closed 278 else 279 match 280 read_frame ~is_client:t.is_client ~max_payload_size:t.max_payload_size 281 t.flow 282 with 283 | Ok frame -> 284 (match frame.opcode with 285 | Close -> 286 t.closed <- true; 287 ignore (send t (close_frame (-1))) 288 | Ping -> ignore (send_pong t ~content:frame.content ()) 289 | _ -> ()); 290 Ok frame 291 | Error e -> 292 t.closed <- true; 293 Error e 294 295(** Receive a complete message (handles fragmentation) *) 296let recv_message t = 297 let rec collect_fragments first_opcode buf = 298 match recv t with 299 | Error e -> Error e 300 | Ok frame -> ( 301 Buffer.add_string buf frame.content; 302 if frame.final then Ok (first_opcode, Buffer.contents buf) 303 else 304 match frame.opcode with 305 | Continuation -> collect_fragments first_opcode buf 306 | _ -> Error (Protocol_error "Expected continuation frame")) 307 in 308 let rec loop () = 309 match recv t with 310 | Error e -> Error e 311 | Ok frame -> ( 312 match frame.opcode with 313 | Text | Binary -> 314 if frame.final then Ok (frame.opcode, frame.content) 315 else begin 316 let buf = Buffer.create 256 in 317 Buffer.add_string buf frame.content; 318 collect_fragments frame.opcode buf 319 end 320 | Close -> Error Connection_closed 321 | Ping | Pong -> 322 (* Control frames handled in recv, try again *) 323 loop () 324 | Continuation -> Error (Protocol_error "Unexpected continuation") 325 | _ -> Error (Protocol_error "Unexpected opcode")) 326 in 327 loop () 328 329(** Close the connection *) 330let close ?(code = 1000) t = 331 if not t.closed then begin 332 t.closed <- true; 333 ignore (send t (close_frame code)) 334 end 335 336(** {1 Handshake helpers} *) 337 338(** Check if request headers indicate a WebSocket upgrade *) 339let is_upgrade_request headers = 340 let upgrade = H1.Headers.get headers "upgrade" in 341 let connection = H1.Headers.get headers "connection" in 342 let key = H1.Headers.get headers "sec-websocket-key" in 343 match (upgrade, connection, key) with 344 | Some u, Some c, Some _ -> 345 let u = String.lowercase_ascii u in 346 let c = String.lowercase_ascii c in 347 u = "websocket" && (c = "upgrade" || String.sub c 0 7 = "upgrade") 348 | _ -> false 349 350(** Get the Sec-WebSocket-Key from request headers *) 351let get_websocket_key headers = H1.Headers.get headers "sec-websocket-key" 352 353let supported_websocket_version = "13" 354 355let get_websocket_version headers = 356 H1.Headers.get headers "sec-websocket-version" 357 358let validate_websocket_version headers = 359 match get_websocket_version headers with 360 | Some v when v = supported_websocket_version -> Ok () 361 | Some v -> Error ("Unsupported WebSocket version: " ^ v) 362 | None -> Error "Missing Sec-WebSocket-Version header" 363 364(** Get the Origin header from request headers *) 365let get_origin headers = H1.Headers.get headers "origin" 366 367(** Origin policy for WebSocket connections. 368 - [`Allow_all] accepts connections from any origin (NOT RECOMMENDED for 369 production) 370 - [`Allow_list origins] only accepts connections from the specified origins 371 - [`Allow_same_origin] only accepts connections where Origin matches the 372 Host header *) 373type origin_policy = Allow_all | Allow_list of string list | Allow_same_origin 374 375(** Validate Origin header against policy. 376 @param policy The origin validation policy 377 @param headers The request headers (must contain Origin, may contain Host) 378 @return [Ok ()] if origin is allowed, [Error reason] if rejected *) 379let validate_origin ~policy headers = 380 match policy with 381 | Allow_all -> Ok () 382 | Allow_list allowed -> ( 383 match get_origin headers with 384 | None -> 385 (* Missing Origin header - could be same-origin request or non-browser client. 386 For security, we require Origin for Allow_list policy. *) 387 Error "Missing Origin header" 388 | Some origin -> 389 let origin_lower = String.lowercase_ascii origin in 390 if 391 List.exists 392 (fun allowed -> String.lowercase_ascii allowed = origin_lower) 393 allowed 394 then Ok () 395 else Error ("Origin not allowed: " ^ origin)) 396 | Allow_same_origin -> ( 397 match (get_origin headers, H1.Headers.get headers "host") with 398 | None, _ -> 399 (* No Origin header - likely same-origin or non-browser, allow it *) 400 Ok () 401 | Some origin, Some host -> 402 (* Extract host from origin URL (e.g., "https://example.com" -> "example.com") *) 403 let origin_host = 404 let uri = Uri.of_string origin in 405 match Uri.host uri with 406 | Some h -> ( 407 match Uri.port uri with 408 | Some p -> h ^ ":" ^ string_of_int p 409 | None -> h) 410 | None -> origin 411 in 412 if String.lowercase_ascii origin_host = String.lowercase_ascii host 413 then Ok () 414 else Error ("Cross-origin request: " ^ origin ^ " vs " ^ host) 415 | Some origin, None -> 416 Error ("Missing Host header for origin check: " ^ origin)) 417 418(** Generate random base64-encoded key for client handshake *) 419let generate_key () = Base64.encode_exn (Rng.generate 16) 420 421(** {1 Client API} *) 422 423(** Connect to a WebSocket server *) 424let connect ~sw ~net ?(tls_config = Tls_config.Client.default) ?protocols url = 425 let uri = Uri.of_string url in 426 let scheme = Uri.scheme uri |> Option.value ~default:"ws" in 427 let is_secure = scheme = "wss" in 428 let host = Uri.host uri |> Option.value ~default:"localhost" in 429 let default_port = if is_secure then 443 else 80 in 430 let port = Uri.port uri |> Option.value ~default:default_port in 431 let path = 432 let p = Uri.path_and_query uri in 433 if p = "" then "/" else p 434 in 435 436 (* Resolve and connect *) 437 let addrs = Eio.Net.getaddrinfo_stream net host in 438 match addrs with 439 | [] -> Error (Io_error ("Cannot resolve host: " ^ host)) 440 | addr_info :: _ -> ( 441 let addr = 442 match addr_info with 443 | `Tcp (ip, _) -> `Tcp (ip, port) 444 | `Unix _ -> failwith "Unix sockets not supported" 445 in 446 let tcp_flow = Eio.Net.connect ~sw net addr in 447 448 (* Wrap with TLS if secure *) 449 let flow_result = 450 if is_secure then 451 match Tls_config.Client.to_tls_config tls_config ~host with 452 | Error msg -> Error (Io_error ("TLS error: " ^ msg)) 453 | Ok tls_cfg -> ( 454 try 455 let host_domain = 456 match Domain_name.of_string host with 457 | Ok dn -> ( 458 match Domain_name.host dn with 459 | Ok h -> Some h 460 | Error _ -> None) 461 | Error _ -> None 462 in 463 let tls_flow = 464 Tls_eio.client_of_flow tls_cfg ?host:host_domain tcp_flow 465 in 466 Ok (tls_flow :> Eio.Flow.two_way_ty r) 467 with exn -> Error (Io_error (Printexc.to_string exn))) 468 else Ok (tcp_flow :> Eio.Flow.two_way_ty r) 469 in 470 471 match flow_result with 472 | Error e -> Error e 473 | Ok flow -> ( 474 (* Generate key and build upgrade request *) 475 let key = generate_key () in 476 let expected_accept = compute_accept_key key in 477 478 let headers = 479 [ 480 ("Host", host); 481 ("Upgrade", "websocket"); 482 ("Connection", "Upgrade"); 483 ("Sec-WebSocket-Key", key); 484 ("Sec-WebSocket-Version", "13"); 485 ] 486 in 487 let headers = 488 match protocols with 489 | Some ps -> 490 ("Sec-WebSocket-Protocol", String.concat ", " ps) :: headers 491 | None -> headers 492 in 493 494 (* Send HTTP upgrade request *) 495 let buf = Buffer.create 256 in 496 Buffer.add_string buf "GET "; 497 Buffer.add_string buf path; 498 Buffer.add_string buf " HTTP/1.1\r\n"; 499 List.iter 500 (fun (k, v) -> 501 Buffer.add_string buf k; 502 Buffer.add_string buf ": "; 503 Buffer.add_string buf v; 504 Buffer.add_string buf "\r\n") 505 headers; 506 Buffer.add_string buf "\r\n"; 507 508 try 509 Eio.Flow.write flow [ Cstruct.of_string (Buffer.contents buf) ]; 510 511 (* Read response headers *) 512 let response_buf = Buffer.create 1024 in 513 let rec read_until_crlf_crlf () = 514 let byte = read_exactly flow 1 in 515 Buffer.add_string response_buf byte; 516 let len = Buffer.length response_buf in 517 if 518 len >= 4 519 && Buffer.nth response_buf (len - 4) = '\r' 520 && Buffer.nth response_buf (len - 3) = '\n' 521 && Buffer.nth response_buf (len - 2) = '\r' 522 && Buffer.nth response_buf (len - 1) = '\n' 523 then () 524 else read_until_crlf_crlf () 525 in 526 read_until_crlf_crlf (); 527 528 (* Parse status line and headers *) 529 let response_str = Buffer.contents response_buf in 530 let lines = String.split_on_char '\n' response_str in 531 532 (* Check status line *) 533 match lines with 534 | status_line :: header_lines -> ( 535 let status_line = String.trim status_line in 536 if 537 not 538 (String.length status_line >= 12 539 && String.sub status_line 9 3 = "101") 540 then Error (Protocol_error ("Bad status: " ^ status_line)) 541 else 542 (* Parse headers *) 543 let headers = 544 List.filter_map 545 (fun line -> 546 let line = String.trim line in 547 if line = "" then None 548 else 549 match String.index_opt line ':' with 550 | Some i -> 551 let key = 552 String.lowercase_ascii (String.sub line 0 i) 553 in 554 let value = 555 String.trim 556 (String.sub line (i + 1) 557 (String.length line - i - 1)) 558 in 559 Some (key, value) 560 | None -> None) 561 header_lines 562 in 563 564 let accept = List.assoc_opt "sec-websocket-accept" headers in 565 match accept with 566 | Some a when a = expected_accept -> 567 Ok 568 { 569 flow; 570 closed = false; 571 is_client = true; 572 read_buf = Buffer.create 4096; 573 max_payload_size = default_max_payload_size; 574 } 575 | Some a -> 576 let buf = 577 Buffer.create 578 (String.length a + String.length expected_accept + 32) 579 in 580 Buffer.add_string buf "Bad accept key: "; 581 Buffer.add_string buf a; 582 Buffer.add_string buf " (expected "; 583 Buffer.add_string buf expected_accept; 584 Buffer.add_char buf ')'; 585 Error (Protocol_error (Buffer.contents buf)) 586 | None -> Error (Protocol_error "Missing accept key")) 587 | [] -> Error (Protocol_error "Empty response") 588 with exn -> Error (Io_error (Printexc.to_string exn)))) 589 590(** {1 Server API} *) 591 592let accept ?max_payload_size ~flow ~key () = 593 let max_payload_size = 594 Option.value max_payload_size ~default:default_max_payload_size 595 in 596 let accept = compute_accept_key key in 597 let buf = Buffer.create (String.length accept + 128) in 598 Buffer.add_string buf "HTTP/1.1 101 Switching Protocols\r\n"; 599 Buffer.add_string buf "Upgrade: websocket\r\n"; 600 Buffer.add_string buf "Connection: Upgrade\r\n"; 601 Buffer.add_string buf "Sec-WebSocket-Accept: "; 602 Buffer.add_string buf accept; 603 Buffer.add_string buf "\r\n\r\n"; 604 let response = Buffer.contents buf in 605 try 606 Eio.Flow.write flow [ Cstruct.of_string response ]; 607 Ok 608 { 609 flow :> Eio.Flow.two_way_ty r; 610 closed = false; 611 is_client = false; 612 read_buf = Buffer.create 4096; 613 max_payload_size; 614 } 615 with exn -> Error (Io_error (Printexc.to_string exn))