Punycode (RFC3492) in OCaml
at main 460 lines 14 kB view raw
1(*--------------------------------------------------------------------------- 2 Copyright (c) 2025 Anil Madhavapeddy <anil@recoil.org>. All rights reserved. 3 SPDX-License-Identifier: ISC 4 ---------------------------------------------------------------------------*) 5 6(* RFC 3492 Punycode Implementation *) 7 8(* {1 Bootstring Parameters for Punycode (RFC 3492 Section 5)} *) 9 10let base = 36 11let tmin = 1 12let tmax = 26 13let skew = 38 14let damp = 700 15let initial_bias = 72 16let initial_n = 0x80 (* 128 *) 17let delimiter = '-' 18let ace_prefix = "xn--" 19let max_label_length = 63 20 21(* {1 Position Tracking} *) 22 23type position = { byte_offset : int; char_index : int } 24 25let position_byte_offset pos = pos.byte_offset 26let position_char_index pos = pos.char_index 27 28let pp_position fmt pos = 29 Format.fprintf fmt "byte %d, char %d" pos.byte_offset pos.char_index 30 31(* {1 Error Types} *) 32 33type error_reason = 34 | Overflow of position 35 | Invalid_character of position * Uchar.t 36 | Invalid_digit of position * char 37 | Unexpected_end of position 38 | Invalid_utf8 of position 39 | Label_too_long of int 40 | Empty_label 41 42let pp_error_reason fmt = function 43 | Overflow pos -> 44 Format.fprintf fmt "arithmetic overflow at %a" pp_position pos 45 | Invalid_character (pos, u) -> 46 Format.fprintf fmt "invalid character U+%04X at %a" (Uchar.to_int u) 47 pp_position pos 48 | Invalid_digit (pos, c) -> 49 Format.fprintf fmt "invalid Punycode digit '%c' (0x%02X) at %a" c 50 (Char.code c) pp_position pos 51 | Unexpected_end pos -> 52 Format.fprintf fmt "unexpected end of input at %a" pp_position pos 53 | Invalid_utf8 pos -> 54 Format.fprintf fmt "invalid UTF-8 sequence at %a" pp_position pos 55 | Label_too_long len -> 56 Format.fprintf fmt "label too long: %d bytes (max %d)" len 57 max_label_length 58 | Empty_label -> Format.fprintf fmt "empty label" 59 60exception Error of error_reason 61 62let () = Printexc.register_printer (function 63 | Error reason -> Some (Format.asprintf "Punycode.Error: %a" pp_error_reason reason) 64 | _ -> None) 65 66let error_reason_to_string reason = Format.asprintf "%a" pp_error_reason reason 67 68(* {1 Error Constructors} *) 69 70let overflow pos = raise (Error (Overflow pos)) 71let invalid_character pos u = raise (Error (Invalid_character (pos, u))) 72let invalid_digit pos c = raise (Error (Invalid_digit (pos, c))) 73let unexpected_end pos = raise (Error (Unexpected_end pos)) 74let invalid_utf8 pos = raise (Error (Invalid_utf8 pos)) 75let label_too_long len = raise (Error (Label_too_long len)) 76let empty_label () = raise (Error Empty_label) 77 78(* {1 Case Flags} *) 79 80type case_flag = Uppercase | Lowercase 81 82(* {1 Basic Predicates} *) 83 84let is_basic u = Uchar.to_int u < 0x80 85let is_ascii_string s = String.for_all (fun c -> Char.code c < 0x80) s 86 87let has_ace_prefix s = 88 let len = String.length s in 89 len >= 4 90 && (s.[0] = 'x' || s.[0] = 'X') 91 && (s.[1] = 'n' || s.[1] = 'N') 92 && s.[2] = '-' 93 && s.[3] = '-' 94 95(* {1 Digit Encoding/Decoding (RFC 3492 Section 5)} 96 97 Digit values: 98 - 0-25: a-z (or A-Z) 99 - 26-35: 0-9 100*) 101 102let encode_digit d case_flag = 103 if d < 26 then Char.chr (d + if case_flag = Uppercase then 0x41 else 0x61) 104 else Char.chr (d - 26 + 0x30) 105 106let decode_digit c = 107 let code = Char.code c in 108 if code >= 0x30 && code <= 0x39 then Some (code - 0x30 + 26) 109 (* '0'-'9' -> 26-35 *) 110 else if code >= 0x41 && code <= 0x5A then Some (code - 0x41) 111 (* 'A'-'Z' -> 0-25 *) 112 else if code >= 0x61 && code <= 0x7A then Some (code - 0x61) 113 (* 'a'-'z' -> 0-25 *) 114 else None 115 116(* Check if a character is "flagged" (uppercase) for case annotation *) 117let is_flagged c = 118 let code = Char.code c in 119 code >= 0x41 && code <= 0x5A (* 'A'-'Z' *) 120 121(* {1 Bias Adaptation (RFC 3492 Section 6.1)} *) 122 123let adapt ~delta ~numpoints ~firsttime = 124 let delta = if firsttime then delta / damp else delta / 2 in 125 let delta = delta + (delta / numpoints) in 126 let threshold = (base - tmin) * tmax / 2 in 127 let rec loop delta k = 128 if delta > threshold then loop (delta / (base - tmin)) (k + base) 129 else k + ((base - tmin + 1) * delta / (delta + skew)) 130 in 131 loop delta 0 132 133(* {1 Overflow-Safe Arithmetic} 134 135 RFC 3492 Section 6.4: Use detection to avoid overflow. 136 A + B overflows iff B > maxint - A 137 A + B*C overflows iff B > (maxint - A) / C 138*) 139 140let max_int_value = max_int 141 142let safe_mul_add a b c pos = 143 if c = 0 then a 144 else if b > (max_int_value - a) / c then overflow pos 145 else a + (b * c) 146 147(* {1 UTF-8 to Code Points Conversion} *) 148 149let utf8_to_codepoints s = 150 let len = String.length s in 151 let acc = ref [] in 152 let byte_offset = ref 0 in 153 let char_index = ref 0 in 154 while !byte_offset < len do 155 let pos = { byte_offset = !byte_offset; char_index = !char_index } in 156 let dec = String.get_utf_8_uchar s !byte_offset in 157 if Uchar.utf_decode_is_valid dec then begin 158 acc := Uchar.utf_decode_uchar dec :: !acc; 159 byte_offset := !byte_offset + Uchar.utf_decode_length dec; 160 incr char_index 161 end 162 else invalid_utf8 pos 163 done; 164 Array.of_list (List.rev !acc) 165 166(* {1 Code Points to UTF-8 Conversion} *) 167 168let codepoints_to_utf8 codepoints = 169 let buf = Buffer.create (Array.length codepoints * 2) in 170 Array.iter (Buffer.add_utf_8_uchar buf) codepoints; 171 Buffer.contents buf 172 173(* {1 Punycode Encoding (RFC 3492 Section 6.3)} *) 174 175let encode_impl codepoints case_flags = 176 let input_length = Array.length codepoints in 177 if input_length = 0 then "" 178 else begin 179 let output = Buffer.create (input_length * 2) in 180 181 (* Copy basic code points to output *) 182 let basic_count = ref 0 in 183 for j = 0 to input_length - 1 do 184 let cp = codepoints.(j) in 185 if is_basic cp then begin 186 let c = Uchar.to_int cp in 187 let case = 188 match case_flags with Some flags -> flags.(j) | None -> Lowercase 189 in 190 (* Preserve or apply case for ASCII letters *) 191 let c' = 192 if c >= 0x41 && c <= 0x5A then (* 'A'-'Z' *) 193 if case = Lowercase then c + 0x20 else c 194 else if c >= 0x61 && c <= 0x7A then (* 'a'-'z' *) 195 if case = Uppercase then c - 0x20 else c 196 else c 197 in 198 Buffer.add_char output (Char.chr c'); 199 incr basic_count 200 end 201 done; 202 203 let b = !basic_count in 204 let h = ref b in 205 206 (* Add delimiter if there were basic code points *) 207 if b > 0 then Buffer.add_char output delimiter; 208 209 (* Main encoding loop *) 210 let n = ref initial_n in 211 let delta = ref 0 in 212 let bias = ref initial_bias in 213 214 while !h < input_length do 215 (* Find minimum code point >= n *) 216 let m = 217 Array.fold_left 218 (fun acc cp -> 219 let cp_val = Uchar.to_int cp in 220 if cp_val >= !n && cp_val < acc then cp_val else acc) 221 max_int_value codepoints 222 in 223 224 (* Increase delta to advance state to <m, 0> *) 225 let pos = { byte_offset = 0; char_index = !h } in 226 delta := safe_mul_add !delta (m - !n) (!h + 1) pos; 227 n := m; 228 229 (* Process each code point *) 230 for j = 0 to input_length - 1 do 231 let cp = Uchar.to_int codepoints.(j) in 232 let pos = { byte_offset = 0; char_index = j } in 233 234 if cp < !n then begin 235 incr delta; 236 if !delta = 0 then (* Overflow *) 237 overflow pos 238 end 239 else if cp = !n then begin 240 (* Encode delta as variable-length integer *) 241 let q = ref !delta in 242 let k = ref base in 243 let done_encoding = ref false in 244 245 while not !done_encoding do 246 let t = 247 if !k <= !bias then tmin 248 else if !k >= !bias + tmax then tmax 249 else !k - !bias 250 in 251 if !q < t then begin 252 (* Output final digit *) 253 let case = 254 match case_flags with 255 | Some flags -> flags.(j) 256 | None -> Lowercase 257 in 258 Buffer.add_char output (encode_digit !q case); 259 done_encoding := true 260 end 261 else begin 262 (* Output intermediate digit and continue *) 263 let digit = t + ((!q - t) mod (base - t)) in 264 Buffer.add_char output (encode_digit digit Lowercase); 265 q := (!q - t) / (base - t); 266 k := !k + base 267 end 268 done; 269 270 bias := adapt ~delta:!delta ~numpoints:(!h + 1) ~firsttime:(!h = b); 271 delta := 0; 272 incr h 273 end 274 done; 275 276 incr delta; 277 incr n 278 done; 279 280 Buffer.contents output 281 end 282 283let encode codepoints = encode_impl codepoints None 284 285let encode_with_case codepoints case_flags = 286 if Array.length codepoints <> Array.length case_flags then 287 invalid_arg "encode_with_case: array lengths must match"; 288 encode_impl codepoints (Some case_flags) 289 290(* {1 Punycode Decoding (RFC 3492 Section 6.2)} *) 291 292let decode_impl input = 293 let input_length = String.length input in 294 if input_length = 0 then ([||], [||]) 295 else begin 296 (* Find last delimiter *) 297 let b = Option.value ~default:0 (String.rindex_opt input delimiter) in 298 299 (* Copy basic code points and extract case flags *) 300 let output = ref [] in 301 let case_output = ref [] in 302 303 for j = 0 to b - 1 do 304 let c = input.[j] in 305 let pos = { byte_offset = j; char_index = j } in 306 let code = Char.code c in 307 if code >= 0x80 then 308 invalid_character pos (Uchar.of_int code) 309 else begin 310 output := Uchar.of_int code :: !output; 311 case_output := 312 (if is_flagged c then Uppercase else Lowercase) :: !case_output 313 end 314 done; 315 316 let output = ref (Array.of_list (List.rev !output)) in 317 let case_output = ref (Array.of_list (List.rev !case_output)) in 318 319 (* Main decoding loop *) 320 let n = ref initial_n in 321 let i = ref 0 in 322 let bias = ref initial_bias in 323 let in_pos = ref (if b > 0 then b + 1 else 0) in 324 325 while !in_pos < input_length do 326 let oldi = !i in 327 let w = ref 1 in 328 let k = ref base in 329 let done_decoding = ref false in 330 331 while not !done_decoding do 332 let pos = 333 { byte_offset = !in_pos; char_index = Array.length !output } 334 in 335 336 if !in_pos >= input_length then 337 unexpected_end pos 338 else begin 339 let c = input.[!in_pos] in 340 incr in_pos; 341 342 match decode_digit c with 343 | None -> invalid_digit pos c 344 | Some digit -> 345 (* i = i + digit * w, with overflow check *) 346 i := safe_mul_add !i digit !w pos; 347 348 let t = 349 if !k <= !bias then tmin 350 else if !k >= !bias + tmax then tmax 351 else !k - !bias 352 in 353 354 if digit < t then 355 (* Record case flag from this final digit *) 356 done_decoding := true 357 else begin 358 (* w = w * (base - t), with overflow check *) 359 let base_minus_t = base - t in 360 if !w > max_int_value / base_minus_t then 361 overflow pos 362 else begin 363 w := !w * base_minus_t; 364 k := !k + base 365 end 366 end 367 end 368 done; 369 370 let out_len = Array.length !output in 371 bias := 372 adapt ~delta:(!i - oldi) ~numpoints:(out_len + 1) 373 ~firsttime:(oldi = 0); 374 375 let pos = { byte_offset = !in_pos - 1; char_index = out_len } in 376 377 (* n = n + i / (out_len + 1), with overflow check *) 378 let increment = !i / (out_len + 1) in 379 if increment > max_int_value - !n then overflow pos 380 else begin 381 n := !n + increment; 382 i := !i mod (out_len + 1); 383 384 (* Validate that n is a valid Unicode scalar value *) 385 if not (Uchar.is_valid !n) then 386 invalid_character pos Uchar.rep 387 else begin 388 (* Insert n at position i *) 389 let new_output = Array.make (out_len + 1) (Uchar.of_int 0) in 390 let new_case = Array.make (out_len + 1) Lowercase in 391 392 for j = 0 to !i - 1 do 393 new_output.(j) <- !output.(j); 394 new_case.(j) <- !case_output.(j) 395 done; 396 new_output.(!i) <- Uchar.of_int !n; 397 (* Case flag from final digit of this delta *) 398 new_case.(!i) <- 399 (if !in_pos > 0 && is_flagged input.[!in_pos - 1] then 400 Uppercase 401 else Lowercase); 402 for j = !i to out_len - 1 do 403 new_output.(j + 1) <- !output.(j); 404 new_case.(j + 1) <- !case_output.(j) 405 done; 406 407 output := new_output; 408 case_output := new_case; 409 incr i 410 end 411 end 412 done; 413 414 (!output, !case_output) 415 end 416 417let decode input = fst (decode_impl input) 418let decode_with_case input = decode_impl input 419 420(* {1 UTF-8 String Operations} *) 421 422let encode_utf8 s = 423 let codepoints = utf8_to_codepoints s in 424 encode codepoints 425 426let decode_utf8 punycode = 427 let codepoints = decode punycode in 428 codepoints_to_utf8 codepoints 429 430(* {1 Domain Label Operations} *) 431 432let encode_label label = 433 if String.length label = 0 then empty_label () 434 else if is_ascii_string label then begin 435 (* All ASCII - return as-is, but check length *) 436 let len = String.length label in 437 if len > max_label_length then label_too_long len else label 438 end 439 else begin 440 (* Has non-ASCII - encode with Punycode *) 441 let encoded = encode_utf8 label in 442 let result = ace_prefix ^ encoded in 443 let len = String.length result in 444 if len > max_label_length then label_too_long len else result 445 end 446 447let decode_label label = 448 if String.length label = 0 then empty_label () 449 else if has_ace_prefix label then begin 450 (* Remove ACE prefix and decode *) 451 let punycode = String.sub label 4 (String.length label - 4) in 452 decode_utf8 punycode 453 end 454 else begin 455 (* No ACE prefix - validate and return *) 456 if is_ascii_string label then label 457 else 458 (* Has non-ASCII but no ACE prefix - return as-is *) 459 label 460 end