upstream: https://github.com/janestreet/memtrace
at main 259 lines 7.3 kB view raw
1module Shared_writer_fd = struct 2 type t = { lock : Mutex.t; closed : bool Atomic.t; fd : Unix.file_descr } 3 4 let make fd = { lock = Mutex.create (); closed = Atomic.make false; fd } 5 6 exception Closed 7 8 let rec write_fully fd buf ~pos ~len = 9 if len = 0 then () 10 else 11 let written = Unix.write fd buf pos len in 12 write_fully fd buf ~pos:(pos + written) ~len:(len - written) 13 14 let write_fully t buf ~pos ~len = 15 Mutex.lock t.lock; 16 Fun.protect 17 (fun () -> 18 if Atomic.get t.closed then raise Closed; 19 write_fully t.fd buf ~pos ~len) 20 ~finally:(fun () -> Mutex.unlock t.lock) 21 22 let close t = 23 Mutex.lock t.lock; 24 Atomic.set t.closed true; 25 Mutex.unlock t.lock 26end 27 28module Shared = struct 29 type t = { buf : Bytes.t; mutable pos : int; pos_end : int } 30 31 let of_bytes buf = { buf; pos = 0; pos_end = Bytes.length buf } 32 let of_bytes_sub buf ~pos ~pos_end = { buf; pos; pos_end } 33 let remaining b = b.pos_end - b.pos 34 35 external bswap_16 : int -> int = "%bswap16" 36 external bswap_32 : int32 -> int32 = "%bswap_int32" 37 external bswap_64 : int64 -> int64 = "%bswap_int64" 38end 39 40module Write = struct 41 include Shared 42 43 let write_fd fd b = Shared_writer_fd.write_fully fd b.buf ~pos:0 ~len:b.pos 44 let put_raw_8 b i v = Bytes.unsafe_set b i (Char.unsafe_chr v) 45 46 external put_raw_16 : Bytes.t -> int -> int -> unit = "%caml_bytes_set16u" 47 external put_raw_32 : Bytes.t -> int -> int32 -> unit = "%caml_bytes_set32u" 48 external put_raw_64 : Bytes.t -> int -> int64 -> unit = "%caml_bytes_set64u" 49 50 exception Overflow of int 51 52 let[@inline never] overflow b = Overflow b.pos 53 54 let[@inline always] put_8 b v = 55 let pos = b.pos in 56 let pos' = b.pos + 1 in 57 if pos' > b.pos_end then raise (overflow b) 58 else ( 59 put_raw_8 b.buf pos v; 60 b.pos <- pos') 61 62 let[@inline always] put_16 b v = 63 let pos = b.pos in 64 let pos' = b.pos + 2 in 65 if pos' > b.pos_end then raise (overflow b) 66 else ( 67 put_raw_16 b.buf pos (if Sys.big_endian then bswap_16 v else v); 68 b.pos <- pos') 69 70 let[@inline always] put_32 b v = 71 let pos = b.pos in 72 let pos' = b.pos + 4 in 73 if pos' > b.pos_end then raise (overflow b) 74 else ( 75 put_raw_32 b.buf pos (if Sys.big_endian then bswap_32 v else v); 76 b.pos <- pos') 77 78 let[@inline always] put_64 b v = 79 let pos = b.pos in 80 let pos' = b.pos + 8 in 81 if pos' > b.pos_end then raise (overflow b) 82 else ( 83 put_raw_64 b.buf pos (if Sys.big_endian then bswap_64 v else v); 84 b.pos <- pos') 85 86 let[@inline always] put_float b f = put_64 b (Int64.bits_of_float f) 87 88 let put_string b s = 89 let slen = 90 match String.index_opt s '\000' with 91 | Some i -> i 92 | None -> String.length s 93 in 94 if b.pos + slen + 1 > b.pos_end then raise (overflow b); 95 Bytes.blit_string s 0 b.buf b.pos slen; 96 Bytes.unsafe_set b.buf (b.pos + slen) '\000'; 97 b.pos <- b.pos + slen + 1 98 99 let[@inline never] put_vint_big b v = 100 if v = v land 0xffff then ( 101 put_8 b 253; 102 put_16 b v) 103 else if v = Int32.to_int (Int32.of_int v) then ( 104 put_8 b 254; 105 put_32 b (Int32.of_int v)) 106 else ( 107 put_8 b 255; 108 put_64 b (Int64.of_int v)) 109 110 let[@inline always] put_vint b v = 111 if 0 <= v && v <= 252 then put_8 b v else put_vint_big b v 112 113 type position_8 = int 114 type position_16 = int 115 type position_32 = int 116 type position_64 = int 117 type position_float = int 118 119 let[@inline always] skip_8 b = 120 let pos = b.pos in 121 let pos' = b.pos + 1 in 122 if pos' > b.pos_end then raise (overflow b); 123 b.pos <- pos'; 124 pos 125 126 let[@inline always] skip_16 b = 127 let pos = b.pos in 128 let pos' = b.pos + 2 in 129 if pos' > b.pos_end then raise (overflow b); 130 b.pos <- pos'; 131 pos 132 133 let[@inline always] skip_32 b = 134 let pos = b.pos in 135 let pos' = b.pos + 4 in 136 if pos' > b.pos_end then raise (overflow b); 137 b.pos <- pos'; 138 pos 139 140 let[@inline always] skip_64 b = 141 let pos = b.pos in 142 let pos' = b.pos + 8 in 143 if pos' > b.pos_end then raise (overflow b); 144 b.pos <- pos'; 145 pos 146 147 let skip_float = skip_64 148 149 let update_8 b pos v = 150 assert (pos + 1 <= b.pos_end); 151 put_raw_8 b.buf pos v 152 153 let update_16 b pos v = 154 assert (pos + 2 <= b.pos_end); 155 put_raw_16 b.buf pos v 156 157 let update_32 b pos v = 158 assert (pos + 4 <= b.pos_end); 159 put_raw_32 b.buf pos v 160 161 let update_64 b pos v = 162 assert (pos + 8 <= b.pos_end); 163 put_raw_64 b.buf pos v 164 165 let update_float b pos f = update_64 b pos (Int64.bits_of_float f) 166end 167 168module Read = struct 169 include Shared 170 171 let rec read_into fd buf off = 172 if off = Bytes.length buf then { buf; pos = 0; pos_end = off } 173 else ( 174 assert (0 <= off && off <= Bytes.length buf); 175 let n = Unix.read fd buf off (Bytes.length buf - off) in 176 if n = 0 then (* EOF *) 177 { buf; pos = 0; pos_end = off } 178 else (* Short read *) 179 read_into fd buf (off + n)) 180 181 let read_fd fd buf = read_into fd buf 0 182 183 let refill_fd fd b = 184 let len = remaining b in 185 Bytes.blit b.buf b.pos b.buf 0 len; 186 read_into fd b.buf len 187 188 let split b len = 189 let len = min (remaining b) len in 190 ({ b with pos_end = b.pos + len }, { b with pos = b.pos + len }) 191 192 let empty = { buf = Bytes.make 0 '?'; pos = 0; pos_end = 0 } 193 194 external get_raw_16 : Bytes.t -> int -> int = "%caml_bytes_get16u" 195 external get_raw_32 : Bytes.t -> int -> int32 = "%caml_bytes_get32u" 196 external get_raw_64 : Bytes.t -> int -> int64 = "%caml_bytes_get64u" 197 198 exception Underflow of int 199 200 let[@inline never] underflow b = Underflow b.pos 201 202 let[@inline always] get_8 b = 203 let pos = b.pos in 204 let pos' = b.pos + 1 in 205 if pos' > b.pos_end then raise (underflow b); 206 b.pos <- pos'; 207 Char.code (Bytes.unsafe_get b.buf pos) 208 209 let[@inline always] get_16 b = 210 let pos = b.pos in 211 let pos' = b.pos + 2 in 212 if pos' > b.pos_end then raise (underflow b); 213 b.pos <- pos'; 214 if Sys.big_endian then bswap_16 (get_raw_16 b.buf pos) 215 else get_raw_16 b.buf pos 216 217 let[@inline always] get_32 b = 218 let pos = b.pos in 219 let pos' = b.pos + 4 in 220 if pos' > b.pos_end then raise (underflow b); 221 b.pos <- pos'; 222 if Sys.big_endian then bswap_32 (get_raw_32 b.buf pos) 223 else get_raw_32 b.buf pos 224 225 let[@inline always] get_64 b = 226 let pos = b.pos in 227 let pos' = b.pos + 8 in 228 if pos' > b.pos_end then raise (underflow b); 229 b.pos <- pos'; 230 if Sys.big_endian then bswap_64 (get_raw_64 b.buf pos) 231 else get_raw_64 b.buf pos 232 233 let[@inline always] get_float b = Int64.float_of_bits (get_64 b) 234 235 let get_string b = 236 let start = b.pos in 237 while get_8 b <> 0 do 238 () 239 done; 240 let len = b.pos - 1 - start in 241 Bytes.sub_string b.buf start len 242 243 let[@inline never] get_vint_big b c = 244 match c with 245 | 253 -> get_16 b 246 | 254 -> Int32.to_int (get_32 b) 247 | 255 -> Int64.to_int (get_64 b) 248 | _ -> assert false 249 250 let[@inline always] get_vint b = 251 match get_8 b with c when c < 253 -> c | c -> get_vint_big b c 252end 253 254let () = 255 (Printexc.register_printer [@ocaml.alert "-unsafe_multidomain"]) (function 256 | Write.Overflow n -> Some ("Buffer overflow at position " ^ string_of_int n) 257 | Read.Underflow n -> 258 Some ("Buffer underflow at position " ^ string_of_int n) 259 | _ -> None)