open Eio.Std (** {1 Constants} *) (** UUID used in WebSocket handshake per RFC 6455 *) let websocket_uuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" (** {1 Types} *) (** WebSocket frame opcode *) module Opcode = struct type t = | Continuation (** 0x0 *) | Text (** 0x1 *) | Binary (** 0x2 *) | Close (** 0x8 *) | Ping (** 0x9 *) | Pong (** 0xA *) | Ctrl of int (** Other control opcodes *) | Nonctrl of int (** Other non-control opcodes *) let to_int = function | Continuation -> 0 | Text -> 1 | Binary -> 2 | Close -> 8 | Ping -> 9 | Pong -> 10 | Ctrl n -> n | Nonctrl n -> n let of_int = function | 0 -> Continuation | 1 -> Text | 2 -> Binary | 8 -> Close | 9 -> Ping | 10 -> Pong | n when n > 7 -> Ctrl n | n -> Nonctrl n let int_to_string n = if n = 0 then "0" else let n0 = n in let n = ref n in let buf_len = ref 0 in while !n > 0 do buf_len := !buf_len + 1; n := !n / 10 done; let buf = Bytes.create !buf_len in let n = ref n0 in for i = !buf_len - 1 downto 0 do let digit = !n mod 10 in Bytes.set buf i (Char.chr (48 + digit)); n := !n / 10 done; Bytes.unsafe_to_string buf let to_string = function | Continuation -> "continuation" | Text -> "text" | Binary -> "binary" | Close -> "close" | Ping -> "ping" | Pong -> "pong" | Ctrl n -> "ctrl(" ^ int_to_string n ^ ")" | Nonctrl n -> "nonctrl(" ^ int_to_string n ^ ")" let is_control = function Close | Ping | Pong | Ctrl _ -> true | _ -> false end type frame = { opcode : Opcode.t; extension : int; final : bool; content : string; } (** WebSocket frame *) let pp_frame fmt frame = Format.fprintf fmt "{opcode=%s; final=%b; len=%d}" (Opcode.to_string frame.opcode) frame.final (String.length frame.content) let make_frame ?(opcode = Opcode.Text) ?(extension = 0) ?(final = true) ?(content = "") () = { opcode; extension; final; content } let close_frame code = let content = if code < 0 then "" else let buf = Bytes.create 2 in Bytes.set buf 0 (Char.chr ((code lsr 8) land 0xff)); Bytes.set buf 1 (Char.chr (code land 0xff)); Bytes.to_string buf in { opcode = Close; extension = 0; final = true; content } let default_max_payload_size = 16 * 1024 * 1024 type t = { flow : Eio.Flow.two_way_ty r; mutable closed : bool; is_client : bool; read_buf : Buffer.t; max_payload_size : int; } type error = | Connection_closed | Protocol_error of string | Io_error of string | Payload_too_large of int (** {1 Cryptographic helpers} *) (** Compute SHA-1 hash and base64 encode *) let b64_encoded_sha1sum s = let hash = Digestif.SHA1.digest_string s in Base64.encode_exn (Digestif.SHA1.to_raw_string hash) (** Compute the Sec-WebSocket-Accept value *) let compute_accept_key key = b64_encoded_sha1sum (key ^ websocket_uuid) module Rng = struct let generate n = Secure_random.generate n end (** {1 Frame parsing/serialization} *) (** Apply XOR mask to data *) let xor_mask mask data = let len = String.length data in let result = Bytes.create len in for i = 0 to len - 1 do let mask_byte = Char.code mask.[i mod 4] in let data_byte = Char.code data.[i] in Bytes.set result i (Char.chr (data_byte lxor mask_byte)) done; Bytes.to_string result (** Serialize a frame to bytes. Client frames must be masked, server frames must not be masked. *) let write_frame_to_buf ~is_client buf frame = let mask = is_client in let opcode = Opcode.to_int frame.opcode in let fin = if frame.final then 0x80 else 0 in let rsv = (frame.extension land 0x7) lsl 4 in Buffer.add_char buf (Char.chr (fin lor rsv lor opcode)); let len = String.length frame.content in let mask_bit = if mask then 0x80 else 0 in (* Encode payload length *) if len < 126 then Buffer.add_char buf (Char.chr (mask_bit lor len)) else if len < 65536 then begin Buffer.add_char buf (Char.chr (mask_bit lor 126)); Buffer.add_char buf (Char.chr ((len lsr 8) land 0xff)); Buffer.add_char buf (Char.chr (len land 0xff)) end else begin Buffer.add_char buf (Char.chr (mask_bit lor 127)); (* 64-bit length, big-endian *) for i = 7 downto 0 do Buffer.add_char buf (Char.chr ((len lsr (i * 8)) land 0xff)) done end; (* Add mask and payload *) if mask then begin let mask_key = Rng.generate 4 in Buffer.add_string buf mask_key; Buffer.add_string buf (xor_mask mask_key frame.content) end else Buffer.add_string buf frame.content (** Read exactly n bytes from flow *) let read_exactly flow n = let buf = Cstruct.create n in let rec loop off = if off < n then begin let cs = Cstruct.sub buf off (n - off) in let read = Eio.Flow.single_read flow cs in loop (off + read) end in loop 0; Cstruct.to_string buf let read_frame ~is_client ~max_payload_size flow = try let header = read_exactly flow 2 in let b0 = Char.code header.[0] in let b1 = Char.code header.[1] in let final = b0 land 0x80 <> 0 in let extension = (b0 land 0x70) lsr 4 in let opcode = Opcode.of_int (b0 land 0x0f) in let masked = b1 land 0x80 <> 0 in let len0 = b1 land 0x7f in if (not is_client) && not masked then Error (Protocol_error "Client frames must be masked") else if is_client && masked then Error (Protocol_error "Server frames must not be masked") else begin let len = if len0 < 126 then len0 else if len0 = 126 then begin let ext = read_exactly flow 2 in (Char.code ext.[0] lsl 8) lor Char.code ext.[1] end else begin let ext = read_exactly flow 8 in let len = ref 0 in for i = 0 to 7 do len := (!len lsl 8) lor Char.code ext.[i] done; !len end in if len > max_payload_size then Error (Payload_too_large len) else if Opcode.is_control opcode && ((not final) || len > 125) then Error (Protocol_error "Invalid control frame") else begin let mask_key = if masked then Some (read_exactly flow 4) else None in let content = if len > 0 then read_exactly flow len else "" in let content = match mask_key with | Some key -> xor_mask key content | None -> content in Ok { opcode; extension; final; content } end end with | End_of_file -> Error Connection_closed | exn -> Error (Io_error (Printexc.to_string exn)) (** {1 Connection API} *) (** Check if connection is open *) let is_open t = not t.closed (** Send a frame *) let send t frame = if t.closed then Error Connection_closed else try let buf = Buffer.create 128 in write_frame_to_buf ~is_client:t.is_client buf frame; Eio.Flow.write t.flow [ Cstruct.of_string (Buffer.contents buf) ]; Ok () with exn -> Error (Io_error (Printexc.to_string exn)) (** Send a text message *) let send_text t content = send t (make_frame ~opcode:Text ~content ()) (** Send a binary message *) let send_binary t content = send t (make_frame ~opcode:Binary ~content ()) (** Send a ping *) let send_ping t ?(content = "") () = send t (make_frame ~opcode:Ping ~content ()) (** Send a pong *) let send_pong t ?(content = "") () = send t (make_frame ~opcode:Pong ~content ()) let recv t = if t.closed then Error Connection_closed else match read_frame ~is_client:t.is_client ~max_payload_size:t.max_payload_size t.flow with | Ok frame -> (match frame.opcode with | Close -> t.closed <- true; ignore (send t (close_frame (-1))) | Ping -> ignore (send_pong t ~content:frame.content ()) | _ -> ()); Ok frame | Error e -> t.closed <- true; Error e (** Receive a complete message (handles fragmentation) *) let recv_message t = let rec collect_fragments first_opcode buf = match recv t with | Error e -> Error e | Ok frame -> ( Buffer.add_string buf frame.content; if frame.final then Ok (first_opcode, Buffer.contents buf) else match frame.opcode with | Continuation -> collect_fragments first_opcode buf | _ -> Error (Protocol_error "Expected continuation frame")) in let rec loop () = match recv t with | Error e -> Error e | Ok frame -> ( match frame.opcode with | Text | Binary -> if frame.final then Ok (frame.opcode, frame.content) else begin let buf = Buffer.create 256 in Buffer.add_string buf frame.content; collect_fragments frame.opcode buf end | Close -> Error Connection_closed | Ping | Pong -> (* Control frames handled in recv, try again *) loop () | Continuation -> Error (Protocol_error "Unexpected continuation") | _ -> Error (Protocol_error "Unexpected opcode")) in loop () (** Close the connection *) let close ?(code = 1000) t = if not t.closed then begin t.closed <- true; ignore (send t (close_frame code)) end (** {1 Handshake helpers} *) (** Check if request headers indicate a WebSocket upgrade *) let is_upgrade_request headers = let upgrade = H1.Headers.get headers "upgrade" in let connection = H1.Headers.get headers "connection" in let key = H1.Headers.get headers "sec-websocket-key" in match (upgrade, connection, key) with | Some u, Some c, Some _ -> let u = String.lowercase_ascii u in let c = String.lowercase_ascii c in u = "websocket" && (c = "upgrade" || String.sub c 0 7 = "upgrade") | _ -> false (** Get the Sec-WebSocket-Key from request headers *) let get_websocket_key headers = H1.Headers.get headers "sec-websocket-key" let supported_websocket_version = "13" let get_websocket_version headers = H1.Headers.get headers "sec-websocket-version" let validate_websocket_version headers = match get_websocket_version headers with | Some v when v = supported_websocket_version -> Ok () | Some v -> Error ("Unsupported WebSocket version: " ^ v) | None -> Error "Missing Sec-WebSocket-Version header" (** Get the Origin header from request headers *) let get_origin headers = H1.Headers.get headers "origin" (** Origin policy for WebSocket connections. - [`Allow_all] accepts connections from any origin (NOT RECOMMENDED for production) - [`Allow_list origins] only accepts connections from the specified origins - [`Allow_same_origin] only accepts connections where Origin matches the Host header *) type origin_policy = Allow_all | Allow_list of string list | Allow_same_origin (** Validate Origin header against policy. @param policy The origin validation policy @param headers The request headers (must contain Origin, may contain Host) @return [Ok ()] if origin is allowed, [Error reason] if rejected *) let validate_origin ~policy headers = match policy with | Allow_all -> Ok () | Allow_list allowed -> ( match get_origin headers with | None -> (* Missing Origin header - could be same-origin request or non-browser client. For security, we require Origin for Allow_list policy. *) Error "Missing Origin header" | Some origin -> let origin_lower = String.lowercase_ascii origin in if List.exists (fun allowed -> String.lowercase_ascii allowed = origin_lower) allowed then Ok () else Error ("Origin not allowed: " ^ origin)) | Allow_same_origin -> ( match (get_origin headers, H1.Headers.get headers "host") with | None, _ -> (* No Origin header - likely same-origin or non-browser, allow it *) Ok () | Some origin, Some host -> (* Extract host from origin URL (e.g., "https://example.com" -> "example.com") *) let origin_host = let uri = Uri.of_string origin in match Uri.host uri with | Some h -> ( match Uri.port uri with | Some p -> h ^ ":" ^ string_of_int p | None -> h) | None -> origin in if String.lowercase_ascii origin_host = String.lowercase_ascii host then Ok () else Error ("Cross-origin request: " ^ origin ^ " vs " ^ host) | Some origin, None -> Error ("Missing Host header for origin check: " ^ origin)) (** Generate random base64-encoded key for client handshake *) let generate_key () = Base64.encode_exn (Rng.generate 16) (** {1 Client API} *) (** Connect to a WebSocket server *) let connect ~sw ~net ?(tls_config = Tls_config.Client.default) ?protocols url = let uri = Uri.of_string url in let scheme = Uri.scheme uri |> Option.value ~default:"ws" in let is_secure = scheme = "wss" in let host = Uri.host uri |> Option.value ~default:"localhost" in let default_port = if is_secure then 443 else 80 in let port = Uri.port uri |> Option.value ~default:default_port in let path = let p = Uri.path_and_query uri in if p = "" then "/" else p in (* Resolve and connect *) let addrs = Eio.Net.getaddrinfo_stream net host in match addrs with | [] -> Error (Io_error ("Cannot resolve host: " ^ host)) | addr_info :: _ -> ( let addr = match addr_info with | `Tcp (ip, _) -> `Tcp (ip, port) | `Unix _ -> failwith "Unix sockets not supported" in let tcp_flow = Eio.Net.connect ~sw net addr in (* Wrap with TLS if secure *) let flow_result = if is_secure then match Tls_config.Client.to_tls_config tls_config ~host with | Error msg -> Error (Io_error ("TLS error: " ^ msg)) | Ok tls_cfg -> ( try let host_domain = match Domain_name.of_string host with | Ok dn -> ( match Domain_name.host dn with | Ok h -> Some h | Error _ -> None) | Error _ -> None in let tls_flow = Tls_eio.client_of_flow tls_cfg ?host:host_domain tcp_flow in Ok (tls_flow :> Eio.Flow.two_way_ty r) with exn -> Error (Io_error (Printexc.to_string exn))) else Ok (tcp_flow :> Eio.Flow.two_way_ty r) in match flow_result with | Error e -> Error e | Ok flow -> ( (* Generate key and build upgrade request *) let key = generate_key () in let expected_accept = compute_accept_key key in let headers = [ ("Host", host); ("Upgrade", "websocket"); ("Connection", "Upgrade"); ("Sec-WebSocket-Key", key); ("Sec-WebSocket-Version", "13"); ] in let headers = match protocols with | Some ps -> ("Sec-WebSocket-Protocol", String.concat ", " ps) :: headers | None -> headers in (* Send HTTP upgrade request *) let buf = Buffer.create 256 in Buffer.add_string buf "GET "; Buffer.add_string buf path; Buffer.add_string buf " HTTP/1.1\r\n"; List.iter (fun (k, v) -> Buffer.add_string buf k; Buffer.add_string buf ": "; Buffer.add_string buf v; Buffer.add_string buf "\r\n") headers; Buffer.add_string buf "\r\n"; try Eio.Flow.write flow [ Cstruct.of_string (Buffer.contents buf) ]; (* Read response headers *) let response_buf = Buffer.create 1024 in let rec read_until_crlf_crlf () = let byte = read_exactly flow 1 in Buffer.add_string response_buf byte; let len = Buffer.length response_buf in if len >= 4 && Buffer.nth response_buf (len - 4) = '\r' && Buffer.nth response_buf (len - 3) = '\n' && Buffer.nth response_buf (len - 2) = '\r' && Buffer.nth response_buf (len - 1) = '\n' then () else read_until_crlf_crlf () in read_until_crlf_crlf (); (* Parse status line and headers *) let response_str = Buffer.contents response_buf in let lines = String.split_on_char '\n' response_str in (* Check status line *) match lines with | status_line :: header_lines -> ( let status_line = String.trim status_line in if not (String.length status_line >= 12 && String.sub status_line 9 3 = "101") then Error (Protocol_error ("Bad status: " ^ status_line)) else (* Parse headers *) let headers = List.filter_map (fun line -> let line = String.trim line in if line = "" then None else match String.index_opt line ':' with | Some i -> let key = String.lowercase_ascii (String.sub line 0 i) in let value = String.trim (String.sub line (i + 1) (String.length line - i - 1)) in Some (key, value) | None -> None) header_lines in let accept = List.assoc_opt "sec-websocket-accept" headers in match accept with | Some a when a = expected_accept -> Ok { flow; closed = false; is_client = true; read_buf = Buffer.create 4096; max_payload_size = default_max_payload_size; } | Some a -> let buf = Buffer.create (String.length a + String.length expected_accept + 32) in Buffer.add_string buf "Bad accept key: "; Buffer.add_string buf a; Buffer.add_string buf " (expected "; Buffer.add_string buf expected_accept; Buffer.add_char buf ')'; Error (Protocol_error (Buffer.contents buf)) | None -> Error (Protocol_error "Missing accept key")) | [] -> Error (Protocol_error "Empty response") with exn -> Error (Io_error (Printexc.to_string exn)))) (** {1 Server API} *) let accept ?max_payload_size ~flow ~key () = let max_payload_size = Option.value max_payload_size ~default:default_max_payload_size in let accept = compute_accept_key key in let buf = Buffer.create (String.length accept + 128) in Buffer.add_string buf "HTTP/1.1 101 Switching Protocols\r\n"; Buffer.add_string buf "Upgrade: websocket\r\n"; Buffer.add_string buf "Connection: Upgrade\r\n"; Buffer.add_string buf "Sec-WebSocket-Accept: "; Buffer.add_string buf accept; Buffer.add_string buf "\r\n\r\n"; let response = Buffer.contents buf in try Eio.Flow.write flow [ Cstruct.of_string response ]; Ok { flow :> Eio.Flow.two_way_ty r; closed = false; is_client = false; read_buf = Buffer.create 4096; max_payload_size; } with exn -> Error (Io_error (Printexc.to_string exn))