+2
bench/client/bench_client.ml
+2
bench/client/bench_client.ml
···
141
141
| Hcs.Websocket.Connection_closed -> "Connection closed"
142
142
| Hcs.Websocket.Protocol_error s -> "Protocol error: " ^ s
143
143
| Hcs.Websocket.Io_error s -> "IO error: " ^ s
144
+
| Hcs.Websocket.Payload_too_large n ->
145
+
"Payload too large: " ^ string_of_int n ^ " bytes"
144
146
in
145
147
Printf.eprintf "WebSocket connect error: %s\n%!" err_msg
146
148
| Ok ws ->
+4
bin/hc.ml
+4
bin/hc.ml
···
44
44
| Websocket.Connection_closed -> "Connection closed"
45
45
| Websocket.Protocol_error s -> "Protocol error: " ^ s
46
46
| Websocket.Io_error s -> "IO error: " ^ s
47
+
| Websocket.Payload_too_large n ->
48
+
"Payload too large: " ^ string_of_int n ^ " bytes"
47
49
in
48
50
Printf.eprintf "WebSocket error: %s\n" msg;
49
51
exit 1
···
74
76
| Error (Websocket.Protocol_error s) ->
75
77
Printf.eprintf "Protocol error: %s\n" s
76
78
| Error (Websocket.Io_error s) -> Printf.eprintf "IO error: %s\n" s
79
+
| Error (Websocket.Payload_too_large n) ->
80
+
Printf.eprintf "Payload too large: %d bytes\n" n
77
81
in
78
82
(* Only enter receive loop if we sent a message (otherwise just test connection) *)
79
83
if Option.is_some ws_message then recv_loop ();
+3
-2
lib/dune
+3
-2
lib/dune
+205
-98
lib/h1_server.ml
+205
-98
lib/h1_server.ml
···
12
12
13
13
(** {1 Read Buffer Pool} *)
14
14
15
+
(** Lock-free buffer pool using Treiber stack via Kcas for thread-safe pooling.
16
+
*)
15
17
module Read_buffer_pool : sig
16
18
val acquire : unit -> Bigstringaf.t * Cstruct.t
17
19
val release : Bigstringaf.t -> unit
18
20
end = struct
19
21
let buffer_size = 0x4000
22
+
let max_pooled = 256
23
+
let pool : Bigstringaf.t list Kcas.Loc.t = Kcas.Loc.make []
24
+
let pool_size : int Kcas.Loc.t = Kcas.Loc.make 0
20
25
21
26
let acquire () =
22
-
let buf = Bigstringaf.create buffer_size in
27
+
let buf =
28
+
Kcas.Xt.commit
29
+
{
30
+
tx =
31
+
(fun ~xt ->
32
+
match Kcas.Xt.get ~xt pool with
33
+
| [] -> None
34
+
| buf :: rest ->
35
+
Kcas.Xt.set ~xt pool rest;
36
+
Kcas.Xt.set ~xt pool_size (Kcas.Xt.get ~xt pool_size - 1);
37
+
Some buf);
38
+
}
39
+
in
40
+
let buf =
41
+
match buf with Some b -> b | None -> Bigstringaf.create buffer_size
42
+
in
23
43
(buf, Cstruct.of_bigarray buf ~off:0 ~len:buffer_size)
24
44
25
-
let release _ = ()
45
+
let release buf =
46
+
Kcas.Xt.commit
47
+
{
48
+
tx =
49
+
(fun ~xt ->
50
+
let size = Kcas.Xt.get ~xt pool_size in
51
+
if size < max_pooled then begin
52
+
Kcas.Xt.set ~xt pool (buf :: Kcas.Xt.get ~xt pool);
53
+
Kcas.Xt.set ~xt pool_size (size + 1)
54
+
end);
55
+
}
26
56
end
27
57
28
58
(** {1 Configuration} *)
···
292
322
| `GET | `HEAD | `DELETE | `OPTIONS | `CONNECT | `TRACE -> true
293
323
| `POST | `PUT | `PATCH | `Other _ -> false
294
324
295
-
(** Create a lazy body reader from H1's body reader *)
296
-
let make_body_reader (h1_body : H1.Body.Reader.t) : body_reader =
325
+
exception Body_too_large
326
+
327
+
let make_body_reader ?max_body_size (h1_body : H1.Body.Reader.t) : body_reader =
297
328
let read_called = ref false in
298
329
let closed = ref false in
299
-
let chunks = ref [] in
330
+
let body_buffer = Buffer.create 4096 in
331
+
let current_size = ref 0 in
332
+
let too_large = ref false in
300
333
let done_promise, done_resolver = Eio.Promise.create () in
301
334
302
-
(* Start reading in background - will block until first chunk or EOF *)
335
+
let check_size len =
336
+
match max_body_size with
337
+
| Some max when Int64.of_int (!current_size + len) > max ->
338
+
too_large := true;
339
+
H1.Body.Reader.close h1_body;
340
+
Eio.Promise.resolve done_resolver ();
341
+
false
342
+
| _ ->
343
+
current_size := !current_size + len;
344
+
true
345
+
in
346
+
303
347
let rec schedule_read () =
304
-
if not !closed then
348
+
if (not !closed) && not !too_large then
305
349
H1.Body.Reader.schedule_read h1_body
306
350
~on_eof:(fun () -> Eio.Promise.resolve done_resolver ())
307
351
~on_read:(fun buf ~off ~len ->
308
-
(* Store chunk as Cstruct to avoid copying bigstring *)
309
-
let chunk = Cstruct.of_bigarray ~off ~len buf |> Cstruct.to_string in
310
-
chunks := chunk :: !chunks;
311
-
schedule_read ())
352
+
if check_size len then begin
353
+
Buffer.add_string body_buffer (Bigstringaf.substring buf ~off ~len);
354
+
schedule_read ()
355
+
end)
312
356
in
313
357
314
358
{
···
321
365
else begin
322
366
schedule_read ();
323
367
Eio.Promise.await done_promise;
324
-
String.concat "" (List.rev !chunks)
368
+
if !too_large then raise Body_too_large
369
+
else Buffer.contents body_buffer
325
370
end
326
371
end);
327
372
read_stream =
328
373
(fun () ->
329
-
(* For streaming, we read one chunk at a time *)
330
-
if !closed then None
374
+
if !closed || !too_large then None
331
375
else begin
332
376
let chunk_promise, chunk_resolver = Eio.Promise.create () in
333
377
let got_chunk = ref false in
···
336
380
if not !got_chunk then Eio.Promise.resolve chunk_resolver None)
337
381
~on_read:(fun buf ~off ~len ->
338
382
got_chunk := true;
339
-
let cs = Cstruct.of_bigarray ~off ~len buf in
340
-
Eio.Promise.resolve chunk_resolver (Some cs));
383
+
if check_size len then begin
384
+
let cs = Cstruct.of_bigarray ~off ~len buf in
385
+
Eio.Promise.resolve chunk_resolver (Some cs)
386
+
end
387
+
else Eio.Promise.resolve chunk_resolver None);
341
388
Eio.Promise.await chunk_promise
342
389
end);
343
390
close =
···
356
403
close = (fun () -> ());
357
404
}
358
405
359
-
let handle_connection handler flow =
406
+
let respond_413 reqd =
407
+
let body = "Request Entity Too Large" in
408
+
let headers =
409
+
H1.Headers.of_list
410
+
[
411
+
("Date", Date_cache.get ());
412
+
("Content-Length", string_of_int (String.length body));
413
+
("Connection", "close");
414
+
]
415
+
in
416
+
let resp = H1.Response.create ~headers (`Code 413) in
417
+
H1.Reqd.respond_with_string reqd resp body
418
+
419
+
let respond_408 reqd =
420
+
let body = "Request Timeout" in
421
+
let headers =
422
+
H1.Headers.of_list
423
+
[
424
+
("Date", Date_cache.get ());
425
+
("Content-Length", string_of_int (String.length body));
426
+
("Connection", "close");
427
+
]
428
+
in
429
+
let resp = H1.Response.create ~headers (`Code 408) in
430
+
H1.Reqd.respond_with_string reqd resp body
431
+
432
+
let handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
433
+
handler flow =
360
434
let read_buffer, read_cstruct = Read_buffer_pool.acquire () in
361
435
Fun.protect ~finally:(fun () -> Read_buffer_pool.release read_buffer)
362
436
@@ fun () ->
···
364
438
let req = H1.Reqd.request reqd in
365
439
let h1_body = H1.Reqd.request_body reqd in
366
440
367
-
(* Create lazy body reader *)
368
441
let body_reader =
369
442
if method_has_no_body req.H1.Request.meth then begin
370
443
H1.Body.Reader.close h1_body;
371
444
empty_body_reader ()
372
445
end
373
-
else make_body_reader h1_body
446
+
else make_body_reader ?max_body_size h1_body
374
447
in
375
448
376
-
(* Build request with lazy body *)
377
449
let request =
378
450
{
379
451
meth = req.H1.Request.meth;
···
383
455
}
384
456
in
385
457
386
-
(* Call user handler *)
387
-
let response = handler request in
388
-
389
-
(* Ensure body is closed if not read *)
390
-
body_reader.close ();
391
-
392
-
let date_header = ("Date", Date_cache.get ()) in
393
-
let filter_reserved headers =
394
-
List.filter
395
-
(fun (k, _) ->
396
-
let lk = String.lowercase_ascii k in
397
-
lk <> "content-length" && lk <> "date")
398
-
headers
458
+
let handle_request () =
459
+
match (clock, request_timeout) with
460
+
| Some clock, Some timeout ->
461
+
Eio.Time.with_timeout clock timeout (fun () ->
462
+
let response = handler request in
463
+
body_reader.close ();
464
+
Ok response)
465
+
| _ ->
466
+
let response = handler request in
467
+
body_reader.close ();
468
+
Ok response
399
469
in
400
470
401
-
match response.response_body with
402
-
| Body_string body ->
403
-
let content_length = String.length body in
404
-
let headers =
405
-
H1.Headers.of_list
406
-
(date_header
407
-
:: ("Content-Length", string_of_int content_length)
408
-
:: filter_reserved response.headers)
471
+
match handle_request () with
472
+
| Error `Timeout ->
473
+
body_reader.close ();
474
+
respond_408 reqd
475
+
| exception Body_too_large ->
476
+
body_reader.close ();
477
+
respond_413 reqd
478
+
| Ok response -> (
479
+
let date_header = ("Date", Date_cache.get ()) in
480
+
let filter_reserved headers =
481
+
List.filter
482
+
(fun (k, _) ->
483
+
let lk = String.lowercase_ascii k in
484
+
lk <> "content-length" && lk <> "date")
485
+
headers
409
486
in
410
-
let resp = H1.Response.create ~headers response.status in
411
-
H1.Reqd.respond_with_string reqd resp body
412
-
| Body_bigstring bstr ->
413
-
let content_length = Bigstringaf.length bstr in
414
-
let headers =
415
-
H1.Headers.of_list
416
-
(date_header
417
-
:: ("Content-Length", string_of_int content_length)
418
-
:: filter_reserved response.headers)
419
-
in
420
-
let resp = H1.Response.create ~headers response.status in
421
-
H1.Reqd.respond_with_bigstring reqd resp bstr
422
-
| Body_prebuilt { h1_response; body } ->
423
-
let headers =
424
-
H1.Headers.add h1_response.headers "Date" (Date_cache.get ())
425
-
in
426
-
let resp = { h1_response with H1.Response.headers } in
427
-
H1.Reqd.respond_with_bigstring reqd resp body
428
-
| Body_cached_prebuilt cached ->
429
-
let resp = get_cached_response cached in
430
-
H1.Reqd.respond_with_bigstring reqd resp cached.body
431
-
| Body_stream { content_length; next } ->
432
-
let headers =
433
-
match content_length with
434
-
| Some len ->
487
+
488
+
match response.response_body with
489
+
| Body_string body ->
490
+
let content_length = String.length body in
491
+
let headers =
435
492
H1.Headers.of_list
436
493
(date_header
437
-
:: ("Content-Length", Int64.to_string len)
494
+
:: ("Content-Length", string_of_int content_length)
438
495
:: filter_reserved response.headers)
439
-
| None ->
496
+
in
497
+
let resp = H1.Response.create ~headers response.status in
498
+
H1.Reqd.respond_with_string reqd resp body
499
+
| Body_bigstring bstr ->
500
+
let content_length = Bigstringaf.length bstr in
501
+
let headers =
440
502
H1.Headers.of_list
441
503
(date_header
442
-
:: ("Transfer-Encoding", "chunked")
504
+
:: ("Content-Length", string_of_int content_length)
443
505
:: filter_reserved response.headers)
444
-
in
445
-
let resp = H1.Response.create ~headers response.status in
446
-
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
447
-
let rec write_chunks () =
448
-
match next () with
449
-
| None -> H1.Body.Writer.close body_writer
450
-
| Some cs ->
451
-
H1.Body.Writer.write_bigstring body_writer ~off:0
452
-
~len:(Cstruct.length cs) (Cstruct.to_bigarray cs);
453
-
let flushed, resolve = Eio.Promise.create () in
454
-
H1.Body.Writer.flush body_writer (fun () ->
455
-
Eio.Promise.resolve resolve ());
456
-
Eio.Promise.await flushed;
457
-
write_chunks ()
458
-
in
459
-
write_chunks ()
506
+
in
507
+
let resp = H1.Response.create ~headers response.status in
508
+
H1.Reqd.respond_with_bigstring reqd resp bstr
509
+
| Body_prebuilt { h1_response; body } ->
510
+
let headers =
511
+
H1.Headers.add h1_response.headers "Date" (Date_cache.get ())
512
+
in
513
+
let resp = { h1_response with H1.Response.headers } in
514
+
H1.Reqd.respond_with_bigstring reqd resp body
515
+
| Body_cached_prebuilt cached ->
516
+
let resp = get_cached_response cached in
517
+
H1.Reqd.respond_with_bigstring reqd resp cached.body
518
+
| Body_stream { content_length; next } ->
519
+
let headers =
520
+
match content_length with
521
+
| Some len ->
522
+
H1.Headers.of_list
523
+
(date_header
524
+
:: ("Content-Length", Int64.to_string len)
525
+
:: filter_reserved response.headers)
526
+
| None ->
527
+
H1.Headers.of_list
528
+
(date_header
529
+
:: ("Transfer-Encoding", "chunked")
530
+
:: filter_reserved response.headers)
531
+
in
532
+
let resp = H1.Response.create ~headers response.status in
533
+
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
534
+
let rec write_chunks () =
535
+
match next () with
536
+
| None -> H1.Body.Writer.close body_writer
537
+
| Some cs ->
538
+
H1.Body.Writer.write_bigstring body_writer ~off:0
539
+
~len:(Cstruct.length cs) (Cstruct.to_bigarray cs);
540
+
let flushed, resolve = Eio.Promise.create () in
541
+
H1.Body.Writer.flush body_writer (fun () ->
542
+
Eio.Promise.resolve resolve ());
543
+
Eio.Promise.await flushed;
544
+
write_chunks ()
545
+
in
546
+
write_chunks ())
460
547
in
461
548
462
549
let error_handler ?request:_ _error start_response =
···
469
556
470
557
let shutdown = ref false in
471
558
559
+
let do_read () =
560
+
match (clock, read_timeout) with
561
+
| Some clock, Some timeout ->
562
+
Eio.Time.with_timeout clock timeout (fun () ->
563
+
Ok (Eio.Flow.single_read flow read_cstruct))
564
+
| _ -> Ok (Eio.Flow.single_read flow read_cstruct)
565
+
in
566
+
472
567
let rec read_loop () =
473
568
if not !shutdown then
474
569
match H1.Server_connection.next_read_operation conn with
475
570
| `Read -> (
476
-
try
477
-
let n = Eio.Flow.single_read flow read_cstruct in
478
-
let _ = H1.Server_connection.read conn read_buffer ~off:0 ~len:n in
479
-
read_loop ()
480
-
with End_of_file ->
481
-
let _ =
482
-
H1.Server_connection.read_eof conn read_buffer ~off:0 ~len:0
483
-
in
484
-
shutdown := true)
571
+
match do_read () with
572
+
| Error `Timeout -> shutdown := true
573
+
| exception End_of_file ->
574
+
let _ =
575
+
H1.Server_connection.read_eof conn read_buffer ~off:0 ~len:0
576
+
in
577
+
shutdown := true
578
+
| Ok n ->
579
+
let _ =
580
+
H1.Server_connection.read conn read_buffer ~off:0 ~len:n
581
+
in
582
+
read_loop ())
485
583
| `Yield -> H1.Server_connection.yield_reader conn read_loop
486
584
| `Close | `Upgrade -> shutdown := true
487
585
in
···
510
608
511
609
Fiber.both read_loop write_loop
512
610
513
-
let run ~sw ~net ?(config = default_config) handler =
611
+
let run ~sw ~net ?clock ?(config = default_config) handler =
514
612
ensure_gc_tuned ();
515
613
let addr = `Tcp (Eio.Net.Ipaddr.V4.any, config.port) in
516
614
let socket =
···
518
616
~reuse_port:config.reuse_port net addr
519
617
in
520
618
traceln "Server listening on port %d" config.port;
619
+
let max_body_size = config.max_body_size in
620
+
let read_timeout = Some config.read_timeout in
621
+
let request_timeout = Some config.request_timeout in
521
622
let connection_handler flow _addr =
522
623
if config.tcp_nodelay then set_tcp_nodelay flow;
523
-
handle_connection handler flow
624
+
handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
625
+
handler flow
524
626
in
525
627
let on_error exn = traceln "Connection error: %s" (Printexc.to_string exn) in
526
628
Eio.Net.run_server socket connection_handler
527
629
~max_connections:config.max_connections ~on_error
528
630
529
-
let run_parallel ~sw ~net ~domain_mgr ?(config = default_config) handler =
631
+
let run_parallel ~sw ~net ~domain_mgr ?clock ?(config = default_config) handler
632
+
=
530
633
ensure_gc_tuned ();
531
634
let domain_count = max 1 config.domain_count in
532
635
let addr = `Tcp (Eio.Net.Ipaddr.V4.any, config.port) in
···
535
638
~reuse_port:config.reuse_port net addr
536
639
in
537
640
traceln "Server listening on port %d (%d domains)" config.port domain_count;
641
+
let max_body_size = config.max_body_size in
642
+
let read_timeout = Some config.read_timeout in
643
+
let request_timeout = Some config.request_timeout in
538
644
let connection_handler flow _addr =
539
645
if config.tcp_nodelay then set_tcp_nodelay flow;
540
-
handle_connection handler flow
646
+
handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
647
+
handler flow
541
648
in
542
649
let on_error exn = traceln "Connection error: %s" (Printexc.to_string exn) in
543
650
if domain_count <= 1 then
+200
-80
lib/h2_server.ml
+200
-80
lib/h2_server.ml
···
60
60
let make_h2_response ?(status = `OK) headers =
61
61
H2.Response.create ~headers status
62
62
63
-
(** {1 Internal helpers} *)
63
+
(** Lock-free buffer pool using Treiber stack via Kcas for thread-safe pooling.
64
+
*)
65
+
module Read_buffer_pool : sig
66
+
val acquire : unit -> Bigstringaf.t * Cstruct.t
67
+
val release : Bigstringaf.t -> unit
68
+
end = struct
69
+
let buffer_size = 0x4000
70
+
let max_pooled = 256
71
+
let pool : Bigstringaf.t list Kcas.Loc.t = Kcas.Loc.make []
72
+
let pool_size : int Kcas.Loc.t = Kcas.Loc.make 0
73
+
74
+
let acquire () =
75
+
let buf =
76
+
Kcas.Xt.commit
77
+
{
78
+
tx =
79
+
(fun ~xt ->
80
+
match Kcas.Xt.get ~xt pool with
81
+
| [] -> None
82
+
| buf :: rest ->
83
+
Kcas.Xt.set ~xt pool rest;
84
+
Kcas.Xt.set ~xt pool_size (Kcas.Xt.get ~xt pool_size - 1);
85
+
Some buf);
86
+
}
87
+
in
88
+
let buf =
89
+
match buf with Some b -> b | None -> Bigstringaf.create buffer_size
90
+
in
91
+
(buf, Cstruct.of_bigarray buf ~off:0 ~len:buffer_size)
92
+
93
+
let release buf =
94
+
Kcas.Xt.commit
95
+
{
96
+
tx =
97
+
(fun ~xt ->
98
+
let size = Kcas.Xt.get ~xt pool_size in
99
+
if size < max_pooled then begin
100
+
Kcas.Xt.set ~xt pool (buf :: Kcas.Xt.get ~xt pool);
101
+
Kcas.Xt.set ~xt pool_size (size + 1)
102
+
end);
103
+
}
104
+
end
64
105
65
106
let set_tcp_nodelay flow =
66
107
match Eio_unix.Resource.fd_opt flow with
···
81
122
82
123
(** {1 Connection handling} *)
83
124
84
-
let handle_connection handler flow =
85
-
let read_buffer_size = 0x4000 in
86
-
let read_buffer = Bigstringaf.create read_buffer_size in
125
+
type body_result = Ok_body of string | Body_too_large | Missing_path
126
+
127
+
let respond_error reqd status body =
128
+
let headers =
129
+
H2.Headers.of_list
130
+
[
131
+
("content-length", string_of_int (String.length body));
132
+
("date", Date_cache.get ());
133
+
]
134
+
in
135
+
let resp = H2.Response.create ~headers status in
136
+
H2.Reqd.respond_with_string reqd resp body
87
137
138
+
let handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
139
+
handler flow =
140
+
let read_buffer, read_cstruct = Read_buffer_pool.acquire () in
141
+
Fun.protect ~finally:(fun () -> Read_buffer_pool.release read_buffer)
142
+
@@ fun () ->
88
143
let request_handler reqd =
89
144
let req = H2.Reqd.request reqd in
90
145
let body_reader = H2.Reqd.request_body reqd in
91
146
92
-
let body =
147
+
let body_result =
93
148
match req.meth with
94
149
| `GET | `HEAD ->
95
150
H2.Body.Reader.close body_reader;
96
-
""
151
+
Ok_body ""
97
152
| `POST | `PUT | `DELETE | `CONNECT | `OPTIONS | `TRACE | `Other _ ->
98
153
let body_buffer = Buffer.create 4096 in
154
+
let current_size = ref 0 in
155
+
let too_large = ref false in
99
156
let body_done_promise, body_done_resolver = Eio.Promise.create () in
100
157
let rec read_body () =
101
158
H2.Body.Reader.schedule_read body_reader
102
159
~on_eof:(fun () -> Eio.Promise.resolve body_done_resolver ())
103
160
~on_read:(fun buf ~off ~len ->
104
-
Buffer.add_string body_buffer
105
-
(Bigstringaf.substring buf ~off ~len);
106
-
read_body ())
161
+
let new_size = !current_size + len in
162
+
match max_body_size with
163
+
| Some max when Int64.of_int new_size > max ->
164
+
too_large := true;
165
+
H2.Body.Reader.close body_reader;
166
+
Eio.Promise.resolve body_done_resolver ()
167
+
| _ ->
168
+
current_size := new_size;
169
+
Buffer.add_string body_buffer
170
+
(Bigstringaf.substring buf ~off ~len);
171
+
read_body ())
107
172
in
108
173
read_body ();
109
174
Eio.Promise.await body_done_promise;
110
-
Buffer.contents body_buffer
175
+
if !too_large then Body_too_large
176
+
else Ok_body (Buffer.contents body_buffer)
111
177
in
112
178
113
-
let target =
114
-
match H2.Headers.get req.headers ":path" with Some p -> p | None -> "/"
115
-
in
116
-
117
-
let request = { meth = req.meth; target; headers = req.headers; body } in
118
-
let response = handler request in
119
-
120
-
match response.response_body with
121
-
| Body_string body ->
122
-
let headers =
123
-
H2.Headers.of_list
124
-
(("content-length", string_of_int (String.length body))
125
-
:: response.headers)
126
-
in
127
-
let resp = H2.Response.create ~headers response.status in
128
-
H2.Reqd.respond_with_string reqd resp body
129
-
| Body_bigstring bstr ->
130
-
let headers =
131
-
H2.Headers.of_list
132
-
(("content-length", string_of_int (Bigstringaf.length bstr))
133
-
:: response.headers)
134
-
in
135
-
let resp = H2.Response.create ~headers response.status in
136
-
H2.Reqd.respond_with_bigstring reqd resp bstr
137
-
| Body_prebuilt { h2_response; body } ->
138
-
H2.Reqd.respond_with_bigstring reqd h2_response body
139
-
| Body_stream { content_length; next } ->
140
-
let headers =
141
-
match content_length with
142
-
| Some len ->
143
-
H2.Headers.of_list
144
-
(("content-length", Int64.to_string len) :: response.headers)
145
-
| None -> H2.Headers.of_list response.headers
146
-
in
147
-
let resp = H2.Response.create ~headers response.status in
148
-
let body_writer = H2.Reqd.respond_with_streaming reqd resp in
149
-
let rec write_chunks () =
150
-
match next () with
151
-
| None -> H2.Body.Writer.close body_writer
152
-
| Some cs ->
153
-
H2.Body.Writer.write_bigstring body_writer ~off:0
154
-
~len:(Cstruct.length cs) (Cstruct.to_bigarray cs);
155
-
let flushed, resolve = Eio.Promise.create () in
156
-
H2.Body.Writer.flush body_writer (fun _result ->
157
-
Eio.Promise.resolve resolve ());
158
-
Eio.Promise.await flushed;
159
-
write_chunks ()
179
+
match body_result with
180
+
| Body_too_large ->
181
+
respond_error reqd (`Code 413) "Request Entity Too Large"
182
+
| Missing_path -> respond_error reqd `Bad_request "Missing :path header"
183
+
| Ok_body body -> (
184
+
let target =
185
+
match H2.Headers.get req.headers ":path" with
186
+
| Some p -> Some p
187
+
| None -> None
160
188
in
161
-
write_chunks ()
189
+
match target with
190
+
| None -> respond_error reqd `Bad_request "Missing :path header"
191
+
| Some target -> (
192
+
let request =
193
+
{ meth = req.meth; target; headers = req.headers; body }
194
+
in
195
+
let handler_result =
196
+
match (clock, request_timeout) with
197
+
| Some clock, Some timeout ->
198
+
Eio.Time.with_timeout clock timeout (fun () ->
199
+
Ok (handler request))
200
+
| _ -> Ok (handler request)
201
+
in
202
+
match handler_result with
203
+
| Error `Timeout -> respond_error reqd (`Code 408) "Request Timeout"
204
+
| Ok response -> (
205
+
let date_header = ("date", Date_cache.get ()) in
206
+
match response.response_body with
207
+
| Body_string body ->
208
+
let headers =
209
+
H2.Headers.of_list
210
+
(date_header
211
+
:: ("content-length", string_of_int (String.length body))
212
+
:: response.headers)
213
+
in
214
+
let resp = H2.Response.create ~headers response.status in
215
+
H2.Reqd.respond_with_string reqd resp body
216
+
| Body_bigstring bstr ->
217
+
let headers =
218
+
H2.Headers.of_list
219
+
(date_header
220
+
:: ( "content-length",
221
+
string_of_int (Bigstringaf.length bstr) )
222
+
:: response.headers)
223
+
in
224
+
let resp = H2.Response.create ~headers response.status in
225
+
H2.Reqd.respond_with_bigstring reqd resp bstr
226
+
| Body_prebuilt { h2_response; body } ->
227
+
let headers =
228
+
H2.Headers.add h2_response.H2.Response.headers "date"
229
+
(Date_cache.get ())
230
+
in
231
+
let resp = { h2_response with H2.Response.headers } in
232
+
H2.Reqd.respond_with_bigstring reqd resp body
233
+
| Body_stream { content_length; next } ->
234
+
let headers =
235
+
match content_length with
236
+
| Some len ->
237
+
H2.Headers.of_list
238
+
(date_header
239
+
:: ("content-length", Int64.to_string len)
240
+
:: response.headers)
241
+
| None ->
242
+
H2.Headers.of_list (date_header :: response.headers)
243
+
in
244
+
let resp = H2.Response.create ~headers response.status in
245
+
let body_writer =
246
+
H2.Reqd.respond_with_streaming reqd resp
247
+
in
248
+
let rec write_chunks () =
249
+
match next () with
250
+
| None -> H2.Body.Writer.close body_writer
251
+
| Some cs ->
252
+
H2.Body.Writer.write_bigstring body_writer ~off:0
253
+
~len:(Cstruct.length cs) (Cstruct.to_bigarray cs);
254
+
let flushed, resolve = Eio.Promise.create () in
255
+
H2.Body.Writer.flush body_writer (fun _result ->
256
+
Eio.Promise.resolve resolve ());
257
+
Eio.Promise.await flushed;
258
+
write_chunks ()
259
+
in
260
+
write_chunks ())))
162
261
in
163
262
164
263
let error_handler ?request:_ _error start_response =
165
-
let resp_body = start_response H2.Headers.empty in
264
+
let resp_body =
265
+
start_response (H2.Headers.of_list [ ("date", Date_cache.get ()) ])
266
+
in
166
267
H2.Body.Writer.write_string resp_body "Internal Server Error";
167
268
H2.Body.Writer.close resp_body
168
269
in
···
171
272
172
273
let shutdown = ref false in
173
274
275
+
let do_read () =
276
+
match (clock, read_timeout) with
277
+
| Some clock, Some timeout ->
278
+
Eio.Time.with_timeout clock timeout (fun () ->
279
+
Ok (Eio.Flow.single_read flow read_cstruct))
280
+
| _ -> Ok (Eio.Flow.single_read flow read_cstruct)
281
+
in
282
+
174
283
let read_loop () =
175
284
let rec loop () =
176
285
if not !shutdown then
177
286
match H2.Server_connection.next_read_operation conn with
178
287
| `Read -> (
179
-
let cs =
180
-
Cstruct.of_bigarray read_buffer ~off:0 ~len:read_buffer_size
181
-
in
182
-
try
183
-
let n = Eio.Flow.single_read flow cs in
184
-
let _ =
185
-
H2.Server_connection.read conn read_buffer ~off:0 ~len:n
186
-
in
187
-
loop ()
188
-
with End_of_file ->
189
-
let _ =
190
-
H2.Server_connection.read_eof conn read_buffer ~off:0 ~len:0
191
-
in
192
-
shutdown := true)
288
+
match do_read () with
289
+
| Error `Timeout -> shutdown := true
290
+
| exception End_of_file ->
291
+
let _ =
292
+
H2.Server_connection.read_eof conn read_buffer ~off:0 ~len:0
293
+
in
294
+
shutdown := true
295
+
| Ok n ->
296
+
let _ =
297
+
H2.Server_connection.read conn read_buffer ~off:0 ~len:n
298
+
in
299
+
loop ())
193
300
| `Close -> shutdown := true
194
301
in
195
302
loop ()
···
221
328
222
329
(** {1 Public API} *)
223
330
224
-
let run ~sw ~net ?(config = H1_server.default_config) handler =
331
+
let run ~sw ~net ?clock ?(config = H1_server.default_config) handler =
225
332
let addr = `Tcp (Eio.Net.Ipaddr.V4.any, config.port) in
226
333
let socket =
227
334
Eio.Net.listen ~sw ~backlog:config.backlog ~reuse_addr:config.reuse_addr
228
335
~reuse_port:config.reuse_port net addr
229
336
in
230
337
traceln "HTTP/2 Server listening on port %d" config.port;
338
+
let max_body_size = config.max_body_size in
339
+
let read_timeout = Some config.read_timeout in
340
+
let request_timeout = Some config.request_timeout in
231
341
let connection_handler flow _addr =
232
342
if config.tcp_nodelay then set_tcp_nodelay flow;
233
-
handle_connection handler flow
343
+
handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
344
+
handler flow
234
345
in
235
346
let on_error exn = traceln "Connection error: %s" (Printexc.to_string exn) in
236
347
Eio.Net.run_server socket connection_handler
237
348
~max_connections:config.max_connections ~on_error
238
349
239
-
let run_tls ~sw ~net ?(config = H1_server.default_config) ~tls_config handler =
350
+
let run_tls ~sw ~net ?clock ?(config = H1_server.default_config) ~tls_config
351
+
handler =
240
352
let addr = `Tcp (Eio.Net.Ipaddr.V4.any, config.port) in
241
353
let socket =
242
354
Eio.Net.listen ~sw ~backlog:config.backlog ~reuse_addr:config.reuse_addr
243
355
~reuse_port:config.reuse_port net addr
244
356
in
245
357
traceln "HTTP/2 Server (TLS) listening on port %d" config.port;
358
+
let max_body_size = config.max_body_size in
359
+
let read_timeout = Some config.read_timeout in
360
+
let request_timeout = Some config.request_timeout in
246
361
let connection_handler flow _addr =
247
362
if config.tcp_nodelay then set_tcp_nodelay flow;
248
363
match Tls_config.Server.to_tls_config tls_config with
···
250
365
| Ok tls_cfg -> (
251
366
try
252
367
let tls_flow = Tls_eio.server_of_flow tls_cfg flow in
253
-
handle_connection handler tls_flow
368
+
handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
369
+
handler tls_flow
254
370
with
255
371
| Tls_eio.Tls_failure failure ->
256
372
traceln "TLS error: %s" (Tls_config.failure_to_string failure)
···
260
376
Eio.Net.run_server socket connection_handler
261
377
~max_connections:config.max_connections ~on_error
262
378
263
-
let run_parallel ~sw ~net ~domain_mgr ?(config = H1_server.default_config)
264
-
handler =
379
+
let run_parallel ~sw ~net ~domain_mgr ?clock
380
+
?(config = H1_server.default_config) handler =
265
381
let domain_count = max 1 config.domain_count in
266
382
let addr = `Tcp (Eio.Net.Ipaddr.V4.any, config.port) in
267
383
let socket =
···
270
386
in
271
387
traceln "HTTP/2 Server listening on port %d (%d domains)" config.port
272
388
domain_count;
389
+
let max_body_size = config.max_body_size in
390
+
let read_timeout = Some config.read_timeout in
391
+
let request_timeout = Some config.request_timeout in
273
392
let connection_handler flow _addr =
274
393
if config.tcp_nodelay then set_tcp_nodelay flow;
275
-
handle_connection handler flow
394
+
handle_connection ?clock ?read_timeout ?request_timeout ?max_body_size
395
+
handler flow
276
396
in
277
397
let on_error exn = traceln "Connection error: %s" (Printexc.to_string exn) in
278
398
if domain_count <= 1 then
+14
-2
lib/plug/basic_auth.ml
+14
-2
lib/plug/basic_auth.ml
···
2
2
3
3
Implements RFC 7617 Basic authentication. *)
4
4
5
+
let secure_compare a b =
6
+
let len_a = String.length a in
7
+
let len_b = String.length b in
8
+
if len_a <> len_b then false
9
+
else
10
+
let result = ref 0 in
11
+
for i = 0 to len_a - 1 do
12
+
result := !result lor (Char.code a.[i] lxor Char.code b.[i])
13
+
done;
14
+
!result = 0
15
+
5
16
(** Decode base64 credentials from Authorization header *)
6
17
let decode_credentials auth_header =
7
18
let prefix = "Basic " in
···
65
76
@param username Expected username
66
77
@param password Expected password *)
67
78
let create_static ~realm ~username ~password : Core.t =
68
-
create ~realm ~validate:(fun u p -> u = username && p = password)
79
+
create ~realm ~validate:(fun u p ->
80
+
secure_compare u username && secure_compare p password)
69
81
70
82
(** Create basic auth plug with credential map.
71
83
···
75
87
=
76
88
create ~realm ~validate:(fun u p ->
77
89
match Hashtbl.find_opt credentials u with
78
-
| Some expected -> expected = p
90
+
| Some expected -> secure_compare expected p
79
91
| None -> false)
+1
-8
lib/plug/csrf.ml
+1
-8
lib/plug/csrf.ml
···
2
2
3
3
Implements Double Submit Cookie pattern for CSRF protection. *)
4
4
5
-
(** Generate a random CSRF token *)
6
-
let generate_token () =
7
-
let bytes = Stdlib.Bytes.create 32 in
8
-
for i = 0 to 31 do
9
-
Stdlib.Bytes.set bytes i (Char.chr (Random.int 256))
10
-
done;
11
-
Base64.encode_string ~alphabet:Base64.uri_safe_alphabet
12
-
(Stdlib.Bytes.to_string bytes)
5
+
let generate_token () = Secure_random.token ~bytes:32 ()
13
6
14
7
(** Extract token from cookie header *)
15
8
let get_cookie_token ~cookie_name (req : Server.request) : string option =
+1
-6
lib/plug/session.ml
+1
-6
lib/plug/session.ml
···
163
163
Buffer.add_string buf same_site;
164
164
Buffer.contents buf
165
165
166
-
let generate_id () =
167
-
let b = Bytes.create 16 in
168
-
for i = 0 to 15 do
169
-
Bytes.set b i (Char.chr (Random.int 256))
170
-
done;
171
-
Token.b64_encode (Bytes.unsafe_to_string b)
166
+
let generate_id () = Secure_random.token ~bytes:16 ()
172
167
173
168
(** Create session plug with configurable storage and cookie options. *)
174
169
let create ~store ?(cookie_name = "_session") ?(secure = true)
+1
-7
lib/plug/token.ml
+1
-7
lib/plug/token.ml
···
146
146
if String.length secret >= 32 then String.sub secret 0 32
147
147
else Digestif.SHA256.(digest_string secret |> to_raw_string)
148
148
in
149
-
let nonce =
150
-
let b = Bytes.create 12 in
151
-
for i = 0 to 11 do
152
-
Bytes.set b i (Char.chr (Random.int 256))
153
-
done;
154
-
Bytes.unsafe_to_string b
155
-
in
149
+
let nonce = Secure_random.generate 12 in
156
150
let key = Mirage_crypto.AES.GCM.of_secret aes_key in
157
151
let ciphertext =
158
152
Mirage_crypto.AES.GCM.authenticate_encrypt ~key ~nonce data
+59
lib/secure_random.ml
+59
lib/secure_random.ml
···
1
+
let initialized = ref false
2
+
3
+
let ensure_initialized () =
4
+
if not !initialized then begin
5
+
Mirage_crypto_rng_unix.use_default ();
6
+
initialized := true
7
+
end
8
+
9
+
let generate n =
10
+
ensure_initialized ();
11
+
Mirage_crypto_rng.generate n
12
+
13
+
let generate_bytes n = Bytes.of_string (Mirage_crypto_rng.generate n)
14
+
15
+
let int bound =
16
+
if bound <= 0 then invalid_arg "Secure_random.int: bound must be positive";
17
+
let rec sample () =
18
+
let s = Mirage_crypto_rng.generate 8 in
19
+
let v =
20
+
(Char.code s.[0] lsl 56)
21
+
lor (Char.code s.[1] lsl 48)
22
+
lor (Char.code s.[2] lsl 40)
23
+
lor (Char.code s.[3] lsl 32)
24
+
lor (Char.code s.[4] lsl 24)
25
+
lor (Char.code s.[5] lsl 16)
26
+
lor (Char.code s.[6] lsl 8)
27
+
lor Char.code s.[7]
28
+
in
29
+
let v = v land max_int in
30
+
let limit = max_int - (max_int mod bound) in
31
+
if v < limit then v mod bound else sample ()
32
+
in
33
+
sample ()
34
+
35
+
let token ?(bytes = 32) () =
36
+
let raw = generate bytes in
37
+
Base64.encode_string ~pad:false ~alphabet:Base64.uri_safe_alphabet raw
38
+
39
+
let uuid () =
40
+
let bytes = generate_bytes 16 in
41
+
Bytes.set bytes 6
42
+
(Char.chr (Char.code (Bytes.get bytes 6) land 0x0f lor 0x40));
43
+
Bytes.set bytes 8
44
+
(Char.chr (Char.code (Bytes.get bytes 8) land 0x3f lor 0x80));
45
+
let hex = Bytes.create 36 in
46
+
let hex_chars = "0123456789abcdef" in
47
+
let pos = ref 0 in
48
+
for i = 0 to 15 do
49
+
if i = 4 || i = 6 || i = 8 || i = 10 then begin
50
+
Bytes.set hex !pos '-';
51
+
incr pos
52
+
end;
53
+
let b = Char.code (Bytes.get bytes i) in
54
+
Bytes.set hex !pos hex_chars.[b lsr 4];
55
+
incr pos;
56
+
Bytes.set hex !pos hex_chars.[b land 0x0f];
57
+
incr pos
58
+
done;
59
+
Bytes.to_string hex
+281
-191
lib/server.ml
+281
-191
lib/server.ml
···
207
207
type handler = request -> response
208
208
type ws_handler = Websocket.t -> unit
209
209
210
+
type ws_config = {
211
+
origin_policy : Websocket.origin_policy;
212
+
max_payload_size : int;
213
+
}
214
+
215
+
let default_ws_config =
216
+
{
217
+
origin_policy = Websocket.Allow_all;
218
+
max_payload_size = Websocket.default_max_payload_size;
219
+
}
220
+
210
221
let respond ?(status = `OK) ?(headers = []) body =
211
222
Response.make ~status ~headers body
212
223
···
289
300
(** {1 Internal: HTTP/1.1 Connection Handler} *)
290
301
291
302
module H1_handler = struct
292
-
let handle ~handler ~ws_handler ~initial_data flow =
303
+
let handle ~handler ~ws_handler ?(ws_config = default_ws_config)
304
+
?max_body_size ~initial_data flow =
293
305
let buffer_size = 16384 in
294
306
let read_buffer = Bigstringaf.create buffer_size in
295
307
let read_cstruct =
···
306
318
(* Check for WebSocket upgrade *)
307
319
if Option.is_some ws_handler && Websocket.is_upgrade_request req.headers
308
320
then begin
309
-
match Websocket.get_websocket_key req.headers with
310
-
| Some key ->
311
-
ws_upgrade := Some key;
321
+
match
322
+
Websocket.validate_origin ~policy:ws_config.origin_policy req.headers
323
+
with
324
+
| Error reason ->
312
325
H1.Body.Reader.close h1_body;
313
-
let accept = Websocket.compute_accept_key key in
326
+
let body = "Forbidden: " ^ reason in
314
327
let headers =
315
328
H1.Headers.of_list
316
329
[
317
-
("upgrade", "websocket");
318
-
("connection", "Upgrade");
319
-
("sec-websocket-accept", accept);
330
+
("date", Date_cache.get ());
331
+
("content-length", string_of_int (String.length body));
320
332
]
321
333
in
322
-
H1.Reqd.respond_with_upgrade reqd headers
323
-
| None ->
324
-
H1.Body.Reader.close h1_body;
325
-
let headers =
326
-
H1.Headers.of_list
327
-
[ ("date", Date_cache.get ()); ("content-length", "11") ]
328
-
in
329
-
let resp = H1.Response.create ~headers `Bad_request in
330
-
H1.Reqd.respond_with_string reqd resp "Bad Request"
334
+
let resp = H1.Response.create ~headers `Forbidden in
335
+
H1.Reqd.respond_with_string reqd resp body
336
+
| Ok () -> (
337
+
match Websocket.validate_websocket_version req.headers with
338
+
| Error _reason ->
339
+
H1.Body.Reader.close h1_body;
340
+
let body = "Upgrade Required" in
341
+
let headers =
342
+
H1.Headers.of_list
343
+
[
344
+
("date", Date_cache.get ());
345
+
("content-length", string_of_int (String.length body));
346
+
( "sec-websocket-version",
347
+
Websocket.supported_websocket_version );
348
+
]
349
+
in
350
+
let resp = H1.Response.create ~headers (`Code 426) in
351
+
H1.Reqd.respond_with_string reqd resp body
352
+
| Ok () -> (
353
+
match Websocket.get_websocket_key req.headers with
354
+
| Some key ->
355
+
ws_upgrade := Some key;
356
+
H1.Body.Reader.close h1_body;
357
+
let accept = Websocket.compute_accept_key key in
358
+
let headers =
359
+
H1.Headers.of_list
360
+
[
361
+
("upgrade", "websocket");
362
+
("connection", "Upgrade");
363
+
("sec-websocket-accept", accept);
364
+
]
365
+
in
366
+
H1.Reqd.respond_with_upgrade reqd headers
367
+
| None ->
368
+
H1.Body.Reader.close h1_body;
369
+
let headers =
370
+
H1.Headers.of_list
371
+
[
372
+
("date", Date_cache.get ()); ("content-length", "11");
373
+
]
374
+
in
375
+
let resp = H1.Response.create ~headers `Bad_request in
376
+
H1.Reqd.respond_with_string reqd resp "Bad Request"))
331
377
end
332
378
else begin
333
379
(* Regular HTTP/1.1 request *)
334
380
(* Read body for POST/PUT, skip for GET/HEAD *)
335
-
let body =
381
+
let body_result =
336
382
match req.meth with
337
383
| `GET | `HEAD | `DELETE | `OPTIONS | `CONNECT | `TRACE ->
338
384
H1.Body.Reader.close h1_body;
339
-
""
385
+
Ok ""
340
386
| `POST | `PUT | `Other _ ->
341
387
let body_buffer = Buffer.create 4096 in
388
+
let body_size = ref 0 in
389
+
let too_large = ref false in
342
390
let body_done, resolver = Eio.Promise.create () in
343
391
let rec read_body () =
344
392
H1.Body.Reader.schedule_read h1_body
345
393
~on_eof:(fun () -> Eio.Promise.resolve resolver ())
346
394
~on_read:(fun buf ~off ~len ->
347
-
Buffer.add_string body_buffer
348
-
(Bigstringaf.substring buf ~off ~len);
349
-
read_body ())
395
+
let new_size = !body_size + len in
396
+
match max_body_size with
397
+
| Some max when Int64.of_int new_size > max ->
398
+
too_large := true;
399
+
H1.Body.Reader.close h1_body;
400
+
Eio.Promise.resolve resolver ()
401
+
| _ ->
402
+
body_size := new_size;
403
+
Buffer.add_string body_buffer
404
+
(Bigstringaf.substring buf ~off ~len);
405
+
read_body ())
350
406
in
351
407
read_body ();
352
408
Eio.Promise.await body_done;
353
-
Buffer.contents body_buffer
409
+
if !too_large then Error `Body_too_large
410
+
else Ok (Buffer.contents body_buffer)
354
411
in
355
-
356
-
let request =
357
-
{
358
-
meth = req.meth;
359
-
target = req.target;
360
-
headers = h1_headers_to_list req.headers;
361
-
body;
362
-
version = HTTP_1_1;
363
-
}
364
-
in
365
-
let response : Response.t = handler request in
366
-
367
-
let date_header = ("date", Date_cache.get ()) in
368
-
match response.Response.body with
369
-
| Response.Prebuilt_body prebuilt -> Prebuilt.respond_h1 reqd prebuilt
370
-
| Response.Empty ->
412
+
match body_result with
413
+
| Error `Body_too_large ->
414
+
let body = "Request body too large" in
371
415
let headers =
372
416
H1.Headers.of_list
373
-
(date_header :: ("content-length", "0") :: response.headers)
374
-
in
375
-
let resp = H1.Response.create ~headers response.status in
376
-
H1.Reqd.respond_with_string reqd resp ""
377
-
| Response.String body ->
378
-
let headers =
379
-
H1.Headers.of_list
380
-
(date_header
381
-
:: ("content-length", string_of_int (String.length body))
382
-
:: response.headers)
417
+
[
418
+
("date", Date_cache.get ());
419
+
("content-length", string_of_int (String.length body));
420
+
]
383
421
in
384
-
let resp = H1.Response.create ~headers response.status in
422
+
let resp = H1.Response.create ~headers (`Code 413) in
385
423
H1.Reqd.respond_with_string reqd resp body
386
-
| Response.Bigstring body ->
387
-
let headers =
388
-
H1.Headers.of_list
389
-
(date_header
390
-
:: ("content-length", string_of_int (Bigstringaf.length body))
391
-
:: response.headers)
424
+
| Ok body -> (
425
+
let request =
426
+
{
427
+
meth = req.meth;
428
+
target = req.target;
429
+
headers = h1_headers_to_list req.headers;
430
+
body;
431
+
version = HTTP_1_1;
432
+
}
392
433
in
393
-
let resp = H1.Response.create ~headers response.status in
394
-
H1.Reqd.respond_with_bigstring reqd resp body
395
-
| Response.Cstruct cs ->
396
-
let len = Cstruct.length cs in
397
-
let headers =
398
-
H1.Headers.of_list
399
-
(date_header
400
-
:: ("content-length", string_of_int len)
401
-
:: response.headers)
402
-
in
403
-
let resp = H1.Response.create ~headers response.status in
404
-
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
405
-
H1.Body.Writer.write_bigstring body_writer ~off:cs.off ~len
406
-
cs.buffer;
407
-
H1.Body.Writer.close body_writer
408
-
| Response.Stream { content_length; next } ->
409
-
let headers =
410
-
match content_length with
411
-
| Some len ->
434
+
let response : Response.t = handler request in
435
+
436
+
let date_header = ("date", Date_cache.get ()) in
437
+
match response.Response.body with
438
+
| Response.Prebuilt_body prebuilt ->
439
+
Prebuilt.respond_h1 reqd prebuilt
440
+
| Response.Empty ->
441
+
let headers =
442
+
H1.Headers.of_list
443
+
(date_header :: ("content-length", "0") :: response.headers)
444
+
in
445
+
let resp = H1.Response.create ~headers response.status in
446
+
H1.Reqd.respond_with_string reqd resp ""
447
+
| Response.String body ->
448
+
let headers =
412
449
H1.Headers.of_list
413
450
(date_header
414
-
:: ("content-length", Int64.to_string len)
451
+
:: ("content-length", string_of_int (String.length body))
415
452
:: response.headers)
416
-
| None ->
453
+
in
454
+
let resp = H1.Response.create ~headers response.status in
455
+
H1.Reqd.respond_with_string reqd resp body
456
+
| Response.Bigstring body ->
457
+
let headers =
417
458
H1.Headers.of_list
418
459
(date_header
419
-
:: ("transfer-encoding", "chunked")
460
+
:: ( "content-length",
461
+
string_of_int (Bigstringaf.length body) )
462
+
:: response.headers)
463
+
in
464
+
let resp = H1.Response.create ~headers response.status in
465
+
H1.Reqd.respond_with_bigstring reqd resp body
466
+
| Response.Cstruct cs ->
467
+
let len = Cstruct.length cs in
468
+
let headers =
469
+
H1.Headers.of_list
470
+
(date_header
471
+
:: ("content-length", string_of_int len)
420
472
:: response.headers)
421
-
in
422
-
let resp = H1.Response.create ~headers response.status in
423
-
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
424
-
let rec write_chunks () =
425
-
match next () with
426
-
| None -> H1.Body.Writer.close body_writer
427
-
| Some cs ->
428
-
H1.Body.Writer.write_bigstring body_writer ~off:0
429
-
~len:(Cstruct.length cs) (Cstruct.to_bigarray cs);
430
-
(* Flush to ensure data is sent immediately (required for SSE) *)
431
-
let flushed, resolve = Eio.Promise.create () in
432
-
H1.Body.Writer.flush body_writer (fun () ->
433
-
Eio.Promise.resolve resolve ());
434
-
Eio.Promise.await flushed;
435
-
write_chunks ()
436
-
in
437
-
write_chunks ()
473
+
in
474
+
let resp = H1.Response.create ~headers response.status in
475
+
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
476
+
H1.Body.Writer.write_bigstring body_writer ~off:cs.off ~len
477
+
cs.buffer;
478
+
H1.Body.Writer.close body_writer
479
+
| Response.Stream { content_length; next } ->
480
+
let headers =
481
+
match content_length with
482
+
| Some len ->
483
+
H1.Headers.of_list
484
+
(date_header
485
+
:: ("content-length", Int64.to_string len)
486
+
:: response.headers)
487
+
| None ->
488
+
H1.Headers.of_list
489
+
(date_header
490
+
:: ("transfer-encoding", "chunked")
491
+
:: response.headers)
492
+
in
493
+
let resp = H1.Response.create ~headers response.status in
494
+
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
495
+
let rec write_chunks () =
496
+
match next () with
497
+
| None -> H1.Body.Writer.close body_writer
498
+
| Some cs ->
499
+
H1.Body.Writer.write_bigstring body_writer ~off:0
500
+
~len:(Cstruct.length cs) (Cstruct.to_bigarray cs);
501
+
(* Flush to ensure data is sent immediately (required for SSE) *)
502
+
let flushed, resolve = Eio.Promise.create () in
503
+
H1.Body.Writer.flush body_writer (fun () ->
504
+
Eio.Promise.resolve resolve ());
505
+
Eio.Promise.await flushed;
506
+
write_chunks ()
507
+
in
508
+
write_chunks ())
438
509
end
439
510
in
440
511
···
517
588
closed = false;
518
589
is_client = false;
519
590
read_buf = Buffer.create 4096;
591
+
max_payload_size = ws_config.max_payload_size;
520
592
}
521
593
in
522
594
(try ws_h ws with _ -> ());
···
525
597
| None -> ()
526
598
527
599
(** Direct H1 handler - no protocol detection, no initial data buffering *)
528
-
let handle_direct ~handler flow =
600
+
let handle_direct ?max_body_size ~handler flow =
529
601
let buffer_size = 16384 in
530
602
let read_buffer = Bigstringaf.create buffer_size in
531
603
let read_cstruct =
···
537
609
let h1_body = H1.Reqd.request_body reqd in
538
610
539
611
(* Read body for POST/PUT, skip for GET/HEAD *)
540
-
let body =
612
+
let body_result =
541
613
match req.meth with
542
614
| `GET | `HEAD | `DELETE | `OPTIONS | `CONNECT | `TRACE ->
543
615
H1.Body.Reader.close h1_body;
544
-
""
616
+
Ok ""
545
617
| `POST | `PUT | `Other _ ->
546
618
let body_buffer = Buffer.create 4096 in
619
+
let body_size = ref 0 in
620
+
let too_large = ref false in
547
621
let body_done, resolver = Eio.Promise.create () in
548
622
let rec read_body () =
549
623
H1.Body.Reader.schedule_read h1_body
550
624
~on_eof:(fun () -> Eio.Promise.resolve resolver ())
551
625
~on_read:(fun buf ~off ~len ->
552
-
Buffer.add_string body_buffer
553
-
(Bigstringaf.substring buf ~off ~len);
554
-
read_body ())
626
+
let new_size = !body_size + len in
627
+
match max_body_size with
628
+
| Some max when Int64.of_int new_size > max ->
629
+
too_large := true;
630
+
H1.Body.Reader.close h1_body;
631
+
Eio.Promise.resolve resolver ()
632
+
| _ ->
633
+
body_size := new_size;
634
+
Buffer.add_string body_buffer
635
+
(Bigstringaf.substring buf ~off ~len);
636
+
read_body ())
555
637
in
556
638
read_body ();
557
639
Eio.Promise.await body_done;
558
-
Buffer.contents body_buffer
559
-
in
560
-
561
-
let request =
562
-
{
563
-
meth = req.meth;
564
-
target = req.target;
565
-
headers = h1_headers_to_list req.headers;
566
-
body;
567
-
version = HTTP_1_1;
568
-
}
640
+
if !too_large then Error `Body_too_large
641
+
else Ok (Buffer.contents body_buffer)
569
642
in
570
-
let response : Response.t = handler request in
571
-
572
-
let date_header = ("date", Date_cache.get ()) in
573
-
match response.Response.body with
574
-
| Response.Prebuilt_body prebuilt -> Prebuilt.respond_h1 reqd prebuilt
575
-
| Response.Empty ->
576
-
let headers =
577
-
H1.Headers.of_list
578
-
(date_header :: ("content-length", "0") :: response.headers)
579
-
in
580
-
let resp = H1.Response.create ~headers response.status in
581
-
H1.Reqd.respond_with_string reqd resp ""
582
-
| Response.String body ->
643
+
match body_result with
644
+
| Error `Body_too_large ->
645
+
let body = "Request body too large" in
583
646
let headers =
584
647
H1.Headers.of_list
585
-
(date_header
586
-
:: ("content-length", string_of_int (String.length body))
587
-
:: response.headers)
648
+
[
649
+
("date", Date_cache.get ());
650
+
("content-length", string_of_int (String.length body));
651
+
]
588
652
in
589
-
let resp = H1.Response.create ~headers response.status in
653
+
let resp = H1.Response.create ~headers (`Code 413) in
590
654
H1.Reqd.respond_with_string reqd resp body
591
-
| Response.Bigstring body ->
592
-
let headers =
593
-
H1.Headers.of_list
594
-
(date_header
595
-
:: ("content-length", string_of_int (Bigstringaf.length body))
596
-
:: response.headers)
655
+
| Ok body -> (
656
+
let request =
657
+
{
658
+
meth = req.meth;
659
+
target = req.target;
660
+
headers = h1_headers_to_list req.headers;
661
+
body;
662
+
version = HTTP_1_1;
663
+
}
597
664
in
598
-
let resp = H1.Response.create ~headers response.status in
599
-
H1.Reqd.respond_with_bigstring reqd resp body
600
-
| Response.Cstruct cs ->
601
-
let len = Cstruct.length cs in
602
-
let headers =
603
-
H1.Headers.of_list
604
-
(date_header
605
-
:: ("content-length", string_of_int len)
606
-
:: response.headers)
607
-
in
608
-
let resp = H1.Response.create ~headers response.status in
609
-
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
610
-
H1.Body.Writer.write_bigstring body_writer ~off:cs.off ~len cs.buffer;
611
-
H1.Body.Writer.close body_writer
612
-
| Response.Stream { content_length; next } ->
613
-
let headers =
614
-
match content_length with
615
-
| Some len ->
665
+
let response : Response.t = handler request in
666
+
667
+
let date_header = ("date", Date_cache.get ()) in
668
+
match response.Response.body with
669
+
| Response.Prebuilt_body prebuilt -> Prebuilt.respond_h1 reqd prebuilt
670
+
| Response.Empty ->
671
+
let headers =
672
+
H1.Headers.of_list
673
+
(date_header :: ("content-length", "0") :: response.headers)
674
+
in
675
+
let resp = H1.Response.create ~headers response.status in
676
+
H1.Reqd.respond_with_string reqd resp ""
677
+
| Response.String body ->
678
+
let headers =
616
679
H1.Headers.of_list
617
680
(date_header
618
-
:: ("content-length", Int64.to_string len)
681
+
:: ("content-length", string_of_int (String.length body))
619
682
:: response.headers)
620
-
| None ->
683
+
in
684
+
let resp = H1.Response.create ~headers response.status in
685
+
H1.Reqd.respond_with_string reqd resp body
686
+
| Response.Bigstring body ->
687
+
let headers =
621
688
H1.Headers.of_list
622
689
(date_header
623
-
:: ("transfer-encoding", "chunked")
690
+
:: ("content-length", string_of_int (Bigstringaf.length body))
624
691
:: response.headers)
625
-
in
626
-
let resp = H1.Response.create ~headers response.status in
627
-
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
628
-
let rec write_chunks () =
629
-
match next () with
630
-
| None -> H1.Body.Writer.close body_writer
631
-
| Some cs ->
632
-
H1.Body.Writer.write_bigstring body_writer ~off:cs.off
633
-
~len:(Cstruct.length cs) cs.buffer;
634
-
let flushed, resolve = Eio.Promise.create () in
635
-
H1.Body.Writer.flush body_writer (fun () ->
636
-
Eio.Promise.resolve resolve ());
637
-
Eio.Promise.await flushed;
638
-
write_chunks ()
639
-
in
640
-
write_chunks ()
692
+
in
693
+
let resp = H1.Response.create ~headers response.status in
694
+
H1.Reqd.respond_with_bigstring reqd resp body
695
+
| Response.Cstruct cs ->
696
+
let len = Cstruct.length cs in
697
+
let headers =
698
+
H1.Headers.of_list
699
+
(date_header
700
+
:: ("content-length", string_of_int len)
701
+
:: response.headers)
702
+
in
703
+
let resp = H1.Response.create ~headers response.status in
704
+
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
705
+
H1.Body.Writer.write_bigstring body_writer ~off:cs.off ~len
706
+
cs.buffer;
707
+
H1.Body.Writer.close body_writer
708
+
| Response.Stream { content_length; next } ->
709
+
let headers =
710
+
match content_length with
711
+
| Some len ->
712
+
H1.Headers.of_list
713
+
(date_header
714
+
:: ("content-length", Int64.to_string len)
715
+
:: response.headers)
716
+
| None ->
717
+
H1.Headers.of_list
718
+
(date_header
719
+
:: ("transfer-encoding", "chunked")
720
+
:: response.headers)
721
+
in
722
+
let resp = H1.Response.create ~headers response.status in
723
+
let body_writer = H1.Reqd.respond_with_streaming reqd resp in
724
+
let rec write_chunks () =
725
+
match next () with
726
+
| None -> H1.Body.Writer.close body_writer
727
+
| Some cs ->
728
+
H1.Body.Writer.write_bigstring body_writer ~off:cs.off
729
+
~len:(Cstruct.length cs) cs.buffer;
730
+
let flushed, resolve = Eio.Promise.create () in
731
+
H1.Body.Writer.flush body_writer (fun () ->
732
+
Eio.Promise.resolve resolve ());
733
+
Eio.Promise.await flushed;
734
+
write_chunks ()
735
+
in
736
+
write_chunks ())
641
737
in
642
738
643
739
let error_handler ?request:_ _error start_response =
···
891
987
let handle ~config ~handler ~ws_handler tls_cfg flow =
892
988
try
893
989
let tls_flow = Tls_eio.server_of_flow tls_cfg flow in
990
+
let max_body_size = config.max_body_size in
894
991
match config.protocol with
895
-
| Http1_only ->
896
-
(* No ALPN check, direct H1 *)
897
-
H1_handler.handle_direct ~handler tls_flow
898
-
| Http2_only ->
899
-
(* No ALPN check, direct H2 *)
900
-
H2_handler.handle_direct ~handler tls_flow
992
+
| Http1_only -> H1_handler.handle_direct ?max_body_size ~handler tls_flow
993
+
| Http2_only -> H2_handler.handle_direct ~handler tls_flow
901
994
| Auto | Auto_websocket -> (
902
-
(* Check ALPN negotiated protocol *)
903
995
match Tls_config.negotiated_protocol tls_flow with
904
996
| Some Tls_config.HTTP_2 -> H2_handler.handle_direct ~handler tls_flow
905
997
| Some Tls_config.HTTP_1_1 | None ->
906
998
if config.protocol = Auto_websocket then
907
999
H1_handler.handle ~handler ~ws_handler:(Some ws_handler)
908
-
~initial_data:"" tls_flow
909
-
else H1_handler.handle_direct ~handler tls_flow)
1000
+
?max_body_size ~initial_data:"" tls_flow
1001
+
else H1_handler.handle_direct ?max_body_size ~handler tls_flow)
910
1002
with
911
1003
| Tls_eio.Tls_failure failure ->
912
1004
traceln "TLS error: %s" (Tls_config.failure_to_string failure)
···
916
1008
(** {1 Internal: Connection Handler} *)
917
1009
918
1010
let handle_connection ~config ~handler ~ws_handler flow =
1011
+
let max_body_size = config.max_body_size in
919
1012
match config.protocol with
920
-
| Http1_only ->
921
-
(* Fastest path: direct H1, no detection *)
922
-
H1_handler.handle_direct ~handler flow
923
-
| Http2_only ->
924
-
(* Direct H2 (h2c) *)
925
-
H2_handler.handle_direct ~handler flow
1013
+
| Http1_only -> H1_handler.handle_direct ?max_body_size ~handler flow
1014
+
| Http2_only -> H2_handler.handle_direct ~handler flow
926
1015
| Auto | Auto_websocket -> (
927
-
(* Peek to detect protocol *)
928
1016
match peek_bytes flow h2_preface_prefix_len with
929
-
| Error `Eof -> () (* Client disconnected immediately *)
1017
+
| Error `Eof -> ()
930
1018
| Error (`Exn exn) ->
931
1019
traceln "Connection error: %s" (Printexc.to_string exn)
932
1020
| Ok initial_data ->
···
934
1022
H2_handler.handle ~handler ~initial_data flow
935
1023
else if config.protocol = Auto_websocket then
936
1024
H1_handler.handle ~handler ~ws_handler:(Some ws_handler)
937
-
~initial_data flow
938
-
else H1_handler.handle ~handler ~ws_handler:None ~initial_data flow)
1025
+
?max_body_size ~initial_data flow
1026
+
else
1027
+
H1_handler.handle ~handler ~ws_handler:None ?max_body_size
1028
+
~initial_data flow)
939
1029
940
1030
(** {1 Public API} *)
941
1031
+89
-46
lib/websocket.ml
+89
-46
lib/websocket.ml
···
100
100
in
101
101
{ opcode = Close; extension = 0; final = true; content }
102
102
103
+
let default_max_payload_size = 16 * 1024 * 1024
104
+
103
105
type t = {
104
106
flow : Eio.Flow.two_way_ty r;
105
107
mutable closed : bool;
106
-
is_client : bool; (** Client must mask frames *)
108
+
is_client : bool;
107
109
read_buf : Buffer.t;
110
+
max_payload_size : int;
108
111
}
109
-
(** WebSocket connection *)
110
112
111
-
(** Error type *)
112
-
type error = Connection_closed | Protocol_error of string | Io_error of string
113
+
type error =
114
+
| Connection_closed
115
+
| Protocol_error of string
116
+
| Io_error of string
117
+
| Payload_too_large of int
113
118
114
119
(** {1 Cryptographic helpers} *)
115
120
···
121
126
(** Compute the Sec-WebSocket-Accept value *)
122
127
let compute_accept_key key = b64_encoded_sha1sum (key ^ websocket_uuid)
123
128
124
-
(** {1 Random number generation for masking} *)
125
-
126
129
module Rng = struct
127
-
let initialized = ref false
128
-
129
-
let init () =
130
-
if not !initialized then begin
131
-
Random.self_init ();
132
-
initialized := true
133
-
end
134
-
135
-
(** Generate n random bytes *)
136
-
let generate n =
137
-
init ();
138
-
let buf = Bytes.create n in
139
-
for i = 0 to n - 1 do
140
-
Bytes.set buf i (Char.chr (Random.int 256))
141
-
done;
142
-
Bytes.to_string buf
130
+
let generate n = Secure_random.generate n
143
131
end
144
132
145
133
(** {1 Frame parsing/serialization} *)
···
204
192
loop 0;
205
193
Cstruct.to_string buf
206
194
207
-
(** Parse a frame from flow *)
208
-
let read_frame ~is_client flow =
195
+
let read_frame ~is_client ~max_payload_size flow =
209
196
try
210
-
(* Read first 2 bytes *)
211
197
let header = read_exactly flow 2 in
212
198
let b0 = Char.code header.[0] in
213
199
let b1 = Char.code header.[1] in
···
218
204
let masked = b1 land 0x80 <> 0 in
219
205
let len0 = b1 land 0x7f in
220
206
221
-
(* Server receiving from client: frames must be masked
222
-
Client receiving from server: frames must not be masked *)
223
207
if (not is_client) && not masked then
224
208
Error (Protocol_error "Client frames must be masked")
225
209
else if is_client && masked then
226
210
Error (Protocol_error "Server frames must not be masked")
227
211
else begin
228
-
(* Read extended length if needed *)
229
212
let len =
230
213
if len0 < 126 then len0
231
214
else if len0 = 126 then begin
···
233
216
(Char.code ext.[0] lsl 8) lor Char.code ext.[1]
234
217
end
235
218
else begin
236
-
(* 64-bit length *)
237
219
let ext = read_exactly flow 8 in
238
220
let len = ref 0 in
239
221
for i = 0 to 7 do
···
243
225
end
244
226
in
245
227
246
-
(* Control frames cannot be fragmented and max 125 bytes *)
247
-
if Opcode.is_control opcode && ((not final) || len > 125) then
228
+
if len > max_payload_size then Error (Payload_too_large len)
229
+
else if Opcode.is_control opcode && ((not final) || len > 125) then
248
230
Error (Protocol_error "Invalid control frame")
249
231
else begin
250
-
(* Read mask key if present *)
251
232
let mask_key = if masked then Some (read_exactly flow 4) else None in
252
-
253
-
(* Read payload *)
254
233
let content = if len > 0 then read_exactly flow len else "" in
255
234
let content =
256
235
match mask_key with
257
236
| Some key -> xor_mask key content
258
237
| None -> content
259
238
in
260
-
261
239
Ok { opcode; extension; final; content }
262
240
end
263
241
end
···
295
273
let send_pong t ?(content = "") () =
296
274
send t (make_frame ~opcode:Pong ~content ())
297
275
298
-
(** Receive a frame *)
299
276
let recv t =
300
277
if t.closed then Error Connection_closed
301
278
else
302
-
match read_frame ~is_client:t.is_client t.flow with
279
+
match
280
+
read_frame ~is_client:t.is_client ~max_payload_size:t.max_payload_size
281
+
t.flow
282
+
with
303
283
| Ok frame ->
304
-
(* Handle control frames *)
305
284
(match frame.opcode with
306
285
| Close ->
307
286
t.closed <- true;
308
-
(* Echo close frame back *)
309
287
ignore (send t (close_frame (-1)))
310
-
| Ping ->
311
-
(* Auto-respond to pings with pong *)
312
-
ignore (send_pong t ~content:frame.content ())
288
+
| Ping -> ignore (send_pong t ~content:frame.content ())
313
289
| _ -> ());
314
290
Ok frame
315
291
| Error e ->
···
374
350
(** Get the Sec-WebSocket-Key from request headers *)
375
351
let get_websocket_key headers = H1.Headers.get headers "sec-websocket-key"
376
352
353
+
let supported_websocket_version = "13"
354
+
355
+
let get_websocket_version headers =
356
+
H1.Headers.get headers "sec-websocket-version"
357
+
358
+
let validate_websocket_version headers =
359
+
match get_websocket_version headers with
360
+
| Some v when v = supported_websocket_version -> Ok ()
361
+
| Some v -> Error ("Unsupported WebSocket version: " ^ v)
362
+
| None -> Error "Missing Sec-WebSocket-Version header"
363
+
364
+
(** Get the Origin header from request headers *)
365
+
let get_origin headers = H1.Headers.get headers "origin"
366
+
367
+
(** Origin policy for WebSocket connections.
368
+
- [`Allow_all] accepts connections from any origin (NOT RECOMMENDED for
369
+
production)
370
+
- [`Allow_list origins] only accepts connections from the specified origins
371
+
- [`Allow_same_origin] only accepts connections where Origin matches the
372
+
Host header *)
373
+
type origin_policy = Allow_all | Allow_list of string list | Allow_same_origin
374
+
375
+
(** Validate Origin header against policy.
376
+
@param policy The origin validation policy
377
+
@param headers The request headers (must contain Origin, may contain Host)
378
+
@return [Ok ()] if origin is allowed, [Error reason] if rejected *)
379
+
let validate_origin ~policy headers =
380
+
match policy with
381
+
| Allow_all -> Ok ()
382
+
| Allow_list allowed -> (
383
+
match get_origin headers with
384
+
| None ->
385
+
(* Missing Origin header - could be same-origin request or non-browser client.
386
+
For security, we require Origin for Allow_list policy. *)
387
+
Error "Missing Origin header"
388
+
| Some origin ->
389
+
let origin_lower = String.lowercase_ascii origin in
390
+
if
391
+
List.exists
392
+
(fun allowed -> String.lowercase_ascii allowed = origin_lower)
393
+
allowed
394
+
then Ok ()
395
+
else Error ("Origin not allowed: " ^ origin))
396
+
| Allow_same_origin -> (
397
+
match (get_origin headers, H1.Headers.get headers "host") with
398
+
| None, _ ->
399
+
(* No Origin header - likely same-origin or non-browser, allow it *)
400
+
Ok ()
401
+
| Some origin, Some host ->
402
+
(* Extract host from origin URL (e.g., "https://example.com" -> "example.com") *)
403
+
let origin_host =
404
+
let uri = Uri.of_string origin in
405
+
match Uri.host uri with
406
+
| Some h -> (
407
+
match Uri.port uri with
408
+
| Some p -> h ^ ":" ^ string_of_int p
409
+
| None -> h)
410
+
| None -> origin
411
+
in
412
+
if String.lowercase_ascii origin_host = String.lowercase_ascii host
413
+
then Ok ()
414
+
else Error ("Cross-origin request: " ^ origin ^ " vs " ^ host)
415
+
| Some origin, None ->
416
+
Error ("Missing Host header for origin check: " ^ origin))
417
+
377
418
(** Generate random base64-encoded key for client handshake *)
378
419
let generate_key () = Base64.encode_exn (Rng.generate 16)
379
420
···
520
561
header_lines
521
562
in
522
563
523
-
(* Validate accept key *)
524
564
let accept = List.assoc_opt "sec-websocket-accept" headers in
525
565
match accept with
526
566
| Some a when a = expected_accept ->
···
530
570
closed = false;
531
571
is_client = true;
532
572
read_buf = Buffer.create 4096;
573
+
max_payload_size = default_max_payload_size;
533
574
}
534
575
| Some a ->
535
576
let buf =
···
548
589
549
590
(** {1 Server API} *)
550
591
551
-
(** Accept a WebSocket upgrade from an HTTP connection. Returns a WebSocket
552
-
connection after sending the upgrade response. *)
553
-
let accept ~flow ~key =
592
+
let accept ?max_payload_size ~flow ~key () =
593
+
let max_payload_size =
594
+
Option.value max_payload_size ~default:default_max_payload_size
595
+
in
554
596
let accept = compute_accept_key key in
555
597
let buf = Buffer.create (String.length accept + 128) in
556
598
Buffer.add_string buf "HTTP/1.1 101 Switching Protocols\r\n";
···
568
610
closed = false;
569
611
is_client = false;
570
612
read_buf = Buffer.create 4096;
613
+
max_payload_size;
571
614
}
572
615
with exn -> Error (Io_error (Printexc.to_string exn))
+9
-3
test/test_websocket.ml
+9
-3
test/test_websocket.ml
···
62
62
(match e with
63
63
| Hcs.Websocket.Connection_closed -> "closed"
64
64
| Hcs.Websocket.Protocol_error s -> "protocol: " ^ s
65
-
| Hcs.Websocket.Io_error s -> "io: " ^ s))
65
+
| Hcs.Websocket.Io_error s -> "io: " ^ s
66
+
| Hcs.Websocket.Payload_too_large n ->
67
+
"payload too large: " ^ string_of_int n))
66
68
| Error _ -> Eio.traceln " FAIL: send error");
67
69
Hcs.Websocket.close ws
68
70
| Error e ->
···
70
72
(match e with
71
73
| Hcs.Websocket.Connection_closed -> "closed"
72
74
| Hcs.Websocket.Protocol_error s -> "protocol: " ^ s
73
-
| Hcs.Websocket.Io_error s -> "io: " ^ s));
75
+
| Hcs.Websocket.Io_error s -> "io: " ^ s
76
+
| Hcs.Websocket.Payload_too_large n ->
77
+
"payload too large: " ^ string_of_int n));
74
78
75
79
Eio.traceln "Test 3: WebSocket on /ws/chat path...";
76
80
(match
···
93
97
(match e with
94
98
| Hcs.Websocket.Connection_closed -> "closed"
95
99
| Hcs.Websocket.Protocol_error s -> "protocol: " ^ s
96
-
| Hcs.Websocket.Io_error s -> "io: " ^ s));
100
+
| Hcs.Websocket.Io_error s -> "io: " ^ s
101
+
| Hcs.Websocket.Payload_too_large n ->
102
+
"payload too large: " ^ string_of_int n));
97
103
98
104
Eio.traceln "Test 4: Multiple messages...";
99
105
(match