a database layer insipred by caqti and ecto
1type config = { batch_size : int; max_rows : int option }
2
3let default_config = { batch_size = 1000; max_rows = None }
4
5module type STREAM = sig
6 type connection
7 type error
8
9 val fold :
10 connection ->
11 string ->
12 params:Driver.Value.t array ->
13 init:'acc ->
14 f:('acc -> Driver.row -> 'acc) ->
15 ('acc, error) result
16
17 val iter :
18 connection ->
19 string ->
20 params:Driver.Value.t array ->
21 f:(Driver.row -> unit) ->
22 (unit, error) result
23
24 val fold_map :
25 connection ->
26 string ->
27 params:Driver.Value.t array ->
28 f:(Driver.row -> 'a) ->
29 ('a list, error) result
30
31 val cursor_fold :
32 connection ->
33 config:config ->
34 string ->
35 params:Driver.Value.t array ->
36 init:'acc ->
37 f:('acc -> Driver.row -> 'acc) ->
38 ('acc, error) result
39
40 val cursor_iter :
41 connection ->
42 config:config ->
43 string ->
44 params:Driver.Value.t array ->
45 f:(Driver.row -> unit) ->
46 (unit, error) result
47end
48
49module Make (D : Driver.S) :
50 STREAM with type connection = D.connection and type error = D.error = struct
51 type connection = D.connection
52 type error = D.error
53
54 let fold conn sql ~params ~init ~f = D.query_fold conn sql ~params ~init ~f
55 let iter conn sql ~params ~f = D.query_iter conn sql ~params ~f
56
57 let fold_map conn sql ~params ~f =
58 D.query_fold conn sql ~params ~init:[] ~f:(fun acc row -> f row :: acc)
59 |> Result.map List.rev
60
61 let cursor_counter = Atomic.make 0
62
63 let generate_cursor_name () =
64 let n = Atomic.fetch_and_add cursor_counter 1 in
65 Printf.sprintf "repodb_cursor_%d" n
66
67 let cursor_fold conn ~config sql ~params ~init ~f =
68 let cursor_name = generate_cursor_name () in
69 let declare_sql =
70 Printf.sprintf "DECLARE %s CURSOR FOR %s" cursor_name sql
71 in
72 match D.exec conn "BEGIN" ~params:[||] with
73 | Error e -> Error e
74 | Ok () -> (
75 match D.exec conn declare_sql ~params with
76 | Error e ->
77 let _ = D.exec conn "ROLLBACK" ~params:[||] in
78 Error e
79 | Ok () ->
80 let rec fetch_loop acc remaining =
81 let limit =
82 match (config.max_rows, remaining) with
83 | None, _ -> config.batch_size
84 | Some max_r, None -> min config.batch_size max_r
85 | Some _, Some rem -> min config.batch_size rem
86 in
87 if limit <= 0 then Ok acc
88 else
89 let fetch_sql =
90 Printf.sprintf "FETCH %d FROM %s" limit cursor_name
91 in
92 match D.query conn fetch_sql ~params:[||] with
93 | Error e -> Error e
94 | Ok [] -> Ok acc
95 | Ok rows ->
96 let acc' = List.fold_left f acc rows in
97 let new_remaining =
98 match remaining with
99 | None -> None
100 | Some n -> Some (n - List.length rows)
101 in
102 if List.length rows < limit then Ok acc'
103 else fetch_loop acc' new_remaining
104 in
105 let result = fetch_loop init config.max_rows in
106 let close_sql = Printf.sprintf "CLOSE %s" cursor_name in
107 let _ = D.exec conn close_sql ~params:[||] in
108 let _ = D.exec conn "COMMIT" ~params:[||] in
109 result)
110
111 let cursor_iter conn ~config sql ~params ~f =
112 cursor_fold conn ~config sql ~params ~init:() ~f:(fun () row -> f row)
113end
114
115module Sync = struct
116 let fold ~rows ~init ~f = List.fold_left f init rows
117
118 let fold_idx ~rows ~init ~f =
119 let rec loop acc i = function
120 | [] -> acc
121 | row :: rest -> loop (f acc i row) (i + 1) rest
122 in
123 loop init 0 rows
124
125 let iter ~rows ~f = List.iter f rows
126
127 let iter_idx ~rows ~f =
128 let rec loop i = function
129 | [] -> ()
130 | row :: rest ->
131 f i row;
132 loop (i + 1) rest
133 in
134 loop 0 rows
135
136 let map ~rows ~f = List.map f rows
137 let filter_map ~rows ~f = List.filter_map f rows
138 let find ~rows ~f = List.find_opt f rows
139 let exists ~rows ~f = List.exists f rows
140 let for_all ~rows ~f = List.for_all f rows
141 let partition ~rows ~f = List.partition f rows
142
143 let take n rows =
144 let rec aux acc n = function
145 | [] -> List.rev acc
146 | _ when n <= 0 -> List.rev acc
147 | x :: rest -> aux (x :: acc) (n - 1) rest
148 in
149 aux [] n rows
150
151 let drop n rows =
152 let rec aux n = function
153 | [] -> []
154 | l when n <= 0 -> l
155 | _ :: rest -> aux (n - 1) rest
156 in
157 aux n rows
158
159 let chunks n rows =
160 let rec aux acc current count = function
161 | [] ->
162 if current = [] then List.rev acc
163 else List.rev (List.rev current :: acc)
164 | x :: rest ->
165 if count >= n then aux (List.rev current :: acc) [ x ] 1 rest
166 else aux acc (x :: current) (count + 1) rest
167 in
168 aux [] [] 0 rows
169
170 let to_seq rows = List.to_seq rows
171end
172
173module Seq = struct
174 let of_list = List.to_seq
175 let fold ~seq ~init ~f = Stdlib.Seq.fold_left f init seq
176 let iter ~seq ~f = Stdlib.Seq.iter f seq
177 let map ~seq ~f = Stdlib.Seq.map f seq
178 let filter ~seq ~f = Stdlib.Seq.filter f seq
179 let filter_map ~seq ~f = Stdlib.Seq.filter_map f seq
180
181 let take n seq =
182 let rec aux n seq () =
183 if n <= 0 then Stdlib.Seq.Nil
184 else
185 match seq () with
186 | Stdlib.Seq.Nil -> Stdlib.Seq.Nil
187 | Stdlib.Seq.Cons (x, rest) -> Stdlib.Seq.Cons (x, aux (n - 1) rest)
188 in
189 aux n seq
190
191 let drop n seq =
192 let rec aux n seq =
193 if n <= 0 then seq
194 else
195 match seq () with
196 | Stdlib.Seq.Nil -> Stdlib.Seq.empty
197 | Stdlib.Seq.Cons (_, rest) -> aux (n - 1) rest
198 in
199 aux n seq
200
201 let chunks n seq =
202 let rec take_chunk acc count seq =
203 if count >= n then (List.rev acc, seq)
204 else
205 match seq () with
206 | Stdlib.Seq.Nil -> (List.rev acc, Stdlib.Seq.empty)
207 | Stdlib.Seq.Cons (x, rest) -> take_chunk (x :: acc) (count + 1) rest
208 in
209 let rec aux seq () =
210 match seq () with
211 | Stdlib.Seq.Nil -> Stdlib.Seq.Nil
212 | _ ->
213 let chunk, rest = take_chunk [] 0 seq in
214 if chunk = [] then Stdlib.Seq.Nil
215 else Stdlib.Seq.Cons (chunk, aux rest)
216 in
217 aux seq
218
219 let find ~seq ~f =
220 let rec aux seq =
221 match seq () with
222 | Stdlib.Seq.Nil -> None
223 | Stdlib.Seq.Cons (x, rest) -> if f x then Some x else aux rest
224 in
225 aux seq
226
227 let exists ~seq ~f =
228 let rec aux seq =
229 match seq () with
230 | Stdlib.Seq.Nil -> false
231 | Stdlib.Seq.Cons (x, rest) -> f x || aux rest
232 in
233 aux seq
234
235 let for_all ~seq ~f =
236 let rec aux seq =
237 match seq () with
238 | Stdlib.Seq.Nil -> true
239 | Stdlib.Seq.Cons (x, rest) -> f x && aux rest
240 in
241 aux seq
242end
243
244let fold_raw ~query_fold ~sql ~init ~f = query_fold sql ~init ~f
245let iter_raw ~query_iter ~sql ~f = query_iter sql ~f