+58
lib/codec.ml
+58
lib/codec.ml
···
782
782
else ""
783
783
in
784
784
Ok (header, nodes, user_state))
785
+
786
+
let decode_compress_from_cstruct (buf : Cstruct.t) :
787
+
(int * Cstruct.t, Types.decode_error) result =
788
+
let data = Cstruct.to_string buf in
789
+
let _, msgpack = Msgpck.String.read data in
790
+
match msgpack with
791
+
| Msgpck.Map fields -> (
792
+
let algo =
793
+
match List.assoc_opt (Msgpck.String "Algo") fields with
794
+
| Some (Msgpck.Int i) -> i
795
+
| Some (Msgpck.Int32 i) -> Int32.to_int i
796
+
| _ -> -1
797
+
in
798
+
let compressed_buf =
799
+
match List.assoc_opt (Msgpck.String "Buf") fields with
800
+
| Some (Msgpck.Bytes s) -> Some (Cstruct.of_string s)
801
+
| Some (Msgpck.String s) -> Some (Cstruct.of_string s)
802
+
| _ -> None
803
+
in
804
+
match compressed_buf with
805
+
| Some cs -> Ok (algo, cs)
806
+
| None -> Error (Types.Msgpack_error "missing Buf field"))
807
+
| _ -> Error (Types.Msgpack_error "expected map for compress")
808
+
809
+
let decode_push_pull_msg_cstruct (buf : Cstruct.t) :
810
+
( push_pull_header * push_node_state list * Cstruct.t,
811
+
Types.decode_error )
812
+
result =
813
+
if Cstruct.length buf < 1 then Error Types.Truncated_message
814
+
else
815
+
let data = Cstruct.to_string buf in
816
+
let header_size, header_msgpack = Msgpck.String.read data in
817
+
match decode_push_pull_header header_msgpack with
818
+
| Error e -> Error (Types.Msgpack_error e)
819
+
| Ok header -> (
820
+
let rec read_nodes offset remaining acc =
821
+
if remaining <= 0 then Ok (List.rev acc, offset)
822
+
else if offset >= String.length data then
823
+
Error Types.Truncated_message
824
+
else
825
+
let rest = String.sub data offset (String.length data - offset) in
826
+
let node_size, node_msgpack = Msgpck.String.read rest in
827
+
match decode_push_node_state node_msgpack with
828
+
| Error e -> Error (Types.Msgpack_error e)
829
+
| Ok node ->
830
+
read_nodes (offset + node_size) (remaining - 1) (node :: acc)
831
+
in
832
+
match read_nodes header_size header.pp_nodes [] with
833
+
| Error e -> Error e
834
+
| Ok (nodes, offset) ->
835
+
let user_state =
836
+
if header.pp_user_state_len > 0 && offset < Cstruct.length buf
837
+
then
838
+
Cstruct.sub buf offset
839
+
(min header.pp_user_state_len (Cstruct.length buf - offset))
840
+
else Cstruct.empty
841
+
in
842
+
Ok (header, nodes, user_state))
+96
-67
lib/lzw.ml
+96
-67
lib/lzw.ml
···
13
13
let max_dict_size = 1 lsl max_code_bits
14
14
15
15
type bit_reader = {
16
-
data : string;
16
+
data : Cstruct.t;
17
17
mutable pos : int;
18
-
mutable bit_pos : int;
19
18
mutable bits_buf : int;
20
19
mutable bits_count : int;
21
20
}
22
21
23
-
let make_bit_reader data =
24
-
{ data; pos = 0; bit_pos = 0; bits_buf = 0; bits_count = 0 }
22
+
let make_bit_reader data = { data; pos = 0; bits_buf = 0; bits_count = 0 }
25
23
26
24
let read_bits_lsb reader n =
27
25
while reader.bits_count < n do
28
-
if reader.pos >= String.length reader.data then raise (Failure "eof")
26
+
if reader.pos >= Cstruct.length reader.data then raise Exit
29
27
else begin
30
-
let byte = Char.code reader.data.[reader.pos] in
28
+
let byte = Cstruct.get_uint8 reader.data reader.pos in
31
29
reader.bits_buf <- reader.bits_buf lor (byte lsl reader.bits_count);
32
30
reader.bits_count <- reader.bits_count + 8;
33
31
reader.pos <- reader.pos + 1
···
38
36
reader.bits_count <- reader.bits_count - n;
39
37
result
40
38
41
-
let decompress ?(order = LSB) ?(lit_width = 8) data =
42
-
if order <> LSB then Error (Invalid_code 0)
43
-
else if lit_width <> 8 then Error (Invalid_code 0)
44
-
else
45
-
try
46
-
let reader = make_bit_reader data in
47
-
let output = Buffer.create (String.length data * 2) in
39
+
let decompress_to_buffer ~src ~dst =
40
+
try
41
+
let reader = make_bit_reader src in
42
+
let out_pos = ref 0 in
43
+
let dst_len = Cstruct.length dst in
48
44
49
-
let dict = Array.make max_dict_size "" in
50
-
for i = 0 to 255 do
51
-
dict.(i) <- String.make 1 (Char.chr i)
52
-
done;
53
-
dict.(clear_code) <- "";
54
-
dict.(eof_code) <- "";
45
+
let dict = Array.make max_dict_size (Cstruct.empty, 0) in
46
+
for i = 0 to 255 do
47
+
dict.(i) <- (Cstruct.of_string (String.make 1 (Char.chr i)), 1)
48
+
done;
49
+
dict.(clear_code) <- (Cstruct.empty, 0);
50
+
dict.(eof_code) <- (Cstruct.empty, 0);
55
51
56
-
let dict_size = ref initial_dict_size in
57
-
let code_bits = ref 9 in
58
-
let prev_string = ref "" in
52
+
let dict_size = ref initial_dict_size in
53
+
let code_bits = ref 9 in
54
+
let prev_code = ref (-1) in
59
55
60
-
let add_to_dict s =
61
-
if !dict_size < max_dict_size then begin
62
-
dict.(!dict_size) <- s;
63
-
incr dict_size;
64
-
if !dict_size >= 1 lsl !code_bits && !code_bits < max_code_bits then
65
-
incr code_bits
66
-
end
67
-
in
56
+
let write_entry (entry, len) =
57
+
if !out_pos + len > dst_len then raise (Failure "overflow");
58
+
Cstruct.blit entry 0 dst !out_pos len;
59
+
out_pos := !out_pos + len
60
+
in
61
+
62
+
let add_to_dict first_byte =
63
+
if !dict_size < max_dict_size && !prev_code >= 0 then begin
64
+
let prev_entry, prev_len = dict.(!prev_code) in
65
+
let new_entry = Cstruct.create (prev_len + 1) in
66
+
Cstruct.blit prev_entry 0 new_entry 0 prev_len;
67
+
Cstruct.set_uint8 new_entry prev_len first_byte;
68
+
dict.(!dict_size) <- (new_entry, prev_len + 1);
69
+
incr dict_size;
70
+
if !dict_size >= 1 lsl !code_bits && !code_bits < max_code_bits then
71
+
incr code_bits
72
+
end
73
+
in
74
+
75
+
let reset_dict () =
76
+
dict_size := initial_dict_size;
77
+
code_bits := 9;
78
+
prev_code := -1
79
+
in
80
+
81
+
let rec decode_loop () =
82
+
let code = read_bits_lsb reader !code_bits in
83
+
if code = eof_code then ()
84
+
else if code = clear_code then begin
85
+
reset_dict ();
86
+
decode_loop ()
87
+
end
88
+
else begin
89
+
let entry, len, first_byte =
90
+
if code < !dict_size then
91
+
let e, l = dict.(code) in
92
+
(e, l, Cstruct.get_uint8 e 0)
93
+
else if code = !dict_size && !prev_code >= 0 then (
94
+
let prev_entry, prev_len = dict.(!prev_code) in
95
+
let first = Cstruct.get_uint8 prev_entry 0 in
96
+
let new_entry = Cstruct.create (prev_len + 1) in
97
+
Cstruct.blit prev_entry 0 new_entry 0 prev_len;
98
+
Cstruct.set_uint8 new_entry prev_len first;
99
+
(new_entry, prev_len + 1, first))
100
+
else raise (Failure "invalid")
101
+
in
102
+
write_entry (entry, len);
103
+
add_to_dict first_byte;
104
+
prev_code := code;
105
+
decode_loop ()
106
+
end
107
+
in
68
108
69
-
let reset_dict () =
70
-
dict_size := initial_dict_size;
71
-
code_bits := 9;
72
-
prev_string := ""
73
-
in
109
+
decode_loop ();
110
+
Ok !out_pos
111
+
with
112
+
| Exit -> Error Unexpected_eof
113
+
| Failure msg when msg = "overflow" -> Error Buffer_overflow
114
+
| Failure msg when msg = "invalid" -> Error (Invalid_code 0)
115
+
| _ -> Error (Invalid_code 0)
74
116
75
-
let rec decode_loop () =
76
-
let code = read_bits_lsb reader !code_bits in
77
-
if code = eof_code then ()
78
-
else if code = clear_code then begin
79
-
reset_dict ();
80
-
decode_loop ()
81
-
end
82
-
else begin
83
-
let current_string =
84
-
if code < !dict_size then dict.(code)
85
-
else if code = !dict_size then
86
-
!prev_string ^ String.make 1 !prev_string.[0]
87
-
else
88
-
raise
89
-
(Failure
90
-
(Printf.sprintf "invalid code %d >= %d" code !dict_size))
91
-
in
92
-
Buffer.add_string output current_string;
93
-
if !prev_string <> "" then
94
-
add_to_dict (!prev_string ^ String.make 1 current_string.[0]);
95
-
prev_string := current_string;
96
-
decode_loop ()
97
-
end
98
-
in
117
+
let decompress_cstruct src =
118
+
let estimated_size = max (Cstruct.length src * 4) 4096 in
119
+
let dst = Cstruct.create estimated_size in
120
+
match decompress_to_buffer ~src ~dst with
121
+
| Ok len -> Ok (Cstruct.sub dst 0 len)
122
+
| Error Buffer_overflow -> (
123
+
let larger = Cstruct.create (estimated_size * 4) in
124
+
match decompress_to_buffer ~src ~dst:larger with
125
+
| Ok len -> Ok (Cstruct.sub larger 0 len)
126
+
| Error e -> Error e)
127
+
| Error e -> Error e
99
128
100
-
decode_loop ();
101
-
Ok (Buffer.contents output)
102
-
with
103
-
| Failure msg when msg = "eof" -> Error Unexpected_eof
104
-
| Failure msg
105
-
when String.length msg > 12 && String.sub msg 0 12 = "invalid code" ->
106
-
Error (Invalid_code 0)
107
-
| _ -> Error (Invalid_code 0)
129
+
let decompress ?(order = LSB) ?(lit_width = 8) data =
130
+
if order <> LSB then Error (Invalid_code 0)
131
+
else if lit_width <> 8 then Error (Invalid_code 0)
132
+
else
133
+
let src = Cstruct.of_string data in
134
+
match decompress_cstruct src with
135
+
| Ok cs -> Ok (Cstruct.to_string cs)
136
+
| Error e -> Error e
108
137
109
138
let decompress_lsb8 data = decompress ~order:LSB ~lit_width:8 data
+2
lib/lzw.mli
+2
lib/lzw.mli
···
2
2
type error = Invalid_code of int | Unexpected_eof | Buffer_overflow
3
3
4
4
val error_to_string : error -> string
5
+
val decompress_to_buffer : src:Cstruct.t -> dst:Cstruct.t -> (int, error) result
6
+
val decompress_cstruct : Cstruct.t -> (Cstruct.t, error) result
5
7
6
8
val decompress :
7
9
?order:order -> ?lit_width:int -> string -> (string, error) result
+107
-91
lib/protocol.ml
+107
-91
lib/protocol.ml
···
11
11
probe_index : int Kcas.Loc.t;
12
12
send_pool : Buffer_pool.t;
13
13
recv_pool : Buffer_pool.t;
14
+
tcp_recv_pool : Buffer_pool.t;
15
+
tcp_decompress_pool : Buffer_pool.t;
14
16
udp_sock : [ `Generic ] Eio.Net.datagram_socket_ty Eio.Resource.t;
15
17
tcp_listener : [ `Generic ] Eio.Net.listening_socket_ty Eio.Resource.t;
16
18
event_stream : node_event Eio.Stream.t;
···
369
371
else None
370
372
| _ -> None
371
373
374
+
let decompress_payload_cstruct ~src ~dst =
375
+
match Codec.decode_compress_from_cstruct src with
376
+
| Error _ -> None
377
+
| Ok (algo, compressed) ->
378
+
if algo = 0 then
379
+
match Lzw.decompress_to_buffer ~src:compressed ~dst with
380
+
| Ok len -> Some len
381
+
| Error _ -> None
382
+
else None
383
+
372
384
let handle_tcp_connection t flow =
373
-
let buf = Cstruct.create 65536 in
374
-
match read_exact flow buf 1 with
375
-
| Error _ -> ()
376
-
| Ok () -> (
377
-
let msg_type_byte = Cstruct.get_uint8 buf 0 in
378
-
let get_push_pull_payload () =
379
-
let n = read_available flow (Cstruct.shift buf 1) in
380
-
if n > 0 then Some (Cstruct.sub buf 1 n) else None
381
-
in
382
-
let payload_opt =
383
-
if msg_type_byte = Types.Wire.message_type_to_int Types.Wire.Encrypt_msg
384
-
then
385
-
match get_push_pull_payload () with
386
-
| Some encrypted -> (
387
-
match Crypto.decrypt ~key:t.cipher_key encrypted with
388
-
| Ok decrypted -> Some decrypted
389
-
| Error _ -> None)
390
-
| None -> None
391
-
else if
392
-
msg_type_byte = Types.Wire.message_type_to_int Types.Wire.Compress_msg
393
-
then
394
-
match get_push_pull_payload () with
395
-
| Some compressed -> (
396
-
let data = Cstruct.to_string compressed in
397
-
match decompress_payload data with
398
-
| Some decompressed ->
399
-
if String.length decompressed > 0 then
400
-
let inner_type = Char.code decompressed.[0] in
401
-
if
402
-
inner_type
403
-
= Types.Wire.message_type_to_int Types.Wire.Push_pull_msg
404
-
then
405
-
Some
406
-
(Cstruct.of_string
407
-
(String.sub decompressed 1
408
-
(String.length decompressed - 1)))
409
-
else None
410
-
else None
411
-
| None -> None)
412
-
| None -> None
413
-
else if
414
-
msg_type_byte
415
-
= Types.Wire.message_type_to_int Types.Wire.Has_label_msg
416
-
then
385
+
Buffer_pool.with_buffer t.tcp_recv_pool (fun buf ->
386
+
Buffer_pool.with_buffer t.tcp_decompress_pool (fun decomp_buf ->
417
387
match read_exact flow buf 1 with
418
-
| Error _ -> None
419
-
| Ok () ->
420
-
let label_len = Cstruct.get_uint8 buf 0 in
421
-
if label_len > 0 then
422
-
match read_exact flow buf label_len with
423
-
| Error _ -> None
424
-
| Ok () -> (
425
-
match read_exact flow buf 1 with
426
-
| Error _ -> None
427
-
| Ok () ->
428
-
let inner_type = Cstruct.get_uint8 buf 0 in
429
-
if
430
-
inner_type
431
-
= Types.Wire.message_type_to_int
432
-
Types.Wire.Push_pull_msg
433
-
then get_push_pull_payload ()
434
-
else None)
435
-
else None
436
-
else if
437
-
msg_type_byte
438
-
= Types.Wire.message_type_to_int Types.Wire.Push_pull_msg
439
-
then get_push_pull_payload ()
440
-
else None
441
-
in
442
-
match payload_opt with
443
-
| None -> ()
444
-
| Some payload -> (
445
-
let data = Cstruct.to_string payload in
446
-
match Codec.decode_push_pull_msg data with
447
388
| Error _ -> ()
448
-
| Ok (header, nodes, _user_state) -> (
449
-
merge_remote_state t nodes ~is_join:header.pp_join;
450
-
let resp_header, resp_nodes =
451
-
build_local_state t ~is_join:false
452
-
in
453
-
let response =
454
-
Codec.encode_push_pull_msg ~header:resp_header ~nodes:resp_nodes
455
-
~user_state:""
389
+
| Ok () -> (
390
+
let msg_type_byte = Cstruct.get_uint8 buf 0 in
391
+
let get_push_pull_payload () =
392
+
let n = read_available flow (Cstruct.shift buf 1) in
393
+
if n > 0 then Some (Cstruct.sub buf 1 n) else None
456
394
in
457
-
let resp_buf =
458
-
if t.config.encryption_enabled then
459
-
let plain = Cstruct.of_string response in
460
-
let encrypted =
461
-
Crypto.encrypt ~key:t.cipher_key ~random:t.secure_random
462
-
plain
463
-
in
464
-
encrypted
465
-
else Cstruct.of_string response
395
+
let payload_opt =
396
+
if
397
+
msg_type_byte
398
+
= Types.Wire.message_type_to_int Types.Wire.Encrypt_msg
399
+
then
400
+
match get_push_pull_payload () with
401
+
| Some encrypted -> (
402
+
match Crypto.decrypt ~key:t.cipher_key encrypted with
403
+
| Ok decrypted -> Some decrypted
404
+
| Error _ -> None)
405
+
| None -> None
406
+
else if
407
+
msg_type_byte
408
+
= Types.Wire.message_type_to_int Types.Wire.Compress_msg
409
+
then
410
+
match get_push_pull_payload () with
411
+
| Some compressed -> (
412
+
match
413
+
decompress_payload_cstruct ~src:compressed
414
+
~dst:decomp_buf
415
+
with
416
+
| Some len ->
417
+
if len > 0 then
418
+
let inner_type = Cstruct.get_uint8 decomp_buf 0 in
419
+
if
420
+
inner_type
421
+
= Types.Wire.message_type_to_int
422
+
Types.Wire.Push_pull_msg
423
+
then Some (Cstruct.sub decomp_buf 1 (len - 1))
424
+
else None
425
+
else None
426
+
| None -> None)
427
+
| None -> None
428
+
else if
429
+
msg_type_byte
430
+
= Types.Wire.message_type_to_int Types.Wire.Has_label_msg
431
+
then
432
+
match read_exact flow buf 1 with
433
+
| Error _ -> None
434
+
| Ok () ->
435
+
let label_len = Cstruct.get_uint8 buf 0 in
436
+
if label_len > 0 then
437
+
match read_exact flow buf label_len with
438
+
| Error _ -> None
439
+
| Ok () -> (
440
+
match read_exact flow buf 1 with
441
+
| Error _ -> None
442
+
| Ok () ->
443
+
let inner_type = Cstruct.get_uint8 buf 0 in
444
+
if
445
+
inner_type
446
+
= Types.Wire.message_type_to_int
447
+
Types.Wire.Push_pull_msg
448
+
then get_push_pull_payload ()
449
+
else None)
450
+
else None
451
+
else if
452
+
msg_type_byte
453
+
= Types.Wire.message_type_to_int Types.Wire.Push_pull_msg
454
+
then get_push_pull_payload ()
455
+
else None
466
456
in
467
-
try Eio.Flow.write flow [ resp_buf ] with _ -> ())))
457
+
match payload_opt with
458
+
| None -> ()
459
+
| Some payload -> (
460
+
match Codec.decode_push_pull_msg_cstruct payload with
461
+
| Error _ -> ()
462
+
| Ok (header, nodes, _user_state) -> (
463
+
merge_remote_state t nodes ~is_join:header.pp_join;
464
+
let resp_header, resp_nodes =
465
+
build_local_state t ~is_join:false
466
+
in
467
+
let response =
468
+
Codec.encode_push_pull_msg ~header:resp_header
469
+
~nodes:resp_nodes ~user_state:""
470
+
in
471
+
let resp_buf =
472
+
if t.config.encryption_enabled then
473
+
let plain = Cstruct.of_string response in
474
+
let encrypted =
475
+
Crypto.encrypt ~key:t.cipher_key
476
+
~random:t.secure_random plain
477
+
in
478
+
encrypted
479
+
else Cstruct.of_string response
480
+
in
481
+
try Eio.Flow.write flow [ resp_buf ] with _ -> ())))))
468
482
469
483
let run_tcp_listener t =
470
484
while not (is_shutdown t) do
···
589
603
recv_pool =
590
604
Buffer_pool.create ~size:config.udp_buffer_size
591
605
~count:config.recv_buffer_count;
606
+
tcp_recv_pool = Buffer_pool.create ~size:65536 ~count:4;
607
+
tcp_decompress_pool = Buffer_pool.create ~size:131072 ~count:4;
592
608
udp_sock;
593
609
tcp_listener;
594
610
event_stream = Eio.Stream.create 100;