forked from
anil.recoil.org/ocaml-punycode
Punycode (RFC3492) in OCaml
1(*---------------------------------------------------------------------------
2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved.
3 SPDX-License-Identifier: ISC
4 ---------------------------------------------------------------------------*)
5
6(* RFC 3492 Punycode Implementation *)
7
8(* {1 Bootstring Parameters for Punycode (RFC 3492 Section 5)} *)
9
10let base = 36
11let tmin = 1
12let tmax = 26
13let skew = 38
14let damp = 700
15let initial_bias = 72
16let initial_n = 0x80 (* 128 *)
17let delimiter = '-'
18let ace_prefix = "xn--"
19let max_label_length = 63
20
21(* {1 Position Tracking} *)
22
23type position = { byte_offset : int; char_index : int }
24
25let position_byte_offset pos = pos.byte_offset
26let position_char_index pos = pos.char_index
27
28let pp_position fmt pos =
29 Fmt.pf fmt "byte %d, char %d" pos.byte_offset pos.char_index
30
31(* {1 Error Types} *)
32
33type error_reason =
34 | Overflow of position
35 | Invalid_character of position * Uchar.t
36 | Invalid_digit of position * char
37 | Unexpected_end of position
38 | Invalid_utf8 of position
39 | Label_too_long of int
40 | Empty_label
41
42let pp_error_reason fmt = function
43 | Overflow pos -> Fmt.pf fmt "arithmetic overflow at %a" pp_position pos
44 | Invalid_character (pos, u) ->
45 Fmt.pf fmt "invalid character U+%04X at %a" (Uchar.to_int u) pp_position
46 pos
47 | Invalid_digit (pos, c) ->
48 Fmt.pf fmt "invalid Punycode digit '%c' (0x%02X) at %a" c (Char.code c)
49 pp_position pos
50 | Unexpected_end pos ->
51 Fmt.pf fmt "unexpected end of input at %a" pp_position pos
52 | Invalid_utf8 pos ->
53 Fmt.pf fmt "invalid UTF-8 sequence at %a" pp_position pos
54 | Label_too_long len ->
55 Fmt.pf fmt "label too long: %d bytes (max %d)" len max_label_length
56 | Empty_label -> Fmt.pf fmt "empty label"
57
58exception Error of error_reason
59
60let () =
61 Printexc.register_printer (function
62 | Error reason -> Some (Fmt.str "Punycode.Error: %a" pp_error_reason reason)
63 | _ -> None)
64
65let error_reason_to_string reason = Fmt.str "%a" pp_error_reason reason
66
67(* {1 Error Constructors} *)
68
69let overflow pos = raise (Error (Overflow pos))
70let invalid_character pos u = raise (Error (Invalid_character (pos, u)))
71let invalid_digit pos c = raise (Error (Invalid_digit (pos, c)))
72let unexpected_end pos = raise (Error (Unexpected_end pos))
73let invalid_utf8 pos = raise (Error (Invalid_utf8 pos))
74let label_too_long len = raise (Error (Label_too_long len))
75let empty_label () = raise (Error Empty_label)
76
77(* {1 Case Flags} *)
78
79type case_flag = Uppercase | Lowercase
80
81(* {1 Basic Predicates} *)
82
83let is_basic u = Uchar.to_int u < 0x80
84let is_ascii_string s = String.for_all (fun c -> Char.code c < 0x80) s
85
86let has_ace_prefix s =
87 let len = String.length s in
88 len >= 4
89 && (s.[0] = 'x' || s.[0] = 'X')
90 && (s.[1] = 'n' || s.[1] = 'N')
91 && s.[2] = '-'
92 && s.[3] = '-'
93
94(* {1 Digit Encoding/Decoding (RFC 3492 Section 5)}
95
96 Digit values:
97 - 0-25: a-z (or A-Z)
98 - 26-35: 0-9
99*)
100
101let encode_digit d case_flag =
102 if d < 26 then Char.chr (d + if case_flag = Uppercase then 0x41 else 0x61)
103 else Char.chr (d - 26 + 0x30)
104
105let decode_digit c =
106 let code = Char.code c in
107 if code >= 0x30 && code <= 0x39 then Some (code - 0x30 + 26)
108 (* '0'-'9' -> 26-35 *)
109 else if code >= 0x41 && code <= 0x5A then Some (code - 0x41)
110 (* 'A'-'Z' -> 0-25 *)
111 else if code >= 0x61 && code <= 0x7A then Some (code - 0x61)
112 (* 'a'-'z' -> 0-25 *)
113 else None
114
115(* Check if a character is "flagged" (uppercase) for case annotation *)
116let is_flagged c =
117 let code = Char.code c in
118 code >= 0x41 && code <= 0x5A (* 'A'-'Z' *)
119
120(* {1 Bias Adaptation (RFC 3492 Section 6.1)} *)
121
122let adapt ~delta ~numpoints ~firsttime =
123 let delta = if firsttime then delta / damp else delta / 2 in
124 let delta = delta + (delta / numpoints) in
125 let threshold = (base - tmin) * tmax / 2 in
126 let rec loop delta k =
127 if delta > threshold then loop (delta / (base - tmin)) (k + base)
128 else k + ((base - tmin + 1) * delta / (delta + skew))
129 in
130 loop delta 0
131
132(* {1 Overflow-Safe Arithmetic}
133
134 RFC 3492 Section 6.4: Use detection to avoid overflow.
135 A + B overflows iff B > maxint - A
136 A + B*C overflows iff B > (maxint - A) / C
137*)
138
139let max_int_value = max_int
140
141let safe_mul_add a b c pos =
142 if c = 0 then a
143 else if b > (max_int_value - a) / c then overflow pos
144 else a + (b * c)
145
146(* {1 UTF-8 to Code Points Conversion} *)
147
148let utf8_to_codepoints s =
149 let len = String.length s in
150 let acc = ref [] in
151 let byte_offset = ref 0 in
152 let char_index = ref 0 in
153 while !byte_offset < len do
154 let pos = { byte_offset = !byte_offset; char_index = !char_index } in
155 let dec = String.get_utf_8_uchar s !byte_offset in
156 if Uchar.utf_decode_is_valid dec then begin
157 acc := Uchar.utf_decode_uchar dec :: !acc;
158 byte_offset := !byte_offset + Uchar.utf_decode_length dec;
159 incr char_index
160 end
161 else invalid_utf8 pos
162 done;
163 Array.of_list (List.rev !acc)
164
165(* {1 Code Points to UTF-8 Conversion} *)
166
167let codepoints_to_utf8 codepoints =
168 let buf = Buffer.create (Array.length codepoints * 2) in
169 Array.iter (Buffer.add_utf_8_uchar buf) codepoints;
170 Buffer.contents buf
171
172(* {1 Punycode Encoding (RFC 3492 Section 6.3)} *)
173
174let encode_basic_codepoints output codepoints case_flags =
175 let input_length = Array.length codepoints in
176 let basic_count = ref 0 in
177 for j = 0 to input_length - 1 do
178 let cp = codepoints.(j) in
179 if is_basic cp then begin
180 let c = Uchar.to_int cp in
181 let case =
182 match case_flags with Some flags -> flags.(j) | None -> Lowercase
183 in
184 let c' =
185 if c >= 0x41 && c <= 0x5A then if case = Lowercase then c + 0x20 else c
186 else if c >= 0x61 && c <= 0x7A then
187 if case = Uppercase then c - 0x20 else c
188 else c
189 in
190 Buffer.add_char output (Char.chr c');
191 incr basic_count
192 end
193 done;
194 !basic_count
195
196let encode_variable_length_int output ~q ~bias ~case_flags ~j =
197 let q = ref q in
198 let k = ref base in
199 let done_encoding = ref false in
200 while not !done_encoding do
201 let t =
202 if !k <= bias then tmin else if !k >= bias + tmax then tmax else !k - bias
203 in
204 if !q < t then begin
205 let case =
206 match case_flags with Some flags -> flags.(j) | None -> Lowercase
207 in
208 Buffer.add_char output (encode_digit !q case);
209 done_encoding := true
210 end
211 else begin
212 let digit = t + ((!q - t) mod (base - t)) in
213 Buffer.add_char output (encode_digit digit Lowercase);
214 q := (!q - t) / (base - t);
215 k := !k + base
216 end
217 done
218
219let encode_impl codepoints case_flags =
220 let input_length = Array.length codepoints in
221 if input_length = 0 then ""
222 else begin
223 let output = Buffer.create (input_length * 2) in
224
225 let b = encode_basic_codepoints output codepoints case_flags in
226 let h = ref b in
227
228 if b > 0 then Buffer.add_char output delimiter;
229
230 let n = ref initial_n in
231 let delta = ref 0 in
232 let bias = ref initial_bias in
233
234 while !h < input_length do
235 let m =
236 Array.fold_left
237 (fun acc cp ->
238 let cp_val = Uchar.to_int cp in
239 if cp_val >= !n && cp_val < acc then cp_val else acc)
240 max_int_value codepoints
241 in
242
243 let pos = { byte_offset = 0; char_index = !h } in
244 delta := safe_mul_add !delta (m - !n) (!h + 1) pos;
245 n := m;
246
247 for j = 0 to input_length - 1 do
248 let cp = Uchar.to_int codepoints.(j) in
249 let pos = { byte_offset = 0; char_index = j } in
250
251 if cp < !n then begin
252 incr delta;
253 if !delta = 0 then overflow pos
254 end
255 else if cp = !n then begin
256 encode_variable_length_int output ~q:!delta ~bias:!bias ~case_flags ~j;
257 bias := adapt ~delta:!delta ~numpoints:(!h + 1) ~firsttime:(!h = b);
258 delta := 0;
259 incr h
260 end
261 done;
262
263 incr delta;
264 incr n
265 done;
266
267 Buffer.contents output
268 end
269
270let encode codepoints = encode_impl codepoints None
271
272let encode_with_case codepoints case_flags =
273 if Array.length codepoints <> Array.length case_flags then
274 invalid_arg "encode_with_case: array lengths must match";
275 encode_impl codepoints (Some case_flags)
276
277(* {1 Punycode Decoding (RFC 3492 Section 6.2)} *)
278
279let decode_basic_codepoints input b =
280 let output = ref [] in
281 let case_output = ref [] in
282 for j = 0 to b - 1 do
283 let c = input.[j] in
284 let pos = { byte_offset = j; char_index = j } in
285 let code = Char.code c in
286 if code >= 0x80 then invalid_character pos (Uchar.of_int code)
287 else begin
288 output := Uchar.of_int code :: !output;
289 case_output :=
290 (if is_flagged c then Uppercase else Lowercase) :: !case_output
291 end
292 done;
293 (Array.of_list (List.rev !output), Array.of_list (List.rev !case_output))
294
295let decode_delta_sequence input ~input_length ~in_pos ~i ~bias ~output =
296 let oldi = !i in
297 let w = ref 1 in
298 let k = ref base in
299 let done_decoding = ref false in
300 while not !done_decoding do
301 let pos = { byte_offset = !in_pos; char_index = Array.length output } in
302 if !in_pos >= input_length then unexpected_end pos
303 else begin
304 let c = input.[!in_pos] in
305 incr in_pos;
306 match decode_digit c with
307 | None -> invalid_digit pos c
308 | Some digit ->
309 i := safe_mul_add !i digit !w pos;
310 let t =
311 if !k <= bias then tmin
312 else if !k >= bias + tmax then tmax
313 else !k - bias
314 in
315 if digit < t then done_decoding := true
316 else begin
317 let base_minus_t = base - t in
318 if !w > max_int_value / base_minus_t then overflow pos
319 else begin
320 w := !w * base_minus_t;
321 k := !k + base
322 end
323 end
324 end
325 done;
326 oldi
327
328let insert_codepoint ~output ~case_output ~n ~i ~in_pos input =
329 let out_len = Array.length !output in
330 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in
331 let new_case = Array.make (out_len + 1) Lowercase in
332 for j = 0 to i - 1 do
333 new_output.(j) <- !output.(j);
334 new_case.(j) <- !case_output.(j)
335 done;
336 new_output.(i) <- Uchar.of_int n;
337 new_case.(i) <-
338 (if in_pos > 0 && is_flagged input.[in_pos - 1] then Uppercase
339 else Lowercase);
340 for j = i to out_len - 1 do
341 new_output.(j + 1) <- !output.(j);
342 new_case.(j + 1) <- !case_output.(j)
343 done;
344 output := new_output;
345 case_output := new_case
346
347let decode_impl input =
348 let input_length = String.length input in
349 if input_length = 0 then ([||], [||])
350 else begin
351 let b = Option.value ~default:0 (String.rindex_opt input delimiter) in
352
353 let init_output, init_case = decode_basic_codepoints input b in
354 let output = ref init_output in
355 let case_output = ref init_case in
356
357 let n = ref initial_n in
358 let i = ref 0 in
359 let bias = ref initial_bias in
360 let in_pos = ref (if b > 0 then b + 1 else 0) in
361
362 while !in_pos < input_length do
363 let oldi =
364 decode_delta_sequence input ~input_length ~in_pos ~i ~bias:!bias
365 ~output:!output
366 in
367
368 let out_len = Array.length !output in
369 bias :=
370 adapt ~delta:(!i - oldi) ~numpoints:(out_len + 1) ~firsttime:(oldi = 0);
371
372 let pos = { byte_offset = !in_pos - 1; char_index = out_len } in
373
374 let increment = !i / (out_len + 1) in
375 if increment > max_int_value - !n then overflow pos
376 else begin
377 n := !n + increment;
378 i := !i mod (out_len + 1);
379
380 if not (Uchar.is_valid !n) then invalid_character pos Uchar.rep
381 else begin
382 insert_codepoint ~output ~case_output ~n:!n ~i:!i ~in_pos:!in_pos
383 input;
384 incr i
385 end
386 end
387 done;
388
389 (!output, !case_output)
390 end
391
392let decode input = fst (decode_impl input)
393let decode_with_case input = decode_impl input
394
395(* {1 UTF-8 String Operations} *)
396
397let encode_utf8 s =
398 let codepoints = utf8_to_codepoints s in
399 encode codepoints
400
401let decode_utf8 punycode =
402 let codepoints = decode punycode in
403 codepoints_to_utf8 codepoints
404
405(* {1 Domain Label Operations} *)
406
407let encode_label label =
408 if String.length label = 0 then empty_label ()
409 else if is_ascii_string label then begin
410 (* All ASCII - return as-is, but check length *)
411 let len = String.length label in
412 if len > max_label_length then label_too_long len else label
413 end
414 else begin
415 (* Has non-ASCII - encode with Punycode *)
416 let encoded = encode_utf8 label in
417 let result = ace_prefix ^ encoded in
418 let len = String.length result in
419 if len > max_label_length then label_too_long len else result
420 end
421
422let decode_label label =
423 if String.length label = 0 then empty_label ()
424 else if has_ace_prefix label then begin
425 (* Remove ACE prefix and decode *)
426 let punycode = String.sub label 4 (String.length label - 4) in
427 decode_utf8 punycode
428 end
429 else begin
430 (* No ACE prefix - validate and return *)
431 if is_ascii_string label then label
432 else
433 (* Has non-ASCII but no ACE prefix - return as-is *)
434 label
435 end