upstream: https://github.com/janestreet/memtrace
1module Shared_writer_fd = struct
2 type t = { lock : Mutex.t; closed : bool Atomic.t; fd : Unix.file_descr }
3
4 let make fd = { lock = Mutex.create (); closed = Atomic.make false; fd }
5
6 exception Closed
7
8 let rec write_fully fd buf ~pos ~len =
9 if len = 0 then ()
10 else
11 let written = Unix.write fd buf pos len in
12 write_fully fd buf ~pos:(pos + written) ~len:(len - written)
13
14 let write_fully t buf ~pos ~len =
15 Mutex.lock t.lock;
16 Fun.protect
17 (fun () ->
18 if Atomic.get t.closed then raise Closed;
19 write_fully t.fd buf ~pos ~len)
20 ~finally:(fun () -> Mutex.unlock t.lock)
21
22 let close t =
23 Mutex.lock t.lock;
24 Atomic.set t.closed true;
25 Mutex.unlock t.lock
26end
27
28module Shared = struct
29 type t = { buf : Bytes.t; mutable pos : int; pos_end : int }
30
31 let of_bytes buf = { buf; pos = 0; pos_end = Bytes.length buf }
32 let of_bytes_sub buf ~pos ~pos_end = { buf; pos; pos_end }
33 let remaining b = b.pos_end - b.pos
34
35 external bswap_16 : int -> int = "%bswap16"
36 external bswap_32 : int32 -> int32 = "%bswap_int32"
37 external bswap_64 : int64 -> int64 = "%bswap_int64"
38end
39
40module Write = struct
41 include Shared
42
43 let write_fd fd b = Shared_writer_fd.write_fully fd b.buf ~pos:0 ~len:b.pos
44 let put_raw_8 b i v = Bytes.unsafe_set b i (Char.unsafe_chr v)
45
46 external put_raw_16 : Bytes.t -> int -> int -> unit = "%caml_bytes_set16u"
47 external put_raw_32 : Bytes.t -> int -> int32 -> unit = "%caml_bytes_set32u"
48 external put_raw_64 : Bytes.t -> int -> int64 -> unit = "%caml_bytes_set64u"
49
50 exception Overflow of int
51
52 let[@inline never] overflow b = Overflow b.pos
53
54 let[@inline always] put_8 b v =
55 let pos = b.pos in
56 let pos' = b.pos + 1 in
57 if pos' > b.pos_end then raise (overflow b)
58 else (
59 put_raw_8 b.buf pos v;
60 b.pos <- pos')
61
62 let[@inline always] put_16 b v =
63 let pos = b.pos in
64 let pos' = b.pos + 2 in
65 if pos' > b.pos_end then raise (overflow b)
66 else (
67 put_raw_16 b.buf pos (if Sys.big_endian then bswap_16 v else v);
68 b.pos <- pos')
69
70 let[@inline always] put_32 b v =
71 let pos = b.pos in
72 let pos' = b.pos + 4 in
73 if pos' > b.pos_end then raise (overflow b)
74 else (
75 put_raw_32 b.buf pos (if Sys.big_endian then bswap_32 v else v);
76 b.pos <- pos')
77
78 let[@inline always] put_64 b v =
79 let pos = b.pos in
80 let pos' = b.pos + 8 in
81 if pos' > b.pos_end then raise (overflow b)
82 else (
83 put_raw_64 b.buf pos (if Sys.big_endian then bswap_64 v else v);
84 b.pos <- pos')
85
86 let[@inline always] put_float b f = put_64 b (Int64.bits_of_float f)
87
88 let put_string b s =
89 let slen =
90 match String.index_opt s '\000' with
91 | Some i -> i
92 | None -> String.length s
93 in
94 if b.pos + slen + 1 > b.pos_end then raise (overflow b);
95 Bytes.blit_string s 0 b.buf b.pos slen;
96 Bytes.unsafe_set b.buf (b.pos + slen) '\000';
97 b.pos <- b.pos + slen + 1
98
99 let[@inline never] put_vint_big b v =
100 if v = v land 0xffff then (
101 put_8 b 253;
102 put_16 b v)
103 else if v = Int32.to_int (Int32.of_int v) then (
104 put_8 b 254;
105 put_32 b (Int32.of_int v))
106 else (
107 put_8 b 255;
108 put_64 b (Int64.of_int v))
109
110 let[@inline always] put_vint b v =
111 if 0 <= v && v <= 252 then put_8 b v else put_vint_big b v
112
113 type position_8 = int
114 type position_16 = int
115 type position_32 = int
116 type position_64 = int
117 type position_float = int
118
119 let[@inline always] skip_8 b =
120 let pos = b.pos in
121 let pos' = b.pos + 1 in
122 if pos' > b.pos_end then raise (overflow b);
123 b.pos <- pos';
124 pos
125
126 let[@inline always] skip_16 b =
127 let pos = b.pos in
128 let pos' = b.pos + 2 in
129 if pos' > b.pos_end then raise (overflow b);
130 b.pos <- pos';
131 pos
132
133 let[@inline always] skip_32 b =
134 let pos = b.pos in
135 let pos' = b.pos + 4 in
136 if pos' > b.pos_end then raise (overflow b);
137 b.pos <- pos';
138 pos
139
140 let[@inline always] skip_64 b =
141 let pos = b.pos in
142 let pos' = b.pos + 8 in
143 if pos' > b.pos_end then raise (overflow b);
144 b.pos <- pos';
145 pos
146
147 let skip_float = skip_64
148
149 let update_8 b pos v =
150 assert (pos + 1 <= b.pos_end);
151 put_raw_8 b.buf pos v
152
153 let update_16 b pos v =
154 assert (pos + 2 <= b.pos_end);
155 put_raw_16 b.buf pos v
156
157 let update_32 b pos v =
158 assert (pos + 4 <= b.pos_end);
159 put_raw_32 b.buf pos v
160
161 let update_64 b pos v =
162 assert (pos + 8 <= b.pos_end);
163 put_raw_64 b.buf pos v
164
165 let update_float b pos f = update_64 b pos (Int64.bits_of_float f)
166end
167
168module Read = struct
169 include Shared
170
171 let rec read_into fd buf off =
172 if off = Bytes.length buf then { buf; pos = 0; pos_end = off }
173 else (
174 assert (0 <= off && off <= Bytes.length buf);
175 let n = Unix.read fd buf off (Bytes.length buf - off) in
176 if n = 0 then (* EOF *)
177 { buf; pos = 0; pos_end = off }
178 else (* Short read *)
179 read_into fd buf (off + n))
180
181 let read_fd fd buf = read_into fd buf 0
182
183 let refill_fd fd b =
184 let len = remaining b in
185 Bytes.blit b.buf b.pos b.buf 0 len;
186 read_into fd b.buf len
187
188 let split b len =
189 let len = min (remaining b) len in
190 ({ b with pos_end = b.pos + len }, { b with pos = b.pos + len })
191
192 let empty = { buf = Bytes.make 0 '?'; pos = 0; pos_end = 0 }
193
194 external get_raw_16 : Bytes.t -> int -> int = "%caml_bytes_get16u"
195 external get_raw_32 : Bytes.t -> int -> int32 = "%caml_bytes_get32u"
196 external get_raw_64 : Bytes.t -> int -> int64 = "%caml_bytes_get64u"
197
198 exception Underflow of int
199
200 let[@inline never] underflow b = Underflow b.pos
201
202 let[@inline always] get_8 b =
203 let pos = b.pos in
204 let pos' = b.pos + 1 in
205 if pos' > b.pos_end then raise (underflow b);
206 b.pos <- pos';
207 Char.code (Bytes.unsafe_get b.buf pos)
208
209 let[@inline always] get_16 b =
210 let pos = b.pos in
211 let pos' = b.pos + 2 in
212 if pos' > b.pos_end then raise (underflow b);
213 b.pos <- pos';
214 if Sys.big_endian then bswap_16 (get_raw_16 b.buf pos)
215 else get_raw_16 b.buf pos
216
217 let[@inline always] get_32 b =
218 let pos = b.pos in
219 let pos' = b.pos + 4 in
220 if pos' > b.pos_end then raise (underflow b);
221 b.pos <- pos';
222 if Sys.big_endian then bswap_32 (get_raw_32 b.buf pos)
223 else get_raw_32 b.buf pos
224
225 let[@inline always] get_64 b =
226 let pos = b.pos in
227 let pos' = b.pos + 8 in
228 if pos' > b.pos_end then raise (underflow b);
229 b.pos <- pos';
230 if Sys.big_endian then bswap_64 (get_raw_64 b.buf pos)
231 else get_raw_64 b.buf pos
232
233 let[@inline always] get_float b = Int64.float_of_bits (get_64 b)
234
235 let get_string b =
236 let start = b.pos in
237 while get_8 b <> 0 do
238 ()
239 done;
240 let len = b.pos - 1 - start in
241 Bytes.sub_string b.buf start len
242
243 let[@inline never] get_vint_big b c =
244 match c with
245 | 253 -> get_16 b
246 | 254 -> Int32.to_int (get_32 b)
247 | 255 -> Int64.to_int (get_64 b)
248 | _ -> assert false
249
250 let[@inline always] get_vint b =
251 match get_8 b with c when c < 253 -> c | c -> get_vint_big b c
252end
253
254let () =
255 (Printexc.register_printer [@ocaml.alert "-unsafe_multidomain"]) (function
256 | Write.Overflow n -> Some ("Buffer overflow at position " ^ string_of_int n)
257 | Read.Underflow n ->
258 Some ("Buffer underflow at position " ^ string_of_int n)
259 | _ -> None)