swim protocol in ocaml interoperable with membership lib and serf cli

perf: optimize memory allocation with buffer pools and Cstruct

- Add TCP buffer pools to Protocol.t (64KB recv, 128KB decompress)
- Update LZW to work directly with Cstruct (zero-copy decompression)
- Add decompress_to_buffer and decompress_cstruct functions to LZW
- Add Cstruct-based codec functions (decode_compress_from_cstruct,
decode_push_pull_msg_cstruct)
- Update handle_tcp_connection to use buffer pools instead of per-
connection allocation
- Eliminate Cstruct.to_string conversions in hot path

Closes: swim-hrd

Changed files
+263 -158
lib
+58
lib/codec.ml
··· 782 782 else "" 783 783 in 784 784 Ok (header, nodes, user_state)) 785 + 786 + let decode_compress_from_cstruct (buf : Cstruct.t) : 787 + (int * Cstruct.t, Types.decode_error) result = 788 + let data = Cstruct.to_string buf in 789 + let _, msgpack = Msgpck.String.read data in 790 + match msgpack with 791 + | Msgpck.Map fields -> ( 792 + let algo = 793 + match List.assoc_opt (Msgpck.String "Algo") fields with 794 + | Some (Msgpck.Int i) -> i 795 + | Some (Msgpck.Int32 i) -> Int32.to_int i 796 + | _ -> -1 797 + in 798 + let compressed_buf = 799 + match List.assoc_opt (Msgpck.String "Buf") fields with 800 + | Some (Msgpck.Bytes s) -> Some (Cstruct.of_string s) 801 + | Some (Msgpck.String s) -> Some (Cstruct.of_string s) 802 + | _ -> None 803 + in 804 + match compressed_buf with 805 + | Some cs -> Ok (algo, cs) 806 + | None -> Error (Types.Msgpack_error "missing Buf field")) 807 + | _ -> Error (Types.Msgpack_error "expected map for compress") 808 + 809 + let decode_push_pull_msg_cstruct (buf : Cstruct.t) : 810 + ( push_pull_header * push_node_state list * Cstruct.t, 811 + Types.decode_error ) 812 + result = 813 + if Cstruct.length buf < 1 then Error Types.Truncated_message 814 + else 815 + let data = Cstruct.to_string buf in 816 + let header_size, header_msgpack = Msgpck.String.read data in 817 + match decode_push_pull_header header_msgpack with 818 + | Error e -> Error (Types.Msgpack_error e) 819 + | Ok header -> ( 820 + let rec read_nodes offset remaining acc = 821 + if remaining <= 0 then Ok (List.rev acc, offset) 822 + else if offset >= String.length data then 823 + Error Types.Truncated_message 824 + else 825 + let rest = String.sub data offset (String.length data - offset) in 826 + let node_size, node_msgpack = Msgpck.String.read rest in 827 + match decode_push_node_state node_msgpack with 828 + | Error e -> Error (Types.Msgpack_error e) 829 + | Ok node -> 830 + read_nodes (offset + node_size) (remaining - 1) (node :: acc) 831 + in 832 + match read_nodes header_size header.pp_nodes [] with 833 + | Error e -> Error e 834 + | Ok (nodes, offset) -> 835 + let user_state = 836 + if header.pp_user_state_len > 0 && offset < Cstruct.length buf 837 + then 838 + Cstruct.sub buf offset 839 + (min header.pp_user_state_len (Cstruct.length buf - offset)) 840 + else Cstruct.empty 841 + in 842 + Ok (header, nodes, user_state))
+96 -67
lib/lzw.ml
··· 13 13 let max_dict_size = 1 lsl max_code_bits 14 14 15 15 type bit_reader = { 16 - data : string; 16 + data : Cstruct.t; 17 17 mutable pos : int; 18 - mutable bit_pos : int; 19 18 mutable bits_buf : int; 20 19 mutable bits_count : int; 21 20 } 22 21 23 - let make_bit_reader data = 24 - { data; pos = 0; bit_pos = 0; bits_buf = 0; bits_count = 0 } 22 + let make_bit_reader data = { data; pos = 0; bits_buf = 0; bits_count = 0 } 25 23 26 24 let read_bits_lsb reader n = 27 25 while reader.bits_count < n do 28 - if reader.pos >= String.length reader.data then raise (Failure "eof") 26 + if reader.pos >= Cstruct.length reader.data then raise Exit 29 27 else begin 30 - let byte = Char.code reader.data.[reader.pos] in 28 + let byte = Cstruct.get_uint8 reader.data reader.pos in 31 29 reader.bits_buf <- reader.bits_buf lor (byte lsl reader.bits_count); 32 30 reader.bits_count <- reader.bits_count + 8; 33 31 reader.pos <- reader.pos + 1 ··· 38 36 reader.bits_count <- reader.bits_count - n; 39 37 result 40 38 41 - let decompress ?(order = LSB) ?(lit_width = 8) data = 42 - if order <> LSB then Error (Invalid_code 0) 43 - else if lit_width <> 8 then Error (Invalid_code 0) 44 - else 45 - try 46 - let reader = make_bit_reader data in 47 - let output = Buffer.create (String.length data * 2) in 39 + let decompress_to_buffer ~src ~dst = 40 + try 41 + let reader = make_bit_reader src in 42 + let out_pos = ref 0 in 43 + let dst_len = Cstruct.length dst in 48 44 49 - let dict = Array.make max_dict_size "" in 50 - for i = 0 to 255 do 51 - dict.(i) <- String.make 1 (Char.chr i) 52 - done; 53 - dict.(clear_code) <- ""; 54 - dict.(eof_code) <- ""; 45 + let dict = Array.make max_dict_size (Cstruct.empty, 0) in 46 + for i = 0 to 255 do 47 + dict.(i) <- (Cstruct.of_string (String.make 1 (Char.chr i)), 1) 48 + done; 49 + dict.(clear_code) <- (Cstruct.empty, 0); 50 + dict.(eof_code) <- (Cstruct.empty, 0); 55 51 56 - let dict_size = ref initial_dict_size in 57 - let code_bits = ref 9 in 58 - let prev_string = ref "" in 52 + let dict_size = ref initial_dict_size in 53 + let code_bits = ref 9 in 54 + let prev_code = ref (-1) in 59 55 60 - let add_to_dict s = 61 - if !dict_size < max_dict_size then begin 62 - dict.(!dict_size) <- s; 63 - incr dict_size; 64 - if !dict_size >= 1 lsl !code_bits && !code_bits < max_code_bits then 65 - incr code_bits 66 - end 67 - in 56 + let write_entry (entry, len) = 57 + if !out_pos + len > dst_len then raise (Failure "overflow"); 58 + Cstruct.blit entry 0 dst !out_pos len; 59 + out_pos := !out_pos + len 60 + in 61 + 62 + let add_to_dict first_byte = 63 + if !dict_size < max_dict_size && !prev_code >= 0 then begin 64 + let prev_entry, prev_len = dict.(!prev_code) in 65 + let new_entry = Cstruct.create (prev_len + 1) in 66 + Cstruct.blit prev_entry 0 new_entry 0 prev_len; 67 + Cstruct.set_uint8 new_entry prev_len first_byte; 68 + dict.(!dict_size) <- (new_entry, prev_len + 1); 69 + incr dict_size; 70 + if !dict_size >= 1 lsl !code_bits && !code_bits < max_code_bits then 71 + incr code_bits 72 + end 73 + in 74 + 75 + let reset_dict () = 76 + dict_size := initial_dict_size; 77 + code_bits := 9; 78 + prev_code := -1 79 + in 80 + 81 + let rec decode_loop () = 82 + let code = read_bits_lsb reader !code_bits in 83 + if code = eof_code then () 84 + else if code = clear_code then begin 85 + reset_dict (); 86 + decode_loop () 87 + end 88 + else begin 89 + let entry, len, first_byte = 90 + if code < !dict_size then 91 + let e, l = dict.(code) in 92 + (e, l, Cstruct.get_uint8 e 0) 93 + else if code = !dict_size && !prev_code >= 0 then ( 94 + let prev_entry, prev_len = dict.(!prev_code) in 95 + let first = Cstruct.get_uint8 prev_entry 0 in 96 + let new_entry = Cstruct.create (prev_len + 1) in 97 + Cstruct.blit prev_entry 0 new_entry 0 prev_len; 98 + Cstruct.set_uint8 new_entry prev_len first; 99 + (new_entry, prev_len + 1, first)) 100 + else raise (Failure "invalid") 101 + in 102 + write_entry (entry, len); 103 + add_to_dict first_byte; 104 + prev_code := code; 105 + decode_loop () 106 + end 107 + in 68 108 69 - let reset_dict () = 70 - dict_size := initial_dict_size; 71 - code_bits := 9; 72 - prev_string := "" 73 - in 109 + decode_loop (); 110 + Ok !out_pos 111 + with 112 + | Exit -> Error Unexpected_eof 113 + | Failure msg when msg = "overflow" -> Error Buffer_overflow 114 + | Failure msg when msg = "invalid" -> Error (Invalid_code 0) 115 + | _ -> Error (Invalid_code 0) 74 116 75 - let rec decode_loop () = 76 - let code = read_bits_lsb reader !code_bits in 77 - if code = eof_code then () 78 - else if code = clear_code then begin 79 - reset_dict (); 80 - decode_loop () 81 - end 82 - else begin 83 - let current_string = 84 - if code < !dict_size then dict.(code) 85 - else if code = !dict_size then 86 - !prev_string ^ String.make 1 !prev_string.[0] 87 - else 88 - raise 89 - (Failure 90 - (Printf.sprintf "invalid code %d >= %d" code !dict_size)) 91 - in 92 - Buffer.add_string output current_string; 93 - if !prev_string <> "" then 94 - add_to_dict (!prev_string ^ String.make 1 current_string.[0]); 95 - prev_string := current_string; 96 - decode_loop () 97 - end 98 - in 117 + let decompress_cstruct src = 118 + let estimated_size = max (Cstruct.length src * 4) 4096 in 119 + let dst = Cstruct.create estimated_size in 120 + match decompress_to_buffer ~src ~dst with 121 + | Ok len -> Ok (Cstruct.sub dst 0 len) 122 + | Error Buffer_overflow -> ( 123 + let larger = Cstruct.create (estimated_size * 4) in 124 + match decompress_to_buffer ~src ~dst:larger with 125 + | Ok len -> Ok (Cstruct.sub larger 0 len) 126 + | Error e -> Error e) 127 + | Error e -> Error e 99 128 100 - decode_loop (); 101 - Ok (Buffer.contents output) 102 - with 103 - | Failure msg when msg = "eof" -> Error Unexpected_eof 104 - | Failure msg 105 - when String.length msg > 12 && String.sub msg 0 12 = "invalid code" -> 106 - Error (Invalid_code 0) 107 - | _ -> Error (Invalid_code 0) 129 + let decompress ?(order = LSB) ?(lit_width = 8) data = 130 + if order <> LSB then Error (Invalid_code 0) 131 + else if lit_width <> 8 then Error (Invalid_code 0) 132 + else 133 + let src = Cstruct.of_string data in 134 + match decompress_cstruct src with 135 + | Ok cs -> Ok (Cstruct.to_string cs) 136 + | Error e -> Error e 108 137 109 138 let decompress_lsb8 data = decompress ~order:LSB ~lit_width:8 data
+2
lib/lzw.mli
··· 2 2 type error = Invalid_code of int | Unexpected_eof | Buffer_overflow 3 3 4 4 val error_to_string : error -> string 5 + val decompress_to_buffer : src:Cstruct.t -> dst:Cstruct.t -> (int, error) result 6 + val decompress_cstruct : Cstruct.t -> (Cstruct.t, error) result 5 7 6 8 val decompress : 7 9 ?order:order -> ?lit_width:int -> string -> (string, error) result
+107 -91
lib/protocol.ml
··· 11 11 probe_index : int Kcas.Loc.t; 12 12 send_pool : Buffer_pool.t; 13 13 recv_pool : Buffer_pool.t; 14 + tcp_recv_pool : Buffer_pool.t; 15 + tcp_decompress_pool : Buffer_pool.t; 14 16 udp_sock : [ `Generic ] Eio.Net.datagram_socket_ty Eio.Resource.t; 15 17 tcp_listener : [ `Generic ] Eio.Net.listening_socket_ty Eio.Resource.t; 16 18 event_stream : node_event Eio.Stream.t; ··· 369 371 else None 370 372 | _ -> None 371 373 374 + let decompress_payload_cstruct ~src ~dst = 375 + match Codec.decode_compress_from_cstruct src with 376 + | Error _ -> None 377 + | Ok (algo, compressed) -> 378 + if algo = 0 then 379 + match Lzw.decompress_to_buffer ~src:compressed ~dst with 380 + | Ok len -> Some len 381 + | Error _ -> None 382 + else None 383 + 372 384 let handle_tcp_connection t flow = 373 - let buf = Cstruct.create 65536 in 374 - match read_exact flow buf 1 with 375 - | Error _ -> () 376 - | Ok () -> ( 377 - let msg_type_byte = Cstruct.get_uint8 buf 0 in 378 - let get_push_pull_payload () = 379 - let n = read_available flow (Cstruct.shift buf 1) in 380 - if n > 0 then Some (Cstruct.sub buf 1 n) else None 381 - in 382 - let payload_opt = 383 - if msg_type_byte = Types.Wire.message_type_to_int Types.Wire.Encrypt_msg 384 - then 385 - match get_push_pull_payload () with 386 - | Some encrypted -> ( 387 - match Crypto.decrypt ~key:t.cipher_key encrypted with 388 - | Ok decrypted -> Some decrypted 389 - | Error _ -> None) 390 - | None -> None 391 - else if 392 - msg_type_byte = Types.Wire.message_type_to_int Types.Wire.Compress_msg 393 - then 394 - match get_push_pull_payload () with 395 - | Some compressed -> ( 396 - let data = Cstruct.to_string compressed in 397 - match decompress_payload data with 398 - | Some decompressed -> 399 - if String.length decompressed > 0 then 400 - let inner_type = Char.code decompressed.[0] in 401 - if 402 - inner_type 403 - = Types.Wire.message_type_to_int Types.Wire.Push_pull_msg 404 - then 405 - Some 406 - (Cstruct.of_string 407 - (String.sub decompressed 1 408 - (String.length decompressed - 1))) 409 - else None 410 - else None 411 - | None -> None) 412 - | None -> None 413 - else if 414 - msg_type_byte 415 - = Types.Wire.message_type_to_int Types.Wire.Has_label_msg 416 - then 385 + Buffer_pool.with_buffer t.tcp_recv_pool (fun buf -> 386 + Buffer_pool.with_buffer t.tcp_decompress_pool (fun decomp_buf -> 417 387 match read_exact flow buf 1 with 418 - | Error _ -> None 419 - | Ok () -> 420 - let label_len = Cstruct.get_uint8 buf 0 in 421 - if label_len > 0 then 422 - match read_exact flow buf label_len with 423 - | Error _ -> None 424 - | Ok () -> ( 425 - match read_exact flow buf 1 with 426 - | Error _ -> None 427 - | Ok () -> 428 - let inner_type = Cstruct.get_uint8 buf 0 in 429 - if 430 - inner_type 431 - = Types.Wire.message_type_to_int 432 - Types.Wire.Push_pull_msg 433 - then get_push_pull_payload () 434 - else None) 435 - else None 436 - else if 437 - msg_type_byte 438 - = Types.Wire.message_type_to_int Types.Wire.Push_pull_msg 439 - then get_push_pull_payload () 440 - else None 441 - in 442 - match payload_opt with 443 - | None -> () 444 - | Some payload -> ( 445 - let data = Cstruct.to_string payload in 446 - match Codec.decode_push_pull_msg data with 447 388 | Error _ -> () 448 - | Ok (header, nodes, _user_state) -> ( 449 - merge_remote_state t nodes ~is_join:header.pp_join; 450 - let resp_header, resp_nodes = 451 - build_local_state t ~is_join:false 452 - in 453 - let response = 454 - Codec.encode_push_pull_msg ~header:resp_header ~nodes:resp_nodes 455 - ~user_state:"" 389 + | Ok () -> ( 390 + let msg_type_byte = Cstruct.get_uint8 buf 0 in 391 + let get_push_pull_payload () = 392 + let n = read_available flow (Cstruct.shift buf 1) in 393 + if n > 0 then Some (Cstruct.sub buf 1 n) else None 456 394 in 457 - let resp_buf = 458 - if t.config.encryption_enabled then 459 - let plain = Cstruct.of_string response in 460 - let encrypted = 461 - Crypto.encrypt ~key:t.cipher_key ~random:t.secure_random 462 - plain 463 - in 464 - encrypted 465 - else Cstruct.of_string response 395 + let payload_opt = 396 + if 397 + msg_type_byte 398 + = Types.Wire.message_type_to_int Types.Wire.Encrypt_msg 399 + then 400 + match get_push_pull_payload () with 401 + | Some encrypted -> ( 402 + match Crypto.decrypt ~key:t.cipher_key encrypted with 403 + | Ok decrypted -> Some decrypted 404 + | Error _ -> None) 405 + | None -> None 406 + else if 407 + msg_type_byte 408 + = Types.Wire.message_type_to_int Types.Wire.Compress_msg 409 + then 410 + match get_push_pull_payload () with 411 + | Some compressed -> ( 412 + match 413 + decompress_payload_cstruct ~src:compressed 414 + ~dst:decomp_buf 415 + with 416 + | Some len -> 417 + if len > 0 then 418 + let inner_type = Cstruct.get_uint8 decomp_buf 0 in 419 + if 420 + inner_type 421 + = Types.Wire.message_type_to_int 422 + Types.Wire.Push_pull_msg 423 + then Some (Cstruct.sub decomp_buf 1 (len - 1)) 424 + else None 425 + else None 426 + | None -> None) 427 + | None -> None 428 + else if 429 + msg_type_byte 430 + = Types.Wire.message_type_to_int Types.Wire.Has_label_msg 431 + then 432 + match read_exact flow buf 1 with 433 + | Error _ -> None 434 + | Ok () -> 435 + let label_len = Cstruct.get_uint8 buf 0 in 436 + if label_len > 0 then 437 + match read_exact flow buf label_len with 438 + | Error _ -> None 439 + | Ok () -> ( 440 + match read_exact flow buf 1 with 441 + | Error _ -> None 442 + | Ok () -> 443 + let inner_type = Cstruct.get_uint8 buf 0 in 444 + if 445 + inner_type 446 + = Types.Wire.message_type_to_int 447 + Types.Wire.Push_pull_msg 448 + then get_push_pull_payload () 449 + else None) 450 + else None 451 + else if 452 + msg_type_byte 453 + = Types.Wire.message_type_to_int Types.Wire.Push_pull_msg 454 + then get_push_pull_payload () 455 + else None 466 456 in 467 - try Eio.Flow.write flow [ resp_buf ] with _ -> ()))) 457 + match payload_opt with 458 + | None -> () 459 + | Some payload -> ( 460 + match Codec.decode_push_pull_msg_cstruct payload with 461 + | Error _ -> () 462 + | Ok (header, nodes, _user_state) -> ( 463 + merge_remote_state t nodes ~is_join:header.pp_join; 464 + let resp_header, resp_nodes = 465 + build_local_state t ~is_join:false 466 + in 467 + let response = 468 + Codec.encode_push_pull_msg ~header:resp_header 469 + ~nodes:resp_nodes ~user_state:"" 470 + in 471 + let resp_buf = 472 + if t.config.encryption_enabled then 473 + let plain = Cstruct.of_string response in 474 + let encrypted = 475 + Crypto.encrypt ~key:t.cipher_key 476 + ~random:t.secure_random plain 477 + in 478 + encrypted 479 + else Cstruct.of_string response 480 + in 481 + try Eio.Flow.write flow [ resp_buf ] with _ -> ()))))) 468 482 469 483 let run_tcp_listener t = 470 484 while not (is_shutdown t) do ··· 589 603 recv_pool = 590 604 Buffer_pool.create ~size:config.udp_buffer_size 591 605 ~count:config.recv_buffer_count; 606 + tcp_recv_pool = Buffer_pool.create ~size:65536 ~count:4; 607 + tcp_decompress_pool = Buffer_pool.create ~size:131072 ~count:4; 592 608 udp_sock; 593 609 tcp_listener; 594 610 event_stream = Eio.Stream.create 100;