Punycode (RFC3492) in OCaml
1(*---------------------------------------------------------------------------
2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved.
3 SPDX-License-Identifier: ISC
4 ---------------------------------------------------------------------------*)
5
6(* RFC 3492 Punycode Implementation *)
7
8(* {1 Bootstring Parameters for Punycode (RFC 3492 Section 5)} *)
9
10let base = 36
11let tmin = 1
12let tmax = 26
13let skew = 38
14let damp = 700
15let initial_bias = 72
16let initial_n = 0x80 (* 128 *)
17let delimiter = '-'
18let ace_prefix = "xn--"
19let max_label_length = 63
20
21(* {1 Position Tracking} *)
22
23type position = { byte_offset : int; char_index : int }
24
25let position_byte_offset pos = pos.byte_offset
26let position_char_index pos = pos.char_index
27
28let pp_position fmt pos =
29 Format.fprintf fmt "byte %d, char %d" pos.byte_offset pos.char_index
30
31(* {1 Error Types} *)
32
33type error_reason =
34 | Overflow of position
35 | Invalid_character of position * Uchar.t
36 | Invalid_digit of position * char
37 | Unexpected_end of position
38 | Invalid_utf8 of position
39 | Label_too_long of int
40 | Empty_label
41
42let pp_error_reason fmt = function
43 | Overflow pos ->
44 Format.fprintf fmt "arithmetic overflow at %a" pp_position pos
45 | Invalid_character (pos, u) ->
46 Format.fprintf fmt "invalid character U+%04X at %a" (Uchar.to_int u)
47 pp_position pos
48 | Invalid_digit (pos, c) ->
49 Format.fprintf fmt "invalid Punycode digit '%c' (0x%02X) at %a" c
50 (Char.code c) pp_position pos
51 | Unexpected_end pos ->
52 Format.fprintf fmt "unexpected end of input at %a" pp_position pos
53 | Invalid_utf8 pos ->
54 Format.fprintf fmt "invalid UTF-8 sequence at %a" pp_position pos
55 | Label_too_long len ->
56 Format.fprintf fmt "label too long: %d bytes (max %d)" len
57 max_label_length
58 | Empty_label -> Format.fprintf fmt "empty label"
59
60exception Error of error_reason
61
62let () = Printexc.register_printer (function
63 | Error reason -> Some (Format.asprintf "Punycode.Error: %a" pp_error_reason reason)
64 | _ -> None)
65
66let error_reason_to_string reason = Format.asprintf "%a" pp_error_reason reason
67
68(* {1 Error Constructors} *)
69
70let overflow pos = raise (Error (Overflow pos))
71let invalid_character pos u = raise (Error (Invalid_character (pos, u)))
72let invalid_digit pos c = raise (Error (Invalid_digit (pos, c)))
73let unexpected_end pos = raise (Error (Unexpected_end pos))
74let invalid_utf8 pos = raise (Error (Invalid_utf8 pos))
75let label_too_long len = raise (Error (Label_too_long len))
76let empty_label () = raise (Error Empty_label)
77
78(* {1 Case Flags} *)
79
80type case_flag = Uppercase | Lowercase
81
82(* {1 Basic Predicates} *)
83
84let is_basic u = Uchar.to_int u < 0x80
85let is_ascii_string s = String.for_all (fun c -> Char.code c < 0x80) s
86
87let has_ace_prefix s =
88 let len = String.length s in
89 len >= 4
90 && (s.[0] = 'x' || s.[0] = 'X')
91 && (s.[1] = 'n' || s.[1] = 'N')
92 && s.[2] = '-'
93 && s.[3] = '-'
94
95(* {1 Digit Encoding/Decoding (RFC 3492 Section 5)}
96
97 Digit values:
98 - 0-25: a-z (or A-Z)
99 - 26-35: 0-9
100*)
101
102let encode_digit d case_flag =
103 if d < 26 then Char.chr (d + if case_flag = Uppercase then 0x41 else 0x61)
104 else Char.chr (d - 26 + 0x30)
105
106let decode_digit c =
107 let code = Char.code c in
108 if code >= 0x30 && code <= 0x39 then Some (code - 0x30 + 26)
109 (* '0'-'9' -> 26-35 *)
110 else if code >= 0x41 && code <= 0x5A then Some (code - 0x41)
111 (* 'A'-'Z' -> 0-25 *)
112 else if code >= 0x61 && code <= 0x7A then Some (code - 0x61)
113 (* 'a'-'z' -> 0-25 *)
114 else None
115
116(* Check if a character is "flagged" (uppercase) for case annotation *)
117let is_flagged c =
118 let code = Char.code c in
119 code >= 0x41 && code <= 0x5A (* 'A'-'Z' *)
120
121(* {1 Bias Adaptation (RFC 3492 Section 6.1)} *)
122
123let adapt ~delta ~numpoints ~firsttime =
124 let delta = if firsttime then delta / damp else delta / 2 in
125 let delta = delta + (delta / numpoints) in
126 let threshold = (base - tmin) * tmax / 2 in
127 let rec loop delta k =
128 if delta > threshold then loop (delta / (base - tmin)) (k + base)
129 else k + ((base - tmin + 1) * delta / (delta + skew))
130 in
131 loop delta 0
132
133(* {1 Overflow-Safe Arithmetic}
134
135 RFC 3492 Section 6.4: Use detection to avoid overflow.
136 A + B overflows iff B > maxint - A
137 A + B*C overflows iff B > (maxint - A) / C
138*)
139
140let max_int_value = max_int
141
142let safe_mul_add a b c pos =
143 if c = 0 then a
144 else if b > (max_int_value - a) / c then overflow pos
145 else a + (b * c)
146
147(* {1 UTF-8 to Code Points Conversion} *)
148
149let utf8_to_codepoints s =
150 let len = String.length s in
151 let acc = ref [] in
152 let byte_offset = ref 0 in
153 let char_index = ref 0 in
154 while !byte_offset < len do
155 let pos = { byte_offset = !byte_offset; char_index = !char_index } in
156 let dec = String.get_utf_8_uchar s !byte_offset in
157 if Uchar.utf_decode_is_valid dec then begin
158 acc := Uchar.utf_decode_uchar dec :: !acc;
159 byte_offset := !byte_offset + Uchar.utf_decode_length dec;
160 incr char_index
161 end
162 else invalid_utf8 pos
163 done;
164 Array.of_list (List.rev !acc)
165
166(* {1 Code Points to UTF-8 Conversion} *)
167
168let codepoints_to_utf8 codepoints =
169 let buf = Buffer.create (Array.length codepoints * 2) in
170 Array.iter (Buffer.add_utf_8_uchar buf) codepoints;
171 Buffer.contents buf
172
173(* {1 Punycode Encoding (RFC 3492 Section 6.3)} *)
174
175let encode_impl codepoints case_flags =
176 let input_length = Array.length codepoints in
177 if input_length = 0 then ""
178 else begin
179 let output = Buffer.create (input_length * 2) in
180
181 (* Copy basic code points to output *)
182 let basic_count = ref 0 in
183 for j = 0 to input_length - 1 do
184 let cp = codepoints.(j) in
185 if is_basic cp then begin
186 let c = Uchar.to_int cp in
187 let case =
188 match case_flags with Some flags -> flags.(j) | None -> Lowercase
189 in
190 (* Preserve or apply case for ASCII letters *)
191 let c' =
192 if c >= 0x41 && c <= 0x5A then (* 'A'-'Z' *)
193 if case = Lowercase then c + 0x20 else c
194 else if c >= 0x61 && c <= 0x7A then (* 'a'-'z' *)
195 if case = Uppercase then c - 0x20 else c
196 else c
197 in
198 Buffer.add_char output (Char.chr c');
199 incr basic_count
200 end
201 done;
202
203 let b = !basic_count in
204 let h = ref b in
205
206 (* Add delimiter if there were basic code points *)
207 if b > 0 then Buffer.add_char output delimiter;
208
209 (* Main encoding loop *)
210 let n = ref initial_n in
211 let delta = ref 0 in
212 let bias = ref initial_bias in
213
214 while !h < input_length do
215 (* Find minimum code point >= n *)
216 let m =
217 Array.fold_left
218 (fun acc cp ->
219 let cp_val = Uchar.to_int cp in
220 if cp_val >= !n && cp_val < acc then cp_val else acc)
221 max_int_value codepoints
222 in
223
224 (* Increase delta to advance state to <m, 0> *)
225 let pos = { byte_offset = 0; char_index = !h } in
226 delta := safe_mul_add !delta (m - !n) (!h + 1) pos;
227 n := m;
228
229 (* Process each code point *)
230 for j = 0 to input_length - 1 do
231 let cp = Uchar.to_int codepoints.(j) in
232 let pos = { byte_offset = 0; char_index = j } in
233
234 if cp < !n then begin
235 incr delta;
236 if !delta = 0 then (* Overflow *)
237 overflow pos
238 end
239 else if cp = !n then begin
240 (* Encode delta as variable-length integer *)
241 let q = ref !delta in
242 let k = ref base in
243 let done_encoding = ref false in
244
245 while not !done_encoding do
246 let t =
247 if !k <= !bias then tmin
248 else if !k >= !bias + tmax then tmax
249 else !k - !bias
250 in
251 if !q < t then begin
252 (* Output final digit *)
253 let case =
254 match case_flags with
255 | Some flags -> flags.(j)
256 | None -> Lowercase
257 in
258 Buffer.add_char output (encode_digit !q case);
259 done_encoding := true
260 end
261 else begin
262 (* Output intermediate digit and continue *)
263 let digit = t + ((!q - t) mod (base - t)) in
264 Buffer.add_char output (encode_digit digit Lowercase);
265 q := (!q - t) / (base - t);
266 k := !k + base
267 end
268 done;
269
270 bias := adapt ~delta:!delta ~numpoints:(!h + 1) ~firsttime:(!h = b);
271 delta := 0;
272 incr h
273 end
274 done;
275
276 incr delta;
277 incr n
278 done;
279
280 Buffer.contents output
281 end
282
283let encode codepoints = encode_impl codepoints None
284
285let encode_with_case codepoints case_flags =
286 if Array.length codepoints <> Array.length case_flags then
287 invalid_arg "encode_with_case: array lengths must match";
288 encode_impl codepoints (Some case_flags)
289
290(* {1 Punycode Decoding (RFC 3492 Section 6.2)} *)
291
292let decode_impl input =
293 let input_length = String.length input in
294 if input_length = 0 then ([||], [||])
295 else begin
296 (* Find last delimiter *)
297 let b = Option.value ~default:0 (String.rindex_opt input delimiter) in
298
299 (* Copy basic code points and extract case flags *)
300 let output = ref [] in
301 let case_output = ref [] in
302
303 for j = 0 to b - 1 do
304 let c = input.[j] in
305 let pos = { byte_offset = j; char_index = j } in
306 let code = Char.code c in
307 if code >= 0x80 then
308 invalid_character pos (Uchar.of_int code)
309 else begin
310 output := Uchar.of_int code :: !output;
311 case_output :=
312 (if is_flagged c then Uppercase else Lowercase) :: !case_output
313 end
314 done;
315
316 let output = ref (Array.of_list (List.rev !output)) in
317 let case_output = ref (Array.of_list (List.rev !case_output)) in
318
319 (* Main decoding loop *)
320 let n = ref initial_n in
321 let i = ref 0 in
322 let bias = ref initial_bias in
323 let in_pos = ref (if b > 0 then b + 1 else 0) in
324
325 while !in_pos < input_length do
326 let oldi = !i in
327 let w = ref 1 in
328 let k = ref base in
329 let done_decoding = ref false in
330
331 while not !done_decoding do
332 let pos =
333 { byte_offset = !in_pos; char_index = Array.length !output }
334 in
335
336 if !in_pos >= input_length then
337 unexpected_end pos
338 else begin
339 let c = input.[!in_pos] in
340 incr in_pos;
341
342 match decode_digit c with
343 | None -> invalid_digit pos c
344 | Some digit ->
345 (* i = i + digit * w, with overflow check *)
346 i := safe_mul_add !i digit !w pos;
347
348 let t =
349 if !k <= !bias then tmin
350 else if !k >= !bias + tmax then tmax
351 else !k - !bias
352 in
353
354 if digit < t then
355 (* Record case flag from this final digit *)
356 done_decoding := true
357 else begin
358 (* w = w * (base - t), with overflow check *)
359 let base_minus_t = base - t in
360 if !w > max_int_value / base_minus_t then
361 overflow pos
362 else begin
363 w := !w * base_minus_t;
364 k := !k + base
365 end
366 end
367 end
368 done;
369
370 let out_len = Array.length !output in
371 bias :=
372 adapt ~delta:(!i - oldi) ~numpoints:(out_len + 1)
373 ~firsttime:(oldi = 0);
374
375 let pos = { byte_offset = !in_pos - 1; char_index = out_len } in
376
377 (* n = n + i / (out_len + 1), with overflow check *)
378 let increment = !i / (out_len + 1) in
379 if increment > max_int_value - !n then overflow pos
380 else begin
381 n := !n + increment;
382 i := !i mod (out_len + 1);
383
384 (* Validate that n is a valid Unicode scalar value *)
385 if not (Uchar.is_valid !n) then
386 invalid_character pos Uchar.rep
387 else begin
388 (* Insert n at position i *)
389 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in
390 let new_case = Array.make (out_len + 1) Lowercase in
391
392 for j = 0 to !i - 1 do
393 new_output.(j) <- !output.(j);
394 new_case.(j) <- !case_output.(j)
395 done;
396 new_output.(!i) <- Uchar.of_int !n;
397 (* Case flag from final digit of this delta *)
398 new_case.(!i) <-
399 (if !in_pos > 0 && is_flagged input.[!in_pos - 1] then
400 Uppercase
401 else Lowercase);
402 for j = !i to out_len - 1 do
403 new_output.(j + 1) <- !output.(j);
404 new_case.(j + 1) <- !case_output.(j)
405 done;
406
407 output := new_output;
408 case_output := new_case;
409 incr i
410 end
411 end
412 done;
413
414 (!output, !case_output)
415 end
416
417let decode input = fst (decode_impl input)
418let decode_with_case input = decode_impl input
419
420(* {1 UTF-8 String Operations} *)
421
422let encode_utf8 s =
423 let codepoints = utf8_to_codepoints s in
424 encode codepoints
425
426let decode_utf8 punycode =
427 let codepoints = decode punycode in
428 codepoints_to_utf8 codepoints
429
430(* {1 Domain Label Operations} *)
431
432let encode_label label =
433 if String.length label = 0 then empty_label ()
434 else if is_ascii_string label then begin
435 (* All ASCII - return as-is, but check length *)
436 let len = String.length label in
437 if len > max_label_length then label_too_long len else label
438 end
439 else begin
440 (* Has non-ASCII - encode with Punycode *)
441 let encoded = encode_utf8 label in
442 let result = ace_prefix ^ encoded in
443 let len = String.length result in
444 if len > max_label_length then label_too_long len else result
445 end
446
447let decode_label label =
448 if String.length label = 0 then empty_label ()
449 else if has_ace_prefix label then begin
450 (* Remove ACE prefix and decode *)
451 let punycode = String.sub label 4 (String.length label - 4) in
452 decode_utf8 punycode
453 end
454 else begin
455 (* No ACE prefix - validate and return *)
456 if is_ascii_string label then label
457 else
458 (* Has non-ASCII but no ACE prefix - return as-is *)
459 label
460 end