Punycode (RFC3492) in OCaml
at main 435 lines 13 kB view raw
1(*--------------------------------------------------------------------------- 2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 SPDX-License-Identifier: ISC 4 ---------------------------------------------------------------------------*) 5 6(* RFC 3492 Punycode Implementation *) 7 8(* {1 Bootstring Parameters for Punycode (RFC 3492 Section 5)} *) 9 10let base = 36 11let tmin = 1 12let tmax = 26 13let skew = 38 14let damp = 700 15let initial_bias = 72 16let initial_n = 0x80 (* 128 *) 17let delimiter = '-' 18let ace_prefix = "xn--" 19let max_label_length = 63 20 21(* {1 Position Tracking} *) 22 23type position = { byte_offset : int; char_index : int } 24 25let position_byte_offset pos = pos.byte_offset 26let position_char_index pos = pos.char_index 27 28let pp_position fmt pos = 29 Fmt.pf fmt "byte %d, char %d" pos.byte_offset pos.char_index 30 31(* {1 Error Types} *) 32 33type error_reason = 34 | Overflow of position 35 | Invalid_character of position * Uchar.t 36 | Invalid_digit of position * char 37 | Unexpected_end of position 38 | Invalid_utf8 of position 39 | Label_too_long of int 40 | Empty_label 41 42let pp_error_reason fmt = function 43 | Overflow pos -> Fmt.pf fmt "arithmetic overflow at %a" pp_position pos 44 | Invalid_character (pos, u) -> 45 Fmt.pf fmt "invalid character U+%04X at %a" (Uchar.to_int u) pp_position 46 pos 47 | Invalid_digit (pos, c) -> 48 Fmt.pf fmt "invalid Punycode digit '%c' (0x%02X) at %a" c (Char.code c) 49 pp_position pos 50 | Unexpected_end pos -> 51 Fmt.pf fmt "unexpected end of input at %a" pp_position pos 52 | Invalid_utf8 pos -> 53 Fmt.pf fmt "invalid UTF-8 sequence at %a" pp_position pos 54 | Label_too_long len -> 55 Fmt.pf fmt "label too long: %d bytes (max %d)" len max_label_length 56 | Empty_label -> Fmt.pf fmt "empty label" 57 58exception Error of error_reason 59 60let () = 61 Printexc.register_printer (function 62 | Error reason -> Some (Fmt.str "Punycode.Error: %a" pp_error_reason reason) 63 | _ -> None) 64 65let error_reason_to_string reason = Fmt.str "%a" pp_error_reason reason 66 67(* {1 Error Constructors} *) 68 69let overflow pos = raise (Error (Overflow pos)) 70let invalid_character pos u = raise (Error (Invalid_character (pos, u))) 71let invalid_digit pos c = raise (Error (Invalid_digit (pos, c))) 72let unexpected_end pos = raise (Error (Unexpected_end pos)) 73let invalid_utf8 pos = raise (Error (Invalid_utf8 pos)) 74let label_too_long len = raise (Error (Label_too_long len)) 75let empty_label () = raise (Error Empty_label) 76 77(* {1 Case Flags} *) 78 79type case_flag = Uppercase | Lowercase 80 81(* {1 Basic Predicates} *) 82 83let is_basic u = Uchar.to_int u < 0x80 84let is_ascii_string s = String.for_all (fun c -> Char.code c < 0x80) s 85 86let has_ace_prefix s = 87 let len = String.length s in 88 len >= 4 89 && (s.[0] = 'x' || s.[0] = 'X') 90 && (s.[1] = 'n' || s.[1] = 'N') 91 && s.[2] = '-' 92 && s.[3] = '-' 93 94(* {1 Digit Encoding/Decoding (RFC 3492 Section 5)} 95 96 Digit values: 97 - 0-25: a-z (or A-Z) 98 - 26-35: 0-9 99*) 100 101let encode_digit d case_flag = 102 if d < 26 then Char.chr (d + if case_flag = Uppercase then 0x41 else 0x61) 103 else Char.chr (d - 26 + 0x30) 104 105let decode_digit c = 106 let code = Char.code c in 107 if code >= 0x30 && code <= 0x39 then Some (code - 0x30 + 26) 108 (* '0'-'9' -> 26-35 *) 109 else if code >= 0x41 && code <= 0x5A then Some (code - 0x41) 110 (* 'A'-'Z' -> 0-25 *) 111 else if code >= 0x61 && code <= 0x7A then Some (code - 0x61) 112 (* 'a'-'z' -> 0-25 *) 113 else None 114 115(* Check if a character is "flagged" (uppercase) for case annotation *) 116let is_flagged c = 117 let code = Char.code c in 118 code >= 0x41 && code <= 0x5A (* 'A'-'Z' *) 119 120(* {1 Bias Adaptation (RFC 3492 Section 6.1)} *) 121 122let adapt ~delta ~numpoints ~firsttime = 123 let delta = if firsttime then delta / damp else delta / 2 in 124 let delta = delta + (delta / numpoints) in 125 let threshold = (base - tmin) * tmax / 2 in 126 let rec loop delta k = 127 if delta > threshold then loop (delta / (base - tmin)) (k + base) 128 else k + ((base - tmin + 1) * delta / (delta + skew)) 129 in 130 loop delta 0 131 132(* {1 Overflow-Safe Arithmetic} 133 134 RFC 3492 Section 6.4: Use detection to avoid overflow. 135 A + B overflows iff B > maxint - A 136 A + B*C overflows iff B > (maxint - A) / C 137*) 138 139let max_int_value = max_int 140 141let safe_mul_add a b c pos = 142 if c = 0 then a 143 else if b > (max_int_value - a) / c then overflow pos 144 else a + (b * c) 145 146(* {1 UTF-8 to Code Points Conversion} *) 147 148let utf8_to_codepoints s = 149 let len = String.length s in 150 let acc = ref [] in 151 let byte_offset = ref 0 in 152 let char_index = ref 0 in 153 while !byte_offset < len do 154 let pos = { byte_offset = !byte_offset; char_index = !char_index } in 155 let dec = String.get_utf_8_uchar s !byte_offset in 156 if Uchar.utf_decode_is_valid dec then begin 157 acc := Uchar.utf_decode_uchar dec :: !acc; 158 byte_offset := !byte_offset + Uchar.utf_decode_length dec; 159 incr char_index 160 end 161 else invalid_utf8 pos 162 done; 163 Array.of_list (List.rev !acc) 164 165(* {1 Code Points to UTF-8 Conversion} *) 166 167let codepoints_to_utf8 codepoints = 168 let buf = Buffer.create (Array.length codepoints * 2) in 169 Array.iter (Buffer.add_utf_8_uchar buf) codepoints; 170 Buffer.contents buf 171 172(* {1 Punycode Encoding (RFC 3492 Section 6.3)} *) 173 174let encode_basic_codepoints output codepoints case_flags = 175 let input_length = Array.length codepoints in 176 let basic_count = ref 0 in 177 for j = 0 to input_length - 1 do 178 let cp = codepoints.(j) in 179 if is_basic cp then begin 180 let c = Uchar.to_int cp in 181 let case = 182 match case_flags with Some flags -> flags.(j) | None -> Lowercase 183 in 184 let c' = 185 if c >= 0x41 && c <= 0x5A then if case = Lowercase then c + 0x20 else c 186 else if c >= 0x61 && c <= 0x7A then 187 if case = Uppercase then c - 0x20 else c 188 else c 189 in 190 Buffer.add_char output (Char.chr c'); 191 incr basic_count 192 end 193 done; 194 !basic_count 195 196let encode_variable_length_int output ~q ~bias ~case_flags ~j = 197 let q = ref q in 198 let k = ref base in 199 let done_encoding = ref false in 200 while not !done_encoding do 201 let t = 202 if !k <= bias then tmin else if !k >= bias + tmax then tmax else !k - bias 203 in 204 if !q < t then begin 205 let case = 206 match case_flags with Some flags -> flags.(j) | None -> Lowercase 207 in 208 Buffer.add_char output (encode_digit !q case); 209 done_encoding := true 210 end 211 else begin 212 let digit = t + ((!q - t) mod (base - t)) in 213 Buffer.add_char output (encode_digit digit Lowercase); 214 q := (!q - t) / (base - t); 215 k := !k + base 216 end 217 done 218 219let encode_impl codepoints case_flags = 220 let input_length = Array.length codepoints in 221 if input_length = 0 then "" 222 else begin 223 let output = Buffer.create (input_length * 2) in 224 225 let b = encode_basic_codepoints output codepoints case_flags in 226 let h = ref b in 227 228 if b > 0 then Buffer.add_char output delimiter; 229 230 let n = ref initial_n in 231 let delta = ref 0 in 232 let bias = ref initial_bias in 233 234 while !h < input_length do 235 let m = 236 Array.fold_left 237 (fun acc cp -> 238 let cp_val = Uchar.to_int cp in 239 if cp_val >= !n && cp_val < acc then cp_val else acc) 240 max_int_value codepoints 241 in 242 243 let pos = { byte_offset = 0; char_index = !h } in 244 delta := safe_mul_add !delta (m - !n) (!h + 1) pos; 245 n := m; 246 247 for j = 0 to input_length - 1 do 248 let cp = Uchar.to_int codepoints.(j) in 249 let pos = { byte_offset = 0; char_index = j } in 250 251 if cp < !n then begin 252 incr delta; 253 if !delta = 0 then overflow pos 254 end 255 else if cp = !n then begin 256 encode_variable_length_int output ~q:!delta ~bias:!bias ~case_flags ~j; 257 bias := adapt ~delta:!delta ~numpoints:(!h + 1) ~firsttime:(!h = b); 258 delta := 0; 259 incr h 260 end 261 done; 262 263 incr delta; 264 incr n 265 done; 266 267 Buffer.contents output 268 end 269 270let encode codepoints = encode_impl codepoints None 271 272let encode_with_case codepoints case_flags = 273 if Array.length codepoints <> Array.length case_flags then 274 invalid_arg "encode_with_case: array lengths must match"; 275 encode_impl codepoints (Some case_flags) 276 277(* {1 Punycode Decoding (RFC 3492 Section 6.2)} *) 278 279let decode_basic_codepoints input b = 280 let output = ref [] in 281 let case_output = ref [] in 282 for j = 0 to b - 1 do 283 let c = input.[j] in 284 let pos = { byte_offset = j; char_index = j } in 285 let code = Char.code c in 286 if code >= 0x80 then invalid_character pos (Uchar.of_int code) 287 else begin 288 output := Uchar.of_int code :: !output; 289 case_output := 290 (if is_flagged c then Uppercase else Lowercase) :: !case_output 291 end 292 done; 293 (Array.of_list (List.rev !output), Array.of_list (List.rev !case_output)) 294 295let decode_delta_sequence input ~input_length ~in_pos ~i ~bias ~output = 296 let oldi = !i in 297 let w = ref 1 in 298 let k = ref base in 299 let done_decoding = ref false in 300 while not !done_decoding do 301 let pos = { byte_offset = !in_pos; char_index = Array.length output } in 302 if !in_pos >= input_length then unexpected_end pos 303 else begin 304 let c = input.[!in_pos] in 305 incr in_pos; 306 match decode_digit c with 307 | None -> invalid_digit pos c 308 | Some digit -> 309 i := safe_mul_add !i digit !w pos; 310 let t = 311 if !k <= bias then tmin 312 else if !k >= bias + tmax then tmax 313 else !k - bias 314 in 315 if digit < t then done_decoding := true 316 else begin 317 let base_minus_t = base - t in 318 if !w > max_int_value / base_minus_t then overflow pos 319 else begin 320 w := !w * base_minus_t; 321 k := !k + base 322 end 323 end 324 end 325 done; 326 oldi 327 328let insert_codepoint ~output ~case_output ~n ~i ~in_pos input = 329 let out_len = Array.length !output in 330 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in 331 let new_case = Array.make (out_len + 1) Lowercase in 332 for j = 0 to i - 1 do 333 new_output.(j) <- !output.(j); 334 new_case.(j) <- !case_output.(j) 335 done; 336 new_output.(i) <- Uchar.of_int n; 337 new_case.(i) <- 338 (if in_pos > 0 && is_flagged input.[in_pos - 1] then Uppercase 339 else Lowercase); 340 for j = i to out_len - 1 do 341 new_output.(j + 1) <- !output.(j); 342 new_case.(j + 1) <- !case_output.(j) 343 done; 344 output := new_output; 345 case_output := new_case 346 347let decode_impl input = 348 let input_length = String.length input in 349 if input_length = 0 then ([||], [||]) 350 else begin 351 let b = Option.value ~default:0 (String.rindex_opt input delimiter) in 352 353 let init_output, init_case = decode_basic_codepoints input b in 354 let output = ref init_output in 355 let case_output = ref init_case in 356 357 let n = ref initial_n in 358 let i = ref 0 in 359 let bias = ref initial_bias in 360 let in_pos = ref (if b > 0 then b + 1 else 0) in 361 362 while !in_pos < input_length do 363 let oldi = 364 decode_delta_sequence input ~input_length ~in_pos ~i ~bias:!bias 365 ~output:!output 366 in 367 368 let out_len = Array.length !output in 369 bias := 370 adapt ~delta:(!i - oldi) ~numpoints:(out_len + 1) ~firsttime:(oldi = 0); 371 372 let pos = { byte_offset = !in_pos - 1; char_index = out_len } in 373 374 let increment = !i / (out_len + 1) in 375 if increment > max_int_value - !n then overflow pos 376 else begin 377 n := !n + increment; 378 i := !i mod (out_len + 1); 379 380 if not (Uchar.is_valid !n) then invalid_character pos Uchar.rep 381 else begin 382 insert_codepoint ~output ~case_output ~n:!n ~i:!i ~in_pos:!in_pos 383 input; 384 incr i 385 end 386 end 387 done; 388 389 (!output, !case_output) 390 end 391 392let decode input = fst (decode_impl input) 393let decode_with_case input = decode_impl input 394 395(* {1 UTF-8 String Operations} *) 396 397let encode_utf8 s = 398 let codepoints = utf8_to_codepoints s in 399 encode codepoints 400 401let decode_utf8 punycode = 402 let codepoints = decode punycode in 403 codepoints_to_utf8 codepoints 404 405(* {1 Domain Label Operations} *) 406 407let encode_label label = 408 if String.length label = 0 then empty_label () 409 else if is_ascii_string label then begin 410 (* All ASCII - return as-is, but check length *) 411 let len = String.length label in 412 if len > max_label_length then label_too_long len else label 413 end 414 else begin 415 (* Has non-ASCII - encode with Punycode *) 416 let encoded = encode_utf8 label in 417 let result = ace_prefix ^ encoded in 418 let len = String.length result in 419 if len > max_label_length then label_too_long len else result 420 end 421 422let decode_label label = 423 if String.length label = 0 then empty_label () 424 else if has_ace_prefix label then begin 425 (* Remove ACE prefix and decode *) 426 let punycode = String.sub label 4 (String.length label - 4) in 427 decode_utf8 punycode 428 end 429 else begin 430 (* No ACE prefix - validate and return *) 431 if is_ascii_string label then label 432 else 433 (* Has non-ASCII but no ACE prefix - return as-is *) 434 label 435 end