objective categorical abstract machine language personal data server
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

Hermes, an XRPC client for atproto

futur.blue cd8ea5c4 baa45f3d

verified
+5427 -7
+10 -7
README.md
··· 54 54 55 55 This repo contains several libraries in addition to the `pegasus` PDS: 56 56 57 - | library | description | 58 - | -------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 59 - | frontend | The PDS frontend, containing the admin dashboard and account page. | 60 - | ipld | A mostly [DASL-compliant](https://dasl.ing/) implementation of [CIDs](https://dasl.ing/cid.html), [CAR](https://dasl.ing/car.html), and [DAG-CBOR](https://dasl.ing/drisl.html). | 61 - | kleidos | An atproto-valid interface for secp256k1 and secp256r1 key management, signing/verifying, and encoding/decoding. | 62 - | mist | A [Merkle Search Tree](https://atproto.com/specs/repository#mst-structure) implementation for data repository purposes. | 63 - | pegasus | The PDS implementation. | 57 + | library | description | 58 + | ---------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | 59 + | frontend | The PDS frontend, containing the admin dashboard and account page. | 60 + | ipld | A mostly [DASL-compliant](https://dasl.ing/) implementation of [CIDs](https://dasl.ing/cid.html), [CAR](https://dasl.ing/car.html), and [DAG-CBOR](https://dasl.ing/drisl.html). | 61 + | kleidos | An atproto-valid interface for secp256k1 and secp256r1 key management, signing/verifying, and encoding/decoding. | 62 + | mist | A [Merkle Search Tree](https://atproto.com/specs/repository#mst-structure) implementation for data repository purposes. | 63 + | hermes | An XRPC client for atproto. | 64 + | hermes_ppx | A preprocessor for hermes, making API calls more ergonomic. | 65 + | hermes-cli | A CLI to generate OCaml types from atproto lexicons. | 66 + | pegasus | The PDS implementation. | 64 67 65 68 To start developing, you'll need: 66 69
+37
dune-project
··· 138 138 (multibase (>= 0.1.0)))) 139 139 140 140 (package 141 + (name hermes) 142 + (synopsis "Type-safe XRPC client for ATProto") 143 + (description "XRPC client with PPX extensions for type-safe API calls") 144 + (allow_empty) 145 + (depends 146 + (ocaml (= 5.2.1)) 147 + dune 148 + lwt 149 + (cohttp-lwt-unix (>= 6.1.1)) 150 + (uri (>= 4.4.0)) 151 + (yojson (>= 3.0.0)) 152 + (base64 (>= 3.5.0)) 153 + (lwt_ppx (>= 5.9.1)) 154 + (ppx_deriving_yojson (>= 3.9.1)))) 155 + 156 + (package 157 + (name hermes-cli) 158 + (synopsis "Code generator for Hermes from ATProto lexicons") 159 + (allow_empty) 160 + (depends 161 + (ocaml (= 5.2.1)) 162 + dune 163 + (cmdliner (>= 1.2.0)) 164 + (yojson (>= 3.0.0)) 165 + (fmt (>= 0.9.0)) 166 + (fpath (>= 0.7.3)))) 167 + 168 + (package 169 + (name hermes_ppx) 170 + (synopsis "PPX extension for Hermes XRPC calls") 171 + (allow_empty) 172 + (depends 173 + (ocaml (= 5.2.1)) 174 + dune 175 + (ppxlib (>= 0.32.0)))) 176 + 177 + (package 141 178 (name tailwindcss) 142 179 (allow_empty)) 143 180
+33
hermes-cli.opam
··· 1 + # This file is generated by dune, edit dune-project instead 2 + opam-version: "2.0" 3 + synopsis: "Code generator for Hermes from ATProto lexicons" 4 + maintainer: ["futurGH"] 5 + authors: ["futurGH"] 6 + license: "MPL-2.0" 7 + homepage: "https://github.com/futurGH/pegasus" 8 + bug-reports: "https://github.com/futurGH/pegasus/issues" 9 + depends: [ 10 + "ocaml" {= "5.2.1"} 11 + "dune" {>= "3.20"} 12 + "cmdliner" {>= "1.2.0"} 13 + "yojson" {>= "3.0.0"} 14 + "fmt" {>= "0.9.0"} 15 + "fpath" {>= "0.7.3"} 16 + "odoc" {with-doc} 17 + ] 18 + build: [ 19 + ["dune" "subst"] {dev} 20 + [ 21 + "dune" 22 + "build" 23 + "-p" 24 + name 25 + "-j" 26 + jobs 27 + "@install" 28 + "@runtest" {with-test} 29 + "@doc" {with-doc} 30 + ] 31 + ] 32 + dev-repo: "git+https://github.com/futurGH/pegasus.git" 33 + x-maintenance-intent: ["(latest)"]
+5
hermes-cli/bin/dune
··· 1 + (executable 2 + (name main) 3 + (public_name hermes-cli) 4 + (package hermes-cli) 5 + (libraries hermes_cli cmdliner))
+119
hermes-cli/bin/main.ml
··· 1 + open Hermes_cli 2 + 3 + (* recursively find all json files in a directory *) 4 + let find_json_files dir = 5 + let rec aux acc path = 6 + if Sys.is_directory path then 7 + Sys.readdir path |> Array.to_list 8 + |> List.map (Filename.concat path) 9 + |> List.fold_left aux acc 10 + else if Filename.check_suffix path ".json" then path :: acc 11 + else acc 12 + in 13 + aux [] dir 14 + 15 + (* generate module structure from lexicons *) 16 + let generate ~input_dir ~output_dir ~module_name = 17 + (* create output directory *) 18 + if not (Sys.file_exists output_dir) then Sys.mkdir output_dir 0o755 ; 19 + (* find all lexicon files *) 20 + let files = find_json_files input_dir in 21 + Printf.printf "Found %d lexicon files\n" (List.length files) ; 22 + (* parse all files *) 23 + let lexicons = 24 + List.filter_map 25 + (fun path -> 26 + match Parser.parse_file path with 27 + | Ok doc -> 28 + Printf.printf " Parsed: %s\n" doc.Lexicon_types.id ; 29 + Some doc 30 + | Error e -> 31 + Printf.eprintf " Error parsing %s: %s\n" path e ; 32 + None ) 33 + files 34 + in 35 + Printf.printf "Successfully parsed %d lexicons\n" (List.length lexicons) ; 36 + (* group by namespace, all but last segment *) 37 + let by_namespace = Hashtbl.create 64 in 38 + List.iter 39 + (fun doc -> 40 + let segments = String.split_on_char '.' doc.Lexicon_types.id in 41 + match List.rev segments with 42 + | _last :: rest -> 43 + let ns = String.concat "." (List.rev rest) in 44 + let existing = 45 + try Hashtbl.find by_namespace ns with Not_found -> [] 46 + in 47 + Hashtbl.replace by_namespace ns (doc :: existing) 48 + | [] -> 49 + () ) 50 + lexicons ; 51 + (* generate file for each lexicon *) 52 + List.iter 53 + (fun doc -> 54 + let code = Codegen.gen_lexicon_module doc in 55 + let rel_path = Naming.file_path_of_nsid doc.Lexicon_types.id in 56 + let full_path = Filename.concat output_dir rel_path in 57 + (* write file *) 58 + let oc = open_out full_path in 59 + output_string oc code ; 60 + close_out oc ; 61 + Printf.printf " Generated: %s\n" rel_path ) 62 + lexicons ; 63 + (* generate index file *) 64 + let index_path = 65 + Filename.concat output_dir (String.lowercase_ascii module_name ^ ".ml") 66 + in 67 + let oc = open_out index_path in 68 + Printf.fprintf oc "(* %s - generated from atproto lexicons *)\n\n" module_name ; 69 + (* export each lexicon as a module alias *) 70 + List.iter 71 + (fun doc -> 72 + let flat_module = Naming.flat_module_name_of_nsid doc.Lexicon_types.id in 73 + Printf.fprintf oc "module %s = %s\n" flat_module flat_module ) 74 + lexicons ; 75 + close_out oc ; 76 + Printf.printf "Generated index: %s\n" index_path ; 77 + (* generate dune file *) 78 + let dune_path = Filename.concat output_dir "dune" in 79 + let oc = open_out dune_path in 80 + Printf.fprintf oc "(library\n" ; 81 + Printf.fprintf oc " (name %s)\n" (String.lowercase_ascii module_name) ; 82 + Printf.fprintf oc " (libraries hermes yojson lwt)\n" ; 83 + Printf.fprintf oc " (preprocess (pps ppx_deriving_yojson)))\n" ; 84 + close_out oc ; 85 + Printf.printf "Generated dune file\n" ; 86 + Printf.printf "Done! Generated %d modules\n" (List.length lexicons) 87 + 88 + let input_dir = 89 + let doc = "directory containing lexicon JSON files" in 90 + Cmdliner.Arg.( 91 + required & opt (some dir) None & info ["i"; "input"] ~docv:"DIR" ~doc ) 92 + 93 + let output_dir = 94 + let doc = "output directory for generated code" in 95 + Cmdliner.Arg.( 96 + required & opt (some string) None & info ["o"; "output"] ~docv:"DIR" ~doc ) 97 + 98 + let module_name = 99 + let doc = "name of the generated module" in 100 + Cmdliner.Arg.( 101 + value 102 + & opt string "Hermes_lexicons" 103 + & info ["m"; "module-name"] ~docv:"NAME" ~doc ) 104 + 105 + let generate_cmd = 106 + let doc = "generate ocaml types from atproto lexicons" in 107 + let info = Cmdliner.Cmd.info "generate" ~doc in 108 + let generate' input_dir output_dir module_name = 109 + generate ~input_dir ~output_dir ~module_name 110 + in 111 + Cmdliner.Cmd.v info 112 + Cmdliner.Term.(const generate' $ input_dir $ output_dir $ module_name) 113 + 114 + let main_cmd = 115 + let doc = "hermes - atproto lexicon code generator" in 116 + let info = Cmdliner.Cmd.info "hermes-cli" ~version:"0.1.0" ~doc in 117 + Cmdliner.Cmd.group info [generate_cmd] 118 + 119 + let () = exit (Cmdliner.Cmd.eval main_cmd)
+826
hermes-cli/lib/codegen.ml
··· 1 + open Lexicon_types 2 + 3 + type output = 4 + { mutable imports: string list 5 + ; mutable generated_unions: string list 6 + ; buf: Buffer.t } 7 + 8 + let make_output () = {imports= []; generated_unions= []; buf= Buffer.create 4096} 9 + 10 + let add_import out module_name = 11 + if not (List.mem module_name out.imports) then 12 + out.imports <- module_name :: out.imports 13 + 14 + let mark_union_generated out union_name = 15 + if not (List.mem union_name out.generated_unions) then 16 + out.generated_unions <- union_name :: out.generated_unions 17 + 18 + let is_union_generated out union_name = List.mem union_name out.generated_unions 19 + 20 + let emit out s = Buffer.add_string out.buf s 21 + 22 + let emitln out s = 23 + Buffer.add_string out.buf s ; 24 + Buffer.add_char out.buf '\n' 25 + 26 + let emit_newline out = Buffer.add_char out.buf '\n' 27 + 28 + (* generate ocaml type for a primitive type *) 29 + let rec gen_type_ref nsid out (type_def : type_def) : string = 30 + match type_def with 31 + | String _ -> 32 + "string" 33 + | Integer {maximum; _} -> ( 34 + (* use int64 for large integers *) 35 + match maximum with 36 + | Some m when m > 1073741823 -> 37 + "int64" 38 + | _ -> 39 + "int" ) 40 + | Boolean _ -> 41 + "bool" 42 + | Bytes _ -> 43 + "bytes" 44 + | Blob _ -> 45 + "Hermes.blob" 46 + | CidLink _ -> 47 + "Cid.t" 48 + | Array {items; _} -> 49 + let item_type = gen_type_ref nsid out items in 50 + item_type ^ " list" 51 + | Object _ -> 52 + (* objects should be defined separately *) 53 + "object_todo" 54 + | Ref {ref_; _} -> 55 + gen_ref_type nsid out ref_ 56 + | Union {refs; _} -> 57 + (* generate inline union reference *) 58 + gen_union_type_name refs 59 + | Token _ -> 60 + "string" 61 + | Unknown _ -> 62 + "Yojson.Safe.t" 63 + | Query _ | Procedure _ | Subscription _ | Record _ -> 64 + "unit (* primary type *)" 65 + 66 + (* generate reference to another type *) 67 + and gen_ref_type _nsid out ref_str : string = 68 + if String.length ref_str > 0 && ref_str.[0] = '#' then begin 69 + (* local ref: #someDef -> someDef *) 70 + let def_name = String.sub ref_str 1 (String.length ref_str - 1) in 71 + Naming.type_name def_name 72 + end 73 + else begin 74 + (* external ref: com.example.defs#someDef *) 75 + match String.split_on_char '#' ref_str with 76 + | [ext_nsid; def_name] -> 77 + (* use flat module names for include_subdirs unqualified *) 78 + let flat_module = Naming.flat_module_name_of_nsid ext_nsid in 79 + add_import out flat_module ; 80 + flat_module ^ "." ^ Naming.type_name def_name 81 + | [ext_nsid] -> 82 + (* just nsid, refers to main def *) 83 + let flat_module = Naming.flat_module_name_of_nsid ext_nsid in 84 + add_import out flat_module ; flat_module ^ ".main" 85 + | _ -> 86 + "invalid_ref" 87 + end 88 + 89 + and gen_union_type_name refs = Naming.union_type_name refs 90 + 91 + (* generate full type uri for a ref *) 92 + let gen_type_uri nsid ref_str = 93 + if String.length ref_str > 0 && ref_str.[0] = '#' then 94 + (* local ref *) 95 + nsid ^ ref_str 96 + else 97 + (* external ref, use as-is *) 98 + ref_str 99 + 100 + (* collect inline union specs from object properties *) 101 + let rec collect_inline_unions acc type_def = 102 + match type_def with 103 + | Union spec -> 104 + (spec.refs, spec) :: acc 105 + | Array {items; _} -> 106 + collect_inline_unions acc items 107 + | _ -> 108 + acc 109 + 110 + let collect_inline_unions_from_properties properties = 111 + List.fold_left 112 + (fun acc (_, (prop : property)) -> collect_inline_unions acc prop.type_def) 113 + [] properties 114 + 115 + (* generate inline union types that appear in object properties *) 116 + let gen_inline_unions nsid out properties = 117 + let inline_unions = collect_inline_unions_from_properties properties in 118 + List.iter 119 + (fun (refs, spec) -> 120 + let type_name = Naming.union_type_name refs in 121 + (* skip if already generated *) 122 + if not (is_union_generated out type_name) then begin 123 + mark_union_generated out type_name ; 124 + let is_closed = Option.value spec.closed ~default:false in 125 + emitln out (Printf.sprintf "type %s =" type_name) ; 126 + List.iter 127 + (fun ref_str -> 128 + let variant_name = Naming.variant_name_of_ref ref_str in 129 + let payload_type = gen_ref_type nsid out ref_str in 130 + emitln out (Printf.sprintf " | %s of %s" variant_name payload_type) ) 131 + refs ; 132 + if not is_closed then emitln out " | Unknown of Yojson.Safe.t" ; 133 + emit_newline out ; 134 + (* generate of_yojson function *) 135 + emitln out (Printf.sprintf "let %s_of_yojson json =" type_name) ; 136 + emitln out " let open Yojson.Safe.Util in" ; 137 + emitln out " try" ; 138 + emitln out " match json |> member \"$type\" |> to_string with" ; 139 + List.iter 140 + (fun ref_str -> 141 + let variant_name = Naming.variant_name_of_ref ref_str in 142 + let full_type_uri = gen_type_uri nsid ref_str in 143 + let payload_type = gen_ref_type nsid out ref_str in 144 + emitln out (Printf.sprintf " | \"%s\" ->" full_type_uri) ; 145 + emitln out 146 + (Printf.sprintf " (match %s_of_yojson json with" 147 + payload_type ) ; 148 + emitln out 149 + (Printf.sprintf " | Ok v -> Ok (%s v)" variant_name) ; 150 + emitln out " | Error e -> Error e)" ) 151 + refs ; 152 + if is_closed then 153 + emitln out " | t -> Error (\"unknown union type: \" ^ t)" 154 + else emitln out " | _ -> Ok (Unknown json)" ; 155 + emitln out " with _ -> Error \"failed to parse union\"" ; 156 + emit_newline out ; 157 + (* generate to_yojson function *) 158 + emitln out (Printf.sprintf "let %s_to_yojson = function" type_name) ; 159 + List.iter 160 + (fun ref_str -> 161 + let variant_name = Naming.variant_name_of_ref ref_str in 162 + let payload_type = gen_ref_type nsid out ref_str in 163 + emitln out 164 + (Printf.sprintf " | %s v -> %s_to_yojson v" variant_name 165 + payload_type ) ) 166 + refs ; 167 + if not is_closed then emitln out " | Unknown j -> j" ; 168 + emit_newline out 169 + end ) 170 + inline_unions 171 + 172 + (* generate object type definition *) 173 + let gen_object_type nsid out name (spec : object_spec) = 174 + let required = Option.value spec.required ~default:[] in 175 + let nullable = Option.value spec.nullable ~default:[] in 176 + (* handle empty objects as unit *) 177 + if spec.properties = [] then begin 178 + emitln out (Printf.sprintf "type %s = unit" (Naming.type_name name)) ; 179 + emitln out 180 + (Printf.sprintf "let %s_of_yojson _ = Ok ()" (Naming.type_name name)) ; 181 + emitln out 182 + (Printf.sprintf "let %s_to_yojson () = `Assoc []" (Naming.type_name name)) ; 183 + emit_newline out 184 + end 185 + else begin 186 + (* generate inline union types first *) 187 + gen_inline_unions nsid out spec.properties ; 188 + emitln out (Printf.sprintf "type %s =" (Naming.type_name name)) ; 189 + emitln out " {" ; 190 + List.iter 191 + (fun (prop_name, (prop : property)) -> 192 + let ocaml_name = Naming.field_name prop_name in 193 + let base_type = gen_type_ref nsid out prop.type_def in 194 + let is_required = List.mem prop_name required in 195 + let is_nullable = List.mem prop_name nullable in 196 + let type_str = 197 + if is_required && not is_nullable then base_type 198 + else base_type ^ " option" 199 + in 200 + let key_attr = Naming.key_annotation prop_name ocaml_name in 201 + let default_attr = 202 + if is_required && not is_nullable then "" else " [@default None]" 203 + in 204 + emitln out 205 + (Printf.sprintf " %s: %s%s%s;" ocaml_name type_str key_attr 206 + default_attr ) ) 207 + spec.properties ; 208 + emitln out " }" ; 209 + emitln out "[@@deriving yojson {strict= false}]" ; 210 + emit_newline out 211 + end 212 + 213 + (* generate union type definition *) 214 + let gen_union_type nsid out name (spec : union_spec) = 215 + let type_name = Naming.type_name name in 216 + let is_closed = Option.value spec.closed ~default:false in 217 + emitln out (Printf.sprintf "type %s =" type_name) ; 218 + List.iter 219 + (fun ref_str -> 220 + let variant_name = Naming.variant_name_of_ref ref_str in 221 + let payload_type = gen_ref_type nsid out ref_str in 222 + emitln out (Printf.sprintf " | %s of %s" variant_name payload_type) ) 223 + spec.refs ; 224 + if not is_closed then emitln out " | Unknown of Yojson.Safe.t" ; 225 + emit_newline out ; 226 + (* generate of_yojson function *) 227 + emitln out (Printf.sprintf "let %s_of_yojson json =" type_name) ; 228 + emitln out " let open Yojson.Safe.Util in" ; 229 + emitln out " try" ; 230 + emitln out " match json |> member \"$type\" |> to_string with" ; 231 + List.iter 232 + (fun ref_str -> 233 + let variant_name = Naming.variant_name_of_ref ref_str in 234 + let full_type_uri = gen_type_uri nsid ref_str in 235 + let payload_type = gen_ref_type nsid out ref_str in 236 + emitln out (Printf.sprintf " | \"%s\" ->" full_type_uri) ; 237 + emitln out 238 + (Printf.sprintf " (match %s_of_yojson json with" payload_type) ; 239 + emitln out (Printf.sprintf " | Ok v -> Ok (%s v)" variant_name) ; 240 + emitln out " | Error e -> Error e)" ) 241 + spec.refs ; 242 + if is_closed then emitln out " | t -> Error (\"unknown union type: \" ^ t)" 243 + else emitln out " | _ -> Ok (Unknown json)" ; 244 + emitln out " with _ -> Error \"failed to parse union\"" ; 245 + emit_newline out ; 246 + (* generate to_yojson function *) 247 + emitln out (Printf.sprintf "let %s_to_yojson = function" type_name) ; 248 + List.iter 249 + (fun ref_str -> 250 + let variant_name = Naming.variant_name_of_ref ref_str in 251 + let payload_type = gen_ref_type nsid out ref_str in 252 + emitln out 253 + (Printf.sprintf " | %s v -> %s_to_yojson v" variant_name payload_type) ) 254 + spec.refs ; 255 + if not is_closed then emitln out " | Unknown j -> j" ; 256 + emit_newline out 257 + 258 + let is_json_encoding encoding = encoding = "application/json" || encoding = "" 259 + 260 + let is_bytes_encoding encoding = 261 + encoding <> "" && encoding <> "application/json" 262 + 263 + (* generate params type for query/procedure *) 264 + let gen_params_type nsid out (spec : params_spec) = 265 + let required = Option.value spec.required ~default:[] in 266 + emitln out "type params =" ; 267 + emitln out " {" ; 268 + List.iter 269 + (fun (prop_name, (prop : property)) -> 270 + let ocaml_name = Naming.field_name prop_name in 271 + let base_type = gen_type_ref nsid out prop.type_def in 272 + let is_required = List.mem prop_name required in 273 + let type_str = if is_required then base_type else base_type ^ " option" in 274 + let key_attr = Naming.key_annotation prop_name ocaml_name in 275 + let default_attr = if is_required then "" else " [@default None]" in 276 + emitln out 277 + (Printf.sprintf " %s: %s%s%s;" ocaml_name type_str key_attr 278 + default_attr ) ) 279 + spec.properties ; 280 + emitln out " }" ; 281 + emitln out "[@@deriving yojson {strict= false}]" ; 282 + emit_newline out 283 + 284 + (* generate output type for query/procedure *) 285 + let gen_output_type nsid out (body : body_def) = 286 + match body.schema with 287 + | Some (Object spec) -> 288 + (* handle empty objects as unit *) 289 + if spec.properties = [] then begin 290 + emitln out "type output = unit" ; 291 + emitln out "let output_of_yojson _ = Ok ()" ; 292 + emitln out "let output_to_yojson () = `Assoc []" ; 293 + emit_newline out 294 + end 295 + else begin 296 + (* generate inline union types first *) 297 + gen_inline_unions nsid out spec.properties ; 298 + let required = Option.value spec.required ~default:[] in 299 + let nullable = Option.value spec.nullable ~default:[] in 300 + emitln out "type output =" ; 301 + emitln out " {" ; 302 + List.iter 303 + (fun (prop_name, (prop : property)) -> 304 + let ocaml_name = Naming.field_name prop_name in 305 + let base_type = gen_type_ref nsid out prop.type_def in 306 + let is_required = List.mem prop_name required in 307 + let is_nullable = List.mem prop_name nullable in 308 + let type_str = 309 + if is_required && not is_nullable then base_type 310 + else base_type ^ " option" 311 + in 312 + let key_attr = Naming.key_annotation prop_name ocaml_name in 313 + let default_attr = 314 + if is_required && not is_nullable then "" else " [@default None]" 315 + in 316 + emitln out 317 + (Printf.sprintf " %s: %s%s%s;" ocaml_name type_str key_attr 318 + default_attr ) ) 319 + spec.properties ; 320 + emitln out " }" ; 321 + emitln out "[@@deriving yojson {strict= false}]" ; 322 + emit_newline out 323 + end 324 + | Some other_type -> 325 + let type_str = gen_type_ref nsid out other_type in 326 + emitln out (Printf.sprintf "type output = %s" type_str) ; 327 + emitln out "[@@deriving yojson {strict= false}]" ; 328 + emit_newline out 329 + | None -> 330 + emitln out "type output = unit" ; 331 + emitln out "let output_of_yojson _ = Ok ()" ; 332 + emitln out "let output_to_yojson () = `Null" ; 333 + emit_newline out 334 + 335 + (* generate query module *) 336 + let gen_query nsid out name (spec : query_spec) = 337 + (* check if output is bytes *) 338 + let output_is_bytes = 339 + match spec.output with 340 + | Some body -> 341 + is_bytes_encoding body.encoding 342 + | None -> 343 + false 344 + in 345 + emitln out 346 + (Printf.sprintf "(** %s *)" (Option.value spec.description ~default:name)) ; 347 + emitln out (Printf.sprintf "module %s = struct" (Naming.def_module_name name)) ; 348 + emitln out (Printf.sprintf " let nsid = \"%s\"" nsid) ; 349 + emit_newline out ; 350 + (* generate params type *) 351 + ( match spec.parameters with 352 + | Some params when params.properties <> [] -> 353 + emit out " " ; 354 + gen_params_type nsid out params 355 + | _ -> 356 + emitln out " type params = unit" ; 357 + emitln out " let params_to_yojson () = `Assoc []" ; 358 + emit_newline out ) ; 359 + (* generate output type *) 360 + ( if output_is_bytes then begin 361 + emitln out " (** Raw bytes output with content type *)" ; 362 + emitln out " type output = string * string" ; 363 + emit_newline out 364 + end 365 + else 366 + match spec.output with 367 + | Some body -> 368 + emit out " " ; 369 + gen_output_type nsid out body 370 + | None -> 371 + emitln out " type output = unit" ; 372 + emitln out " let output_of_yojson _ = Ok ()" ; 373 + emit_newline out ) ; 374 + (* generate call function *) 375 + emitln out " let call" ; 376 + ( match spec.parameters with 377 + | Some params when params.properties <> [] -> 378 + let required = Option.value params.required ~default:[] in 379 + List.iter 380 + (fun (prop_name, _) -> 381 + let ocaml_name = Naming.field_name prop_name in 382 + let is_required = List.mem prop_name required in 383 + if is_required then emitln out (Printf.sprintf " ~%s" ocaml_name) 384 + else emitln out (Printf.sprintf " ?%s" ocaml_name) ) 385 + params.properties 386 + | _ -> 387 + () ) ; 388 + emitln out " (client : Hermes.client) : output Lwt.t =" ; 389 + ( match spec.parameters with 390 + | Some params when params.properties <> [] -> 391 + emit out " let params : params = {" ; 392 + let fields = 393 + List.map 394 + (fun (prop_name, _) -> Naming.field_name prop_name) 395 + params.properties 396 + in 397 + emit out (String.concat "; " fields) ; 398 + emitln out "} in" ; 399 + if output_is_bytes then 400 + emitln out 401 + " Hermes.query_bytes client nsid (params_to_yojson params)" 402 + else 403 + emitln out 404 + " Hermes.query client nsid (params_to_yojson params) \ 405 + output_of_yojson" 406 + | _ -> 407 + if output_is_bytes then 408 + emitln out " Hermes.query_bytes client nsid (`Assoc [])" 409 + else 410 + emitln out " Hermes.query client nsid (`Assoc []) output_of_yojson" 411 + ) ; 412 + emitln out "end" ; emit_newline out 413 + 414 + (* generate procedure module *) 415 + let gen_procedure nsid out name (spec : procedure_spec) = 416 + (* check if input/output are bytes *) 417 + let input_is_bytes = 418 + match spec.input with 419 + | Some body -> 420 + is_bytes_encoding body.encoding 421 + | None -> 422 + false 423 + in 424 + let output_is_bytes = 425 + match spec.output with 426 + | Some body -> 427 + is_bytes_encoding body.encoding 428 + | None -> 429 + false 430 + in 431 + let input_content_type = 432 + match spec.input with 433 + | Some body when is_bytes_encoding body.encoding -> 434 + body.encoding 435 + | _ -> 436 + "application/json" 437 + in 438 + emitln out 439 + (Printf.sprintf "(** %s *)" (Option.value spec.description ~default:name)) ; 440 + emitln out (Printf.sprintf "module %s = struct" (Naming.def_module_name name)) ; 441 + emitln out (Printf.sprintf " let nsid = \"%s\"" nsid) ; 442 + emit_newline out ; 443 + (* generate params type *) 444 + ( match spec.parameters with 445 + | Some params when params.properties <> [] -> 446 + emit out " " ; 447 + gen_params_type nsid out params 448 + | _ -> 449 + emitln out " type params = unit" ; 450 + emitln out " let params_to_yojson () = `Assoc []" ; 451 + emit_newline out ) ; 452 + (* generate input type; only for json input with schema *) 453 + ( if not input_is_bytes then 454 + match spec.input with 455 + | Some body when body.schema <> None -> 456 + emit out " " ; 457 + ( match body.schema with 458 + | Some (Object spec) -> 459 + (* generate inline union types first *) 460 + gen_inline_unions nsid out spec.properties ; 461 + let required = Option.value spec.required ~default:[] in 462 + emitln out "type input =" ; 463 + emitln out " {" ; 464 + List.iter 465 + (fun (prop_name, (prop : property)) -> 466 + let ocaml_name = Naming.field_name prop_name in 467 + let base_type = gen_type_ref nsid out prop.type_def in 468 + let is_required = List.mem prop_name required in 469 + let type_str = 470 + if is_required then base_type else base_type ^ " option" 471 + in 472 + let key_attr = Naming.key_annotation prop_name ocaml_name in 473 + let default_attr = 474 + if is_required then "" else " [@default None]" 475 + in 476 + emitln out 477 + (Printf.sprintf " %s: %s%s%s;" ocaml_name type_str 478 + key_attr default_attr ) ) 479 + spec.properties ; 480 + emitln out " }" ; 481 + emitln out " [@@deriving yojson {strict= false}]" 482 + | Some other_type -> 483 + emitln out 484 + (Printf.sprintf "type input = %s" 485 + (gen_type_ref nsid out other_type) ) ; 486 + emitln out " [@@deriving yojson {strict= false}]" 487 + | None -> 488 + () ) ; 489 + emit_newline out 490 + | _ -> 491 + () ) ; 492 + (* generate output type *) 493 + ( if output_is_bytes then begin 494 + emitln out " (** Raw bytes output with content type *)" ; 495 + emitln out " type output = (string * string) option" ; 496 + emit_newline out 497 + end 498 + else 499 + match spec.output with 500 + | Some body -> 501 + emit out " " ; 502 + gen_output_type nsid out body 503 + | None -> 504 + emitln out " type output = unit" ; 505 + emitln out " let output_of_yojson _ = Ok ()" ; 506 + emit_newline out ) ; 507 + (* generate call function *) 508 + emitln out " let call" ; 509 + (* add labeled arguments for parameters *) 510 + ( match spec.parameters with 511 + | Some params when params.properties <> [] -> 512 + let required = Option.value params.required ~default:[] in 513 + List.iter 514 + (fun (prop_name, _) -> 515 + let ocaml_name = Naming.field_name prop_name in 516 + let is_required = List.mem prop_name required in 517 + if is_required then emitln out (Printf.sprintf " ~%s" ocaml_name) 518 + else emitln out (Printf.sprintf " ?%s" ocaml_name) ) 519 + params.properties 520 + | _ -> 521 + () ) ; 522 + (* add labeled arguments for input *) 523 + ( if input_is_bytes then 524 + (* for bytes input, take raw string *) 525 + emitln out " ?input" 526 + else 527 + match spec.input with 528 + | Some body -> ( 529 + match body.schema with 530 + | Some (Object obj_spec) -> 531 + let required = Option.value obj_spec.required ~default:[] in 532 + List.iter 533 + (fun (prop_name, _) -> 534 + let ocaml_name = Naming.field_name prop_name in 535 + let is_required = List.mem prop_name required in 536 + if is_required then 537 + emitln out (Printf.sprintf " ~%s" ocaml_name) 538 + else emitln out (Printf.sprintf " ?%s" ocaml_name) ) 539 + obj_spec.properties 540 + | Some _ -> 541 + (* non-object input, take as single argument *) 542 + emitln out " ~input" 543 + | None -> 544 + () ) 545 + | None -> 546 + () ) ; 547 + emitln out " (client : Hermes.client) : output Lwt.t =" ; 548 + (* build params record *) 549 + ( match spec.parameters with 550 + | Some params when params.properties <> [] -> 551 + emit out " let params = {" ; 552 + let fields = 553 + List.map 554 + (fun (prop_name, _) -> Naming.field_name prop_name) 555 + params.properties 556 + in 557 + emit out (String.concat "; " fields) ; 558 + emitln out "} in" 559 + | _ -> 560 + emitln out " let params = () in" ) ; 561 + (* generate the call based on input/output types *) 562 + if input_is_bytes then begin 563 + (* bytes input - choose between procedure_blob and procedure_bytes *) 564 + if output_is_bytes then 565 + (* bytes-in, bytes-out: use procedure_bytes *) 566 + emitln out 567 + (Printf.sprintf 568 + " Hermes.procedure_bytes client nsid (params_to_yojson params) \ 569 + input ~content_type:\"%s\"" 570 + input_content_type ) 571 + else if spec.output = None then 572 + (* bytes-in, no output: use procedure_bytes and map to unit *) 573 + emitln out 574 + (Printf.sprintf 575 + " let open Lwt.Syntax in\n\ 576 + \ let* _ = Hermes.procedure_bytes client nsid (params_to_yojson \ 577 + params) input ~content_type:\"%s\" in\n\ 578 + \ Lwt.return ()" 579 + input_content_type ) 580 + else 581 + (* bytes-in, json-out: use procedure_blob *) 582 + emitln out 583 + (Printf.sprintf 584 + " Hermes.procedure_blob client nsid (params_to_yojson params) \ 585 + (Bytes.of_string (Option.value input ~default:\"\")) \ 586 + ~content_type:\"%s\" output_of_yojson" 587 + input_content_type ) 588 + end 589 + else begin 590 + (* json input - build input and use procedure *) 591 + ( match spec.input with 592 + | Some body -> ( 593 + match body.schema with 594 + | Some (Object obj_spec) -> 595 + emit out " let input = Some ({" ; 596 + let fields = 597 + List.map 598 + (fun (prop_name, _) -> Naming.field_name prop_name) 599 + obj_spec.properties 600 + in 601 + emit out (String.concat "; " fields) ; 602 + emitln out "} |> input_to_yojson) in" 603 + | Some _ -> 604 + emitln out " let input = Some (input_to_yojson input) in" 605 + | None -> 606 + emitln out " let input = None in" ) 607 + | None -> 608 + emitln out " let input = None in" ) ; 609 + emitln out 610 + " Hermes.procedure client nsid (params_to_yojson params) input \ 611 + output_of_yojson" 612 + end ; 613 + emitln out "end" ; 614 + emit_newline out 615 + 616 + (* generate token constant *) 617 + let gen_token nsid out name (spec : token_spec) = 618 + let full_uri = nsid ^ "#" ^ name in 619 + emitln out 620 + (Printf.sprintf "(** %s *)" (Option.value spec.description ~default:name)) ; 621 + emitln out (Printf.sprintf "let %s = \"%s\"" (Naming.type_name name) full_uri) ; 622 + emit_newline out 623 + 624 + (* generate string type alias (for strings with knownValues) *) 625 + let gen_string_type _nsid out name (spec : string_spec) = 626 + let type_name = Naming.type_name name in 627 + emitln out 628 + (Printf.sprintf "(** String type with known values%s *)" 629 + (match spec.description with Some d -> ": " ^ d | None -> "") ) ; 630 + emitln out (Printf.sprintf "type %s = string" type_name) ; 631 + emitln out (Printf.sprintf "let %s_of_yojson = function" type_name) ; 632 + emitln out " | `String s -> Ok s" ; 633 + emitln out (Printf.sprintf " | _ -> Error \"%s: expected string\"" type_name) ; 634 + emitln out (Printf.sprintf "let %s_to_yojson s = `String s" type_name) ; 635 + emit_newline out 636 + 637 + (* collect local refs from a type definition *) 638 + let rec collect_local_refs acc = function 639 + | Array {items; _} -> 640 + collect_local_refs acc items 641 + | Ref {ref_; _} -> 642 + if String.length ref_ > 0 && ref_.[0] = '#' then 643 + let def_name = String.sub ref_ 1 (String.length ref_ - 1) in 644 + def_name :: acc 645 + else acc 646 + | Union {refs; _} -> 647 + List.fold_left 648 + (fun a r -> 649 + if String.length r > 0 && r.[0] = '#' then 650 + let def_name = String.sub r 1 (String.length r - 1) in 651 + def_name :: a 652 + else a ) 653 + acc refs 654 + | Object {properties; _} -> 655 + List.fold_left 656 + (fun a (_, (prop : property)) -> collect_local_refs a prop.type_def) 657 + acc properties 658 + | Record {record; _} -> 659 + List.fold_left 660 + (fun a (_, (prop : property)) -> collect_local_refs a prop.type_def) 661 + acc record.properties 662 + | Query {parameters; output; _} -> ( 663 + let acc = 664 + match parameters with 665 + | Some params -> 666 + List.fold_left 667 + (fun a (_, (prop : property)) -> 668 + collect_local_refs a prop.type_def ) 669 + acc params.properties 670 + | None -> 671 + acc 672 + in 673 + match output with 674 + | Some body -> 675 + Option.fold ~none:acc ~some:(collect_local_refs acc) body.schema 676 + | None -> 677 + acc ) 678 + | Procedure {parameters; input; output; _} -> ( 679 + let acc = 680 + match parameters with 681 + | Some params -> 682 + List.fold_left 683 + (fun a (_, (prop : property)) -> 684 + collect_local_refs a prop.type_def ) 685 + acc params.properties 686 + | None -> 687 + acc 688 + in 689 + let acc = 690 + match input with 691 + | Some body -> 692 + Option.fold ~none:acc ~some:(collect_local_refs acc) body.schema 693 + | None -> 694 + acc 695 + in 696 + match output with 697 + | Some body -> 698 + Option.fold ~none:acc ~some:(collect_local_refs acc) body.schema 699 + | None -> 700 + acc ) 701 + | _ -> 702 + acc 703 + 704 + (* sort definitions so dependencies come first *) 705 + let sort_definitions (defs : def_entry list) : def_entry list = 706 + (* build dependency map: name -> list of dependencies *) 707 + let deps = 708 + List.map (fun def -> (def.name, collect_local_refs [] def.type_def)) defs 709 + in 710 + (* create name -> def map *) 711 + let def_map = List.fold_left (fun m def -> (def.name, def) :: m) [] defs in 712 + (* topological sort *) 713 + let rec visit visited sorted name = 714 + if List.mem name visited then (visited, sorted) 715 + else 716 + let visited = name :: visited in 717 + let dep_names = try List.assoc name deps with Not_found -> [] in 718 + let visited, sorted = 719 + List.fold_left (fun (v, s) d -> visit v s d) (visited, sorted) dep_names 720 + in 721 + let sorted = 722 + match List.assoc_opt name def_map with 723 + | Some def -> 724 + def :: sorted 725 + | None -> 726 + sorted 727 + in 728 + (visited, sorted) 729 + in 730 + let _, sorted = 731 + List.fold_left (fun (v, s) def -> visit v s def.name) ([], []) defs 732 + in 733 + (* sorted is in reverse order, reverse it *) 734 + List.rev sorted 735 + 736 + (* generate complete lexicon module *) 737 + let gen_lexicon_module (doc : lexicon_doc) : string = 738 + let out = make_output () in 739 + let nsid = doc.id in 740 + (* header *) 741 + emitln out (Printf.sprintf "(* generated from %s *)" nsid) ; 742 + emit_newline out ; 743 + (* sort definitions by dependencies *) 744 + let sorted_defs = sort_definitions doc.defs in 745 + (* generate each definition *) 746 + List.iter 747 + (fun def -> 748 + match def.type_def with 749 + | Object spec -> 750 + gen_object_type nsid out def.name spec 751 + | Union spec -> 752 + gen_union_type nsid out def.name spec 753 + | Token spec -> 754 + gen_token nsid out def.name spec 755 + | Query spec -> 756 + gen_query nsid out def.name spec 757 + | Procedure spec -> 758 + gen_procedure nsid out def.name spec 759 + | Record spec -> 760 + (* generate record as object type *) 761 + gen_object_type nsid out def.name spec.record 762 + | String spec when spec.known_values <> None -> 763 + (* generate type alias for strings with known values *) 764 + gen_string_type nsid out def.name spec 765 + | String _ 766 + | Integer _ 767 + | Boolean _ 768 + | Bytes _ 769 + | Blob _ 770 + | CidLink _ 771 + | Array _ 772 + | Ref _ 773 + | Unknown _ 774 + | Subscription _ -> 775 + (* these are typically not standalone definitions *) 776 + () ) 777 + sorted_defs ; 778 + Buffer.contents out.buf 779 + 780 + (* get all imports needed for a lexicon *) 781 + let get_imports (doc : lexicon_doc) : string list = 782 + let out = make_output () in 783 + let nsid = doc.id in 784 + (* traverse all definitions to collect imports *) 785 + let rec collect_from_type = function 786 + | Array {items; _} -> 787 + collect_from_type items 788 + | Ref {ref_; _} -> 789 + let _ = gen_ref_type nsid out ref_ in 790 + () 791 + | Union {refs; _} -> 792 + List.iter 793 + (fun r -> 794 + let _ = gen_ref_type nsid out r in 795 + () ) 796 + refs 797 + | Object {properties; _} -> 798 + List.iter 799 + (fun (_, (prop : property)) -> collect_from_type prop.type_def) 800 + properties 801 + | Query {parameters; output; _} -> 802 + Option.iter 803 + (fun p -> 804 + List.iter 805 + (fun (_, (prop : property)) -> collect_from_type prop.type_def) 806 + p.properties ) 807 + parameters ; 808 + Option.iter (fun o -> Option.iter collect_from_type o.schema) output 809 + | Procedure {parameters; input; output; _} -> 810 + Option.iter 811 + (fun p -> 812 + List.iter 813 + (fun (_, (prop : property)) -> collect_from_type prop.type_def) 814 + p.properties ) 815 + parameters ; 816 + Option.iter (fun i -> Option.iter collect_from_type i.schema) input ; 817 + Option.iter (fun o -> Option.iter collect_from_type o.schema) output 818 + | Record {record; _} -> 819 + List.iter 820 + (fun (_, (prop : property)) -> collect_from_type prop.type_def) 821 + record.properties 822 + | _ -> 823 + () 824 + in 825 + List.iter (fun def -> collect_from_type def.type_def) doc.defs ; 826 + out.imports
+3
hermes-cli/lib/dune
··· 1 + (library 2 + (name hermes_cli) 3 + (libraries yojson fmt fpath str))
+118
hermes-cli/lib/lexicon_types.ml
··· 1 + type string_spec = 2 + { format: string option 3 + ; min_length: int option 4 + ; max_length: int option 5 + ; min_graphemes: int option 6 + ; max_graphemes: int option 7 + ; known_values: string list option 8 + ; enum: string list option 9 + ; const: string option 10 + ; default: string option 11 + ; description: string option } 12 + 13 + type integer_spec = 14 + { minimum: int option 15 + ; maximum: int option 16 + ; enum: int list option 17 + ; const: int option 18 + ; default: int option 19 + ; description: string option } 20 + 21 + type boolean_spec = 22 + {const: bool option; default: bool option; description: string option} 23 + 24 + type bytes_spec = 25 + {min_length: int option; max_length: int option; description: string option} 26 + 27 + type blob_spec = 28 + {accept: string list option; max_size: int option; description: string option} 29 + 30 + type cid_link_spec = {description: string option} 31 + 32 + type array_spec = 33 + { items: type_def 34 + ; min_length: int option 35 + ; max_length: int option 36 + ; description: string option } 37 + 38 + and property = {type_def: type_def; description: string option} 39 + 40 + and object_spec = 41 + { properties: (string * property) list 42 + ; required: string list option 43 + ; nullable: string list option 44 + ; description: string option } 45 + 46 + and ref_spec = 47 + { ref_: string (* e.g., "#localDef" or "com.example.defs#someDef" *) 48 + ; description: string option } 49 + 50 + and union_spec = 51 + {refs: string list; closed: bool option; description: string option} 52 + 53 + and token_spec = {description: string option} 54 + 55 + and unknown_spec = {description: string option} 56 + 57 + and params_spec = 58 + { properties: (string * property) list 59 + ; required: string list option 60 + ; description: string option } 61 + 62 + and body_def = 63 + {encoding: string; schema: type_def option; description: string option} 64 + 65 + and error_def = {name: string; description: string option} 66 + 67 + and query_spec = 68 + { parameters: params_spec option 69 + ; output: body_def option 70 + ; errors: error_def list option 71 + ; description: string option } 72 + 73 + and procedure_spec = 74 + { parameters: params_spec option 75 + ; input: body_def option 76 + ; output: body_def option 77 + ; errors: error_def list option 78 + ; description: string option } 79 + 80 + and subscription_spec = 81 + { parameters: params_spec option 82 + ; message: body_def option 83 + ; errors: error_def list option 84 + ; description: string option } 85 + 86 + and record_spec = 87 + { key: string (* "tid", "nsid", etc. *) 88 + ; record: object_spec 89 + ; description: string option } 90 + 91 + and type_def = 92 + | String of string_spec 93 + | Integer of integer_spec 94 + | Boolean of boolean_spec 95 + | Bytes of bytes_spec 96 + | Blob of blob_spec 97 + | CidLink of cid_link_spec 98 + | Array of array_spec 99 + | Object of object_spec 100 + | Ref of ref_spec 101 + | Union of union_spec 102 + | Token of token_spec 103 + | Unknown of unknown_spec 104 + | Query of query_spec 105 + | Procedure of procedure_spec 106 + | Subscription of subscription_spec 107 + | Record of record_spec 108 + 109 + type def_entry = {name: string; type_def: type_def} 110 + 111 + type lexicon_doc = 112 + { lexicon: int (* always 1 *) 113 + ; id: string (* nsid *) 114 + ; revision: int option 115 + ; description: string option 116 + ; defs: def_entry list } 117 + 118 + type parse_result = (lexicon_doc, string) result
+147
hermes-cli/lib/naming.ml
··· 1 + (* ocaml reserved keywords that need escaping *) 2 + let reserved_keywords = 3 + [ "and" 4 + ; "as" 5 + ; "assert" 6 + ; "asr" 7 + ; "begin" 8 + ; "class" 9 + ; "constraint" 10 + ; "do" 11 + ; "done" 12 + ; "downto" 13 + ; "else" 14 + ; "end" 15 + ; "exception" 16 + ; "external" 17 + ; "false" 18 + ; "for" 19 + ; "fun" 20 + ; "function" 21 + ; "functor" 22 + ; "if" 23 + ; "in" 24 + ; "include" 25 + ; "inherit" 26 + ; "initializer" 27 + ; "land" 28 + ; "lazy" 29 + ; "let" 30 + ; "lor" 31 + ; "lsl" 32 + ; "lsr" 33 + ; "lxor" 34 + ; "match" 35 + ; "method" 36 + ; "mod" 37 + ; "module" 38 + ; "mutable" 39 + ; "new" 40 + ; "nonrec" 41 + ; "object" 42 + ; "of" 43 + ; "open" 44 + ; "or" 45 + ; "private" 46 + ; "rec" 47 + ; "sig" 48 + ; "struct" 49 + ; "then" 50 + ; "to" 51 + ; "true" 52 + ; "try" 53 + ; "type" 54 + ; "val" 55 + ; "virtual" 56 + ; "when" 57 + ; "while" 58 + ; "with" ] 59 + 60 + let is_reserved name = List.mem (String.lowercase_ascii name) reserved_keywords 61 + 62 + (* convert camelCase to snake_case *) 63 + let camel_to_snake s = 64 + let buf = Buffer.create (String.length s * 2) in 65 + String.iteri 66 + (fun i c -> 67 + if Char.uppercase_ascii c = c && c <> Char.lowercase_ascii c then begin 68 + if i > 0 then Buffer.add_char buf '_' ; 69 + Buffer.add_char buf (Char.lowercase_ascii c) 70 + end 71 + else Buffer.add_char buf c ) 72 + s ; 73 + Buffer.contents buf 74 + 75 + let escape_keyword name = if is_reserved name then name ^ "_" else name 76 + 77 + let field_name name = escape_keyword (camel_to_snake name) 78 + 79 + let module_name_of_segment segment = 80 + if String.length segment = 0 then segment else String.capitalize_ascii segment 81 + 82 + let module_path_of_nsid nsid = 83 + String.split_on_char '.' nsid |> List.map module_name_of_segment 84 + 85 + let type_name_of_nsid nsid = 86 + let segments = String.split_on_char '.' nsid in 87 + match List.rev segments with 88 + | last :: _ -> 89 + camel_to_snake last 90 + | [] -> 91 + "unknown" 92 + 93 + let type_name name = escape_keyword (camel_to_snake name) 94 + 95 + let def_module_name name = String.capitalize_ascii name 96 + 97 + (* generate variant constructor name from ref *) 98 + let variant_name_of_ref ref_str = 99 + (* "#localDef" -> "LocalDef", "com.example.defs#someDef" -> "SomeDef" *) 100 + let name = 101 + match String.split_on_char '#' ref_str with 102 + | [_; def] -> 103 + def 104 + | [def] -> ( 105 + (* just nsid, use last segment *) 106 + match List.rev (String.split_on_char '.' def) with 107 + | last :: _ -> 108 + last 109 + | [] -> 110 + "Unknown" ) 111 + | _ -> 112 + "Unknown" 113 + in 114 + String.capitalize_ascii name 115 + 116 + let union_type_name refs = 117 + match refs with 118 + | [] -> 119 + "unknown_union" 120 + | [r] -> 121 + type_name (variant_name_of_ref r) 122 + | _ -> ( 123 + (* use first two refs to generate a name *) 124 + let names = List.map variant_name_of_ref refs in 125 + let sorted = List.sort String.compare names in 126 + match sorted with 127 + | a :: b :: _ -> 128 + camel_to_snake a ^ "_or_" ^ camel_to_snake b 129 + | [a] -> 130 + camel_to_snake a 131 + | [] -> 132 + "unknown_union" ) 133 + 134 + (* convert nsid to flat file path and module name *) 135 + let flat_name_of_nsid nsid = String.split_on_char '.' nsid |> String.concat "_" 136 + 137 + let file_path_of_nsid nsid = flat_name_of_nsid nsid ^ ".ml" 138 + 139 + let flat_module_name_of_nsid nsid = 140 + String.capitalize_ascii (flat_name_of_nsid nsid) 141 + 142 + let needs_key_annotation original_name ocaml_name = original_name <> ocaml_name 143 + 144 + let key_annotation original_name ocaml_name = 145 + if needs_key_annotation original_name ocaml_name then 146 + Printf.sprintf " [@key \"%s\"]" original_name 147 + else ""
+363
hermes-cli/lib/parser.ml
··· 1 + (* parse lexicon json files into lexicon_types *) 2 + 3 + open Lexicon_types 4 + 5 + let get_string_opt key json = 6 + match json with 7 + | `Assoc pairs -> ( 8 + match List.assoc_opt key pairs with Some (`String s) -> Some s | _ -> None ) 9 + | _ -> 10 + None 11 + 12 + let get_string key json = 13 + match get_string_opt key json with 14 + | Some s -> 15 + s 16 + | None -> 17 + failwith ("missing required string field: " ^ key) 18 + 19 + let get_int_opt key json = 20 + match json with 21 + | `Assoc pairs -> ( 22 + match List.assoc_opt key pairs with Some (`Int i) -> Some i | _ -> None ) 23 + | _ -> 24 + None 25 + 26 + let get_int key json = 27 + match get_int_opt key json with 28 + | Some i -> 29 + i 30 + | None -> 31 + failwith ("missing required int field: " ^ key) 32 + 33 + let get_bool_opt key json = 34 + match json with 35 + | `Assoc pairs -> ( 36 + match List.assoc_opt key pairs with Some (`Bool b) -> Some b | _ -> None ) 37 + | _ -> 38 + None 39 + 40 + let get_list_opt key json = 41 + match json with 42 + | `Assoc pairs -> ( 43 + match List.assoc_opt key pairs with Some (`List l) -> Some l | _ -> None ) 44 + | _ -> 45 + None 46 + 47 + let get_string_list_opt key json = 48 + match get_list_opt key json with 49 + | Some l -> 50 + Some (List.filter_map (function `String s -> Some s | _ -> None) l) 51 + | None -> 52 + None 53 + 54 + let get_int_list_opt key json = 55 + match get_list_opt key json with 56 + | Some l -> 57 + Some (List.filter_map (function `Int i -> Some i | _ -> None) l) 58 + | None -> 59 + None 60 + 61 + let get_assoc key json = 62 + match json with 63 + | `Assoc pairs -> ( 64 + match List.assoc_opt key pairs with 65 + | Some (`Assoc _ as a) -> 66 + Some a 67 + | _ -> 68 + None ) 69 + | _ -> 70 + None 71 + 72 + (* parse type definition from json *) 73 + let rec parse_type_def json : type_def = 74 + let type_str = get_string "type" json in 75 + match type_str with 76 + | "string" -> 77 + String 78 + { format= get_string_opt "format" json 79 + ; min_length= get_int_opt "minLength" json 80 + ; max_length= get_int_opt "maxLength" json 81 + ; min_graphemes= get_int_opt "minGraphemes" json 82 + ; max_graphemes= get_int_opt "maxGraphemes" json 83 + ; known_values= get_string_list_opt "knownValues" json 84 + ; enum= get_string_list_opt "enum" json 85 + ; const= get_string_opt "const" json 86 + ; default= get_string_opt "default" json 87 + ; description= get_string_opt "description" json } 88 + | "integer" -> 89 + Integer 90 + { minimum= get_int_opt "minimum" json 91 + ; maximum= get_int_opt "maximum" json 92 + ; enum= get_int_list_opt "enum" json 93 + ; const= get_int_opt "const" json 94 + ; default= get_int_opt "default" json 95 + ; description= get_string_opt "description" json } 96 + | "boolean" -> 97 + Boolean 98 + { const= get_bool_opt "const" json 99 + ; default= get_bool_opt "default" json 100 + ; description= get_string_opt "description" json } 101 + | "bytes" -> 102 + Bytes 103 + { min_length= get_int_opt "minLength" json 104 + ; max_length= get_int_opt "maxLength" json 105 + ; description= get_string_opt "description" json } 106 + | "blob" -> 107 + Blob 108 + { accept= get_string_list_opt "accept" json 109 + ; max_size= get_int_opt "maxSize" json 110 + ; description= get_string_opt "description" json } 111 + | "cid-link" -> 112 + CidLink {description= get_string_opt "description" json} 113 + | "array" -> 114 + let items_json = 115 + match get_assoc "items" json with 116 + | Some j -> 117 + j 118 + | None -> 119 + failwith "array type missing items" 120 + in 121 + Array 122 + { items= parse_type_def items_json 123 + ; min_length= get_int_opt "minLength" json 124 + ; max_length= get_int_opt "maxLength" json 125 + ; description= get_string_opt "description" json } 126 + | "object" -> 127 + Object (parse_object_spec json) 128 + | "ref" -> 129 + Ref 130 + { ref_= get_string "ref" json 131 + ; description= get_string_opt "description" json } 132 + | "union" -> 133 + Union 134 + { refs= 135 + ( match get_string_list_opt "refs" json with 136 + | Some l -> 137 + l 138 + | None -> 139 + [] ) 140 + ; closed= get_bool_opt "closed" json 141 + ; description= get_string_opt "description" json } 142 + | "token" -> 143 + Token {description= get_string_opt "description" json} 144 + | "unknown" -> 145 + Unknown {description= get_string_opt "description" json} 146 + | "query" -> 147 + Query (parse_query_spec json) 148 + | "procedure" -> 149 + Procedure (parse_procedure_spec json) 150 + | "subscription" -> 151 + Subscription (parse_subscription_spec json) 152 + | "record" -> 153 + Record (parse_record_spec json) 154 + | t -> 155 + failwith ("unknown type: " ^ t) 156 + 157 + and parse_object_spec json : object_spec = 158 + let properties = 159 + match get_assoc "properties" json with 160 + | Some (`Assoc pairs) -> 161 + List.map 162 + (fun (name, prop_json) -> 163 + let type_def = parse_type_def prop_json in 164 + let description = get_string_opt "description" prop_json in 165 + (name, {type_def; description}) ) 166 + pairs 167 + | _ -> 168 + [] 169 + in 170 + { properties 171 + ; required= get_string_list_opt "required" json 172 + ; nullable= get_string_list_opt "nullable" json 173 + ; description= get_string_opt "description" json } 174 + 175 + and parse_params_spec json : params_spec = 176 + let properties = 177 + match get_assoc "properties" json with 178 + | Some (`Assoc pairs) -> 179 + List.map 180 + (fun (name, prop_json) -> 181 + let type_def = parse_type_def prop_json in 182 + let description = get_string_opt "description" prop_json in 183 + (name, {type_def; description}) ) 184 + pairs 185 + | _ -> 186 + [] 187 + in 188 + { properties 189 + ; required= get_string_list_opt "required" json 190 + ; description= get_string_opt "description" json } 191 + 192 + and parse_body_def json : body_def = 193 + { encoding= get_string "encoding" json 194 + ; schema= 195 + ( match get_assoc "schema" json with 196 + | Some j -> 197 + Some (parse_type_def j) 198 + | None -> 199 + None ) 200 + ; description= get_string_opt "description" json } 201 + 202 + and parse_error_def json : error_def = 203 + {name= get_string "name" json; description= get_string_opt "description" json} 204 + 205 + and parse_query_spec json : query_spec = 206 + let parameters = 207 + match get_assoc "parameters" json with 208 + | Some j -> 209 + Some (parse_params_spec j) 210 + | None -> 211 + None 212 + in 213 + let output = 214 + match get_assoc "output" json with 215 + | Some j -> 216 + Some (parse_body_def j) 217 + | None -> 218 + None 219 + in 220 + let errors = 221 + match get_list_opt "errors" json with 222 + | Some l -> 223 + Some 224 + (List.map 225 + (function 226 + | `Assoc _ as j -> 227 + parse_error_def j 228 + | _ -> 229 + failwith "invalid error def" ) 230 + l ) 231 + | None -> 232 + None 233 + in 234 + {parameters; output; errors; description= get_string_opt "description" json} 235 + 236 + and parse_procedure_spec json : procedure_spec = 237 + let parameters = 238 + match get_assoc "parameters" json with 239 + | Some j -> 240 + Some (parse_params_spec j) 241 + | None -> 242 + None 243 + in 244 + let input = 245 + match get_assoc "input" json with 246 + | Some j -> 247 + Some (parse_body_def j) 248 + | None -> 249 + None 250 + in 251 + let output = 252 + match get_assoc "output" json with 253 + | Some j -> 254 + Some (parse_body_def j) 255 + | None -> 256 + None 257 + in 258 + let errors = 259 + match get_list_opt "errors" json with 260 + | Some l -> 261 + Some 262 + (List.map 263 + (function 264 + | `Assoc _ as j -> 265 + parse_error_def j 266 + | _ -> 267 + failwith "invalid error def" ) 268 + l ) 269 + | None -> 270 + None 271 + in 272 + { parameters 273 + ; input 274 + ; output 275 + ; errors 276 + ; description= get_string_opt "description" json } 277 + 278 + and parse_subscription_spec json : subscription_spec = 279 + let parameters = 280 + match get_assoc "parameters" json with 281 + | Some j -> 282 + Some (parse_params_spec j) 283 + | None -> 284 + None 285 + in 286 + let message = 287 + match get_assoc "message" json with 288 + | Some j -> 289 + Some (parse_body_def j) 290 + | None -> 291 + None 292 + in 293 + let errors = 294 + match get_list_opt "errors" json with 295 + | Some l -> 296 + Some 297 + (List.map 298 + (function 299 + | `Assoc _ as j -> 300 + parse_error_def j 301 + | _ -> 302 + failwith "invalid error def" ) 303 + l ) 304 + | None -> 305 + None 306 + in 307 + {parameters; message; errors; description= get_string_opt "description" json} 308 + 309 + and parse_record_spec json : record_spec = 310 + let key = get_string "key" json in 311 + let record_json = 312 + match get_assoc "record" json with 313 + | Some j -> 314 + j 315 + | None -> 316 + failwith "record type missing record field" 317 + in 318 + { key 319 + ; record= parse_object_spec record_json 320 + ; description= get_string_opt "description" json } 321 + 322 + (* parse complete lexicon document *) 323 + let parse_lexicon_doc json : lexicon_doc = 324 + let lexicon = get_int "lexicon" json in 325 + let id = get_string "id" json in 326 + let revision = get_int_opt "revision" json in 327 + let description = get_string_opt "description" json in 328 + let defs = 329 + match get_assoc "defs" json with 330 + | Some (`Assoc pairs) -> 331 + List.map 332 + (fun (name, def_json) -> {name; type_def= parse_type_def def_json}) 333 + pairs 334 + | _ -> 335 + [] 336 + in 337 + {lexicon; id; revision; description; defs} 338 + 339 + (* parse lexicon file *) 340 + let parse_file path : parse_result = 341 + try 342 + let json = Yojson.Safe.from_file path in 343 + Ok (parse_lexicon_doc json) 344 + with 345 + | Yojson.Json_error e -> 346 + Error ("JSON parse error: " ^ e) 347 + | Failure e -> 348 + Error ("Parse error: " ^ e) 349 + | e -> 350 + Error ("Unexpected error: " ^ Printexc.to_string e) 351 + 352 + (* parse json string *) 353 + let parse_string content : parse_result = 354 + try 355 + let json = Yojson.Safe.from_string content in 356 + Ok (parse_lexicon_doc json) 357 + with 358 + | Yojson.Json_error e -> 359 + Error ("JSON parse error: " ^ e) 360 + | Failure e -> 361 + Error ("Parse error: " ^ e) 362 + | e -> 363 + Error ("Unexpected error: " ^ Printexc.to_string e)
+11
hermes-cli/test/dune
··· 1 + (test 2 + (name test_naming) 3 + (libraries alcotest hermes_cli)) 4 + 5 + (test 6 + (name test_parser) 7 + (libraries alcotest hermes_cli yojson)) 8 + 9 + (test 10 + (name test_codegen) 11 + (libraries alcotest hermes_cli))
+342
hermes-cli/test/test_codegen.ml
··· 1 + open Alcotest 2 + open Hermes_cli 3 + 4 + let contains s1 s2 = 5 + try 6 + let len = String.length s2 in 7 + for i = 0 to String.length s1 - len do 8 + if String.sub s1 i len = s2 then raise Exit 9 + done ; 10 + false 11 + with Exit -> true 12 + 13 + (* create a simple lexicon doc for testing *) 14 + let make_lexicon id defs = 15 + {Lexicon_types.lexicon= 1; id; revision= None; description= None; defs} 16 + 17 + let make_def name type_def = {Lexicon_types.name; type_def} 18 + 19 + let make_object_spec properties required = 20 + { Lexicon_types.properties 21 + ; required= Some required 22 + ; nullable= None 23 + ; description= None } 24 + 25 + let make_property type_def = {Lexicon_types.type_def; description= None} 26 + 27 + let string_type = 28 + Lexicon_types.String 29 + { format= None 30 + ; min_length= None 31 + ; max_length= None 32 + ; min_graphemes= None 33 + ; max_graphemes= None 34 + ; known_values= None 35 + ; enum= None 36 + ; const= None 37 + ; default= None 38 + ; description= None } 39 + 40 + let int_type = 41 + Lexicon_types.Integer 42 + { minimum= None 43 + ; maximum= None 44 + ; enum= None 45 + ; const= None 46 + ; default= None 47 + ; description= None } 48 + 49 + let[@warning "-32"] _bool_type = 50 + Lexicon_types.Boolean {const= None; default= None; description= None} 51 + 52 + (* test generating a simple object type *) 53 + let test_gen_simple_object () = 54 + let obj_spec = 55 + make_object_spec 56 + [("name", make_property string_type); ("age", make_property int_type)] 57 + ["name"; "age"] 58 + in 59 + let doc = 60 + make_lexicon "com.example.test" 61 + [make_def "main" (Lexicon_types.Object obj_spec)] 62 + in 63 + let code = Codegen.gen_lexicon_module doc in 64 + check bool "contains type main" true (contains code "type main =") ; 65 + check bool "contains name field" true (contains code "name: string") ; 66 + check bool "contains age field" true (contains code "age: int") ; 67 + check bool "contains deriving" true (contains code "[@@deriving yojson") 68 + 69 + (* test generating object with optional fields *) 70 + let test_gen_optional_fields () = 71 + let obj_spec = 72 + make_object_spec 73 + [ ("required_field", make_property string_type) 74 + ; ("optional_field", make_property string_type) ] 75 + ["required_field"] 76 + (* only required_field is required *) 77 + in 78 + let doc = 79 + make_lexicon "com.example.optional" 80 + [make_def "main" (Lexicon_types.Object obj_spec)] 81 + in 82 + let code = Codegen.gen_lexicon_module doc in 83 + check bool "required not option" true 84 + (contains code "required_field: string;") ; 85 + check bool "optional is option" true 86 + (contains code "optional_field: string option") 87 + 88 + (* test generating with key annotation *) 89 + let test_gen_key_annotation () = 90 + let obj_spec = 91 + make_object_spec [("firstName", make_property string_type)] ["firstName"] 92 + in 93 + let doc = 94 + make_lexicon "com.example.key" 95 + [make_def "main" (Lexicon_types.Object obj_spec)] 96 + in 97 + let code = Codegen.gen_lexicon_module doc in 98 + check bool "has snake_case field" true (contains code "first_name:") ; 99 + check bool "has key annotation" true (contains code "[@key \"firstName\"]") 100 + 101 + (* test generating union type *) 102 + let test_gen_union_type () = 103 + let union_spec = 104 + { Lexicon_types.refs= ["#typeA"; "#typeB"] 105 + ; closed= Some false 106 + ; description= None } 107 + in 108 + let doc = 109 + make_lexicon "com.example.union" 110 + [make_def "result" (Lexicon_types.Union union_spec)] 111 + in 112 + let code = Codegen.gen_lexicon_module doc in 113 + check bool "contains type result" true (contains code "type result =") ; 114 + check bool "contains TypeA variant" true (contains code "| TypeA of") ; 115 + check bool "contains TypeB variant" true (contains code "| TypeB of") ; 116 + check bool "contains Unknown (open)" true 117 + (contains code "| Unknown of Yojson.Safe.t") 118 + 119 + (* test generating closed union *) 120 + let test_gen_closed_union () = 121 + let union_spec = 122 + { Lexicon_types.refs= ["#typeA"; "#typeB"] 123 + ; closed= Some true 124 + ; description= None } 125 + in 126 + let doc = 127 + make_lexicon "com.example.closed" 128 + [make_def "result" (Lexicon_types.Union union_spec)] 129 + in 130 + let code = Codegen.gen_lexicon_module doc in 131 + check bool "no Unknown variant" false (contains code "| Unknown of") 132 + 133 + (* test generating query module *) 134 + let test_gen_query_module () = 135 + let params_spec = 136 + { Lexicon_types.properties= [("userId", make_property string_type)] 137 + ; required= Some ["userId"] 138 + ; description= None } 139 + in 140 + let output_schema = 141 + Lexicon_types.Object 142 + (make_object_spec [("name", make_property string_type)] ["name"]) 143 + in 144 + let output_body = 145 + { Lexicon_types.encoding= "application/json" 146 + ; schema= Some output_schema 147 + ; description= None } 148 + in 149 + let query_spec = 150 + { Lexicon_types.parameters= Some params_spec 151 + ; output= Some output_body 152 + ; errors= None 153 + ; description= Some "Get user by ID" } 154 + in 155 + let doc = 156 + make_lexicon "com.example.getUser" 157 + [make_def "main" (Lexicon_types.Query query_spec)] 158 + in 159 + let code = Codegen.gen_lexicon_module doc in 160 + check bool "contains module Main" true (contains code "module Main = struct") ; 161 + check bool "contains nsid" true 162 + (contains code "let nsid = \"com.example.getUser\"") ; 163 + check bool "contains type params" true (contains code "type params =") ; 164 + check bool "contains type output" true (contains code "type output =") ; 165 + check bool "contains call function" true (contains code "let call") ; 166 + check bool "contains ~user_id param" true (contains code "~user_id") ; 167 + check bool "calls Hermes.query" true (contains code "Hermes.query") 168 + 169 + (* test generating procedure module *) 170 + let test_gen_procedure_module () = 171 + let input_schema = 172 + Lexicon_types.Object 173 + (make_object_spec 174 + [ ("name", make_property string_type) 175 + ; ("email", make_property string_type) ] 176 + ["name"; "email"] ) 177 + in 178 + let input_body = 179 + { Lexicon_types.encoding= "application/json" 180 + ; schema= Some input_schema 181 + ; description= None } 182 + in 183 + let output_schema = 184 + Lexicon_types.Object 185 + (make_object_spec [("id", make_property string_type)] ["id"]) 186 + in 187 + let output_body = 188 + { Lexicon_types.encoding= "application/json" 189 + ; schema= Some output_schema 190 + ; description= None } 191 + in 192 + let proc_spec = 193 + { Lexicon_types.parameters= None 194 + ; input= Some input_body 195 + ; output= Some output_body 196 + ; errors= None 197 + ; description= Some "Create user" } 198 + in 199 + let doc = 200 + make_lexicon "com.example.createUser" 201 + [make_def "main" (Lexicon_types.Procedure proc_spec)] 202 + in 203 + let code = Codegen.gen_lexicon_module doc in 204 + check bool "contains module Main" true (contains code "module Main = struct") ; 205 + check bool "contains type input" true (contains code "type input =") ; 206 + check bool "contains type output" true (contains code "type output =") ; 207 + check bool "contains call function" true (contains code "let call") ; 208 + check bool "contains ~name param" true (contains code "~name") ; 209 + check bool "contains ~email param" true (contains code "~email") ; 210 + check bool "calls Hermes.procedure" true (contains code "Hermes.procedure") 211 + 212 + (* test type ordering with dependencies *) 213 + let test_type_ordering () = 214 + (* create types where typeB depends on typeA *) 215 + let type_a_spec = 216 + make_object_spec [("value", make_property string_type)] ["value"] 217 + in 218 + let type_b_spec = 219 + make_object_spec 220 + [ ( "a" 221 + , make_property (Lexicon_types.Ref {ref_= "#typeA"; description= None}) 222 + ) ] 223 + ["a"] 224 + in 225 + let doc = 226 + make_lexicon "com.example.order" 227 + [ make_def "typeB" (Lexicon_types.Object type_b_spec) 228 + ; make_def "typeA" (Lexicon_types.Object type_a_spec) ] 229 + in 230 + let code = Codegen.gen_lexicon_module doc in 231 + (* typeA should appear before typeB in the generated code *) 232 + let pos_a = 233 + try Some (Str.search_forward (Str.regexp "type type_a") code 0) 234 + with Not_found -> None 235 + in 236 + let pos_b = 237 + try Some (Str.search_forward (Str.regexp "type type_b") code 0) 238 + with Not_found -> None 239 + in 240 + match (pos_a, pos_b) with 241 + | Some a, Some b -> 242 + check bool "typeA before typeB" true (a < b) 243 + | _ -> 244 + fail "both types should be present" 245 + 246 + (* test generating token *) 247 + let test_gen_token () = 248 + let token_spec : Lexicon_types.token_spec = 249 + {description= Some "A token value"} 250 + in 251 + let doc = 252 + make_lexicon "com.example.tokens" 253 + [make_def "myToken" (Lexicon_types.Token token_spec)] 254 + in 255 + let code = Codegen.gen_lexicon_module doc in 256 + check bool "contains let my_token" true (contains code "let my_token =") ; 257 + check bool "contains full URI" true 258 + (contains code "com.example.tokens#myToken") 259 + 260 + (* test generating query with bytes output (like getBlob) *) 261 + let test_gen_query_bytes_output () = 262 + let params_spec = 263 + { Lexicon_types.properties= 264 + [("did", make_property string_type); ("cid", make_property string_type)] 265 + ; required= Some ["did"; "cid"] 266 + ; description= None } 267 + in 268 + let output_body = 269 + { Lexicon_types.encoding= "*/*" (* bytes output *) 270 + ; schema= None 271 + ; description= None } 272 + in 273 + let query_spec = 274 + { Lexicon_types.parameters= Some params_spec 275 + ; output= Some output_body 276 + ; errors= None 277 + ; description= Some "Get a blob" } 278 + in 279 + let doc = 280 + make_lexicon "com.atproto.sync.getBlob" 281 + [make_def "main" (Lexicon_types.Query query_spec)] 282 + in 283 + let code = Codegen.gen_lexicon_module doc in 284 + check bool "contains module Main" true (contains code "module Main = struct") ; 285 + check bool "output is string * string tuple" true 286 + (contains code "type output = string * string") ; 287 + check bool "calls Hermes.query_bytes" true 288 + (contains code "Hermes.query_bytes") 289 + 290 + (* test generating procedure with bytes input (like importRepo) *) 291 + let test_gen_procedure_bytes_input () = 292 + let input_body = 293 + { Lexicon_types.encoding= "application/vnd.ipld.car" (* bytes input *) 294 + ; schema= None 295 + ; description= None } 296 + in 297 + let proc_spec = 298 + { Lexicon_types.parameters= None 299 + ; input= Some input_body 300 + ; output= None 301 + ; errors= None 302 + ; description= Some "Import a repo" } 303 + in 304 + let doc = 305 + make_lexicon "com.atproto.repo.importRepo" 306 + [make_def "main" (Lexicon_types.Procedure proc_spec)] 307 + in 308 + let code = Codegen.gen_lexicon_module doc in 309 + check bool "contains module Main" true (contains code "module Main = struct") ; 310 + check bool "has ?input param" true (contains code "?input") ; 311 + check bool "calls Hermes.procedure_bytes" true 312 + (contains code "Hermes.procedure_bytes") ; 313 + check bool "has content_type" true (contains code "application/vnd.ipld.car") 314 + 315 + (** tests *) 316 + 317 + let object_tests = 318 + [ ("simple object", `Quick, test_gen_simple_object) 319 + ; ("optional fields", `Quick, test_gen_optional_fields) 320 + ; ("key annotation", `Quick, test_gen_key_annotation) ] 321 + 322 + let union_tests = 323 + [ ("open union", `Quick, test_gen_union_type) 324 + ; ("closed union", `Quick, test_gen_closed_union) ] 325 + 326 + let xrpc_tests = 327 + [ ("query module", `Quick, test_gen_query_module) 328 + ; ("procedure module", `Quick, test_gen_procedure_module) 329 + ; ("query with bytes output", `Quick, test_gen_query_bytes_output) 330 + ; ("procedure with bytes input", `Quick, test_gen_procedure_bytes_input) ] 331 + 332 + let ordering_tests = [("type ordering", `Quick, test_type_ordering)] 333 + 334 + let token_tests = [("token generation", `Quick, test_gen_token)] 335 + 336 + let () = 337 + run "Codegen" 338 + [ ("objects", object_tests) 339 + ; ("unions", union_tests) 340 + ; ("xrpc", xrpc_tests) 341 + ; ("ordering", ordering_tests) 342 + ; ("tokens", token_tests) ]
+191
hermes-cli/test/test_naming.ml
··· 1 + open Alcotest 2 + open Hermes_cli.Naming 3 + 4 + (** helpers *) 5 + let test_string = testable Fmt.string String.equal 6 + 7 + let test_camel_to_snake_simple () = 8 + check test_string "simple camelCase" "first_name" (camel_to_snake "firstName") 9 + 10 + let test_camel_to_snake_single () = 11 + check test_string "single word" "name" (camel_to_snake "name") 12 + 13 + let test_camel_to_snake_already_snake () = 14 + check test_string "already snake_case" "first_name" 15 + (camel_to_snake "first_name") 16 + 17 + let test_camel_to_snake_multiple_caps () = 18 + check test_string "multiple caps" "auth_factor_token" 19 + (camel_to_snake "authFactorToken") 20 + 21 + let test_camel_to_snake_leading_cap () = 22 + check test_string "leading capital" "name" (camel_to_snake "Name") 23 + 24 + let test_camel_to_snake_all_caps () = 25 + check test_string "all caps sequence" "d_i_d" (camel_to_snake "DID") 26 + 27 + let test_camel_to_snake_empty () = 28 + check test_string "empty string" "" (camel_to_snake "") 29 + 30 + let test_is_reserved_type () = 31 + check bool "type is reserved" true (is_reserved "type") 32 + 33 + let test_is_reserved_module () = 34 + check bool "module is reserved" true (is_reserved "module") 35 + 36 + let test_is_reserved_and () = 37 + check bool "and is reserved" true (is_reserved "and") 38 + 39 + let test_is_reserved_user () = 40 + check bool "user is not reserved" false (is_reserved "user") 41 + 42 + let test_is_reserved_case_insensitive () = 43 + check bool "TYPE is reserved (case insensitive)" true (is_reserved "TYPE") 44 + 45 + let test_escape_keyword_reserved () = 46 + check test_string "escapes type" "type_" (escape_keyword "type") 47 + 48 + let test_escape_keyword_not_reserved () = 49 + check test_string "does not escape user" "user" (escape_keyword "user") 50 + 51 + let test_escape_keyword_module () = 52 + check test_string "escapes module" "module_" (escape_keyword "module") 53 + 54 + let test_field_name_camel () = 55 + check test_string "converts camelCase" "first_name" (field_name "firstName") 56 + 57 + let test_field_name_reserved () = 58 + check test_string "escapes reserved" "type_" (field_name "type") 59 + 60 + let test_field_name_camel_reserved () = 61 + check test_string "converts and escapes" "to_" (field_name "to") 62 + 63 + let test_module_name_simple () = 64 + check test_string "capitalizes" "App" (module_name_of_segment "app") 65 + 66 + let test_module_name_already_cap () = 67 + check test_string "already capitalized" "App" (module_name_of_segment "App") 68 + 69 + let test_module_name_empty () = 70 + check test_string "empty string" "" (module_name_of_segment "") 71 + 72 + let test_module_path_simple () = 73 + check (list test_string) "simple path" ["App"; "Bsky"; "Graph"] 74 + (module_path_of_nsid "app.bsky.graph") 75 + 76 + let test_module_path_single () = 77 + check (list test_string) "single segment" ["App"] (module_path_of_nsid "app") 78 + 79 + let test_module_path_full () = 80 + check (list test_string) "full nsid" 81 + ["Com"; "Atproto"; "Server"; "CreateSession"] 82 + (module_path_of_nsid "com.atproto.server.createSession") 83 + 84 + let test_type_name_of_nsid () = 85 + check test_string "extracts last segment" "get_profile" 86 + (type_name_of_nsid "app.bsky.actor.getProfile") 87 + 88 + let test_type_name_simple () = 89 + check test_string "converts name" "invite_code" (type_name "inviteCode") 90 + 91 + let test_type_name_reserved () = 92 + check test_string "escapes reserved" "type_" (type_name "type") 93 + 94 + let test_def_module_name () = 95 + check test_string "capitalizes" "InviteCode" (def_module_name "inviteCode") 96 + 97 + let test_variant_local_ref () = 98 + check test_string "local ref" "Relationship" 99 + (variant_name_of_ref "#relationship") 100 + 101 + let test_variant_external_ref () = 102 + check test_string "external ref" "SomeDef" 103 + (variant_name_of_ref "com.example.defs#someDef") 104 + 105 + let test_variant_just_nsid () = 106 + check test_string "just nsid" "Defs" (variant_name_of_ref "com.example.defs") 107 + 108 + let test_flat_module_name () = 109 + check test_string "flat name" "Com_atproto_server_defs" 110 + (flat_module_name_of_nsid "com.atproto.server.defs") 111 + 112 + let test_flat_module_name_short () = 113 + check test_string "short nsid" "App_bsky" 114 + (flat_module_name_of_nsid "app.bsky") 115 + 116 + let test_file_path () = 117 + check test_string "file path" "com_atproto_server_defs.ml" 118 + (file_path_of_nsid "com.atproto.server.defs") 119 + 120 + let test_key_annotation_needed () = 121 + check test_string "annotation needed" " [@key \"firstName\"]" 122 + (key_annotation "firstName" "first_name") 123 + 124 + let test_key_annotation_not_needed () = 125 + check test_string "annotation not needed" "" (key_annotation "name" "name") 126 + 127 + let camel_to_snake_tests = 128 + [ ("simple camelCase", `Quick, test_camel_to_snake_simple) 129 + ; ("single word", `Quick, test_camel_to_snake_single) 130 + ; ("already snake_case", `Quick, test_camel_to_snake_already_snake) 131 + ; ("multiple caps", `Quick, test_camel_to_snake_multiple_caps) 132 + ; ("leading capital", `Quick, test_camel_to_snake_leading_cap) 133 + ; ("all caps", `Quick, test_camel_to_snake_all_caps) 134 + ; ("empty string", `Quick, test_camel_to_snake_empty) ] 135 + 136 + let is_reserved_tests = 137 + [ ("type is reserved", `Quick, test_is_reserved_type) 138 + ; ("module is reserved", `Quick, test_is_reserved_module) 139 + ; ("and is reserved", `Quick, test_is_reserved_and) 140 + ; ("user is not reserved", `Quick, test_is_reserved_user) 141 + ; ("case insensitive", `Quick, test_is_reserved_case_insensitive) ] 142 + 143 + let escape_keyword_tests = 144 + [ ("escapes reserved", `Quick, test_escape_keyword_reserved) 145 + ; ("does not escape non-reserved", `Quick, test_escape_keyword_not_reserved) 146 + ; ("escapes module", `Quick, test_escape_keyword_module) ] 147 + 148 + let field_name_tests = 149 + [ ("converts camelCase", `Quick, test_field_name_camel) 150 + ; ("escapes reserved", `Quick, test_field_name_reserved) 151 + ; ("converts and escapes", `Quick, test_field_name_camel_reserved) ] 152 + 153 + let module_name_tests = 154 + [ ("capitalizes segment", `Quick, test_module_name_simple) 155 + ; ("already capitalized", `Quick, test_module_name_already_cap) 156 + ; ("empty string", `Quick, test_module_name_empty) 157 + ; ("module path simple", `Quick, test_module_path_simple) 158 + ; ("module path single", `Quick, test_module_path_single) 159 + ; ("module path full", `Quick, test_module_path_full) ] 160 + 161 + let type_name_tests = 162 + [ ("type_name_of_nsid", `Quick, test_type_name_of_nsid) 163 + ; ("type_name simple", `Quick, test_type_name_simple) 164 + ; ("type_name reserved", `Quick, test_type_name_reserved) 165 + ; ("def_module_name", `Quick, test_def_module_name) ] 166 + 167 + let variant_name_tests = 168 + [ ("local ref", `Quick, test_variant_local_ref) 169 + ; ("external ref", `Quick, test_variant_external_ref) 170 + ; ("just nsid", `Quick, test_variant_just_nsid) ] 171 + 172 + let flat_module_tests = 173 + [ ("flat module name", `Quick, test_flat_module_name) 174 + ; ("flat module short", `Quick, test_flat_module_name_short) 175 + ; ("file path", `Quick, test_file_path) ] 176 + 177 + let annotation_tests = 178 + [ ("annotation needed", `Quick, test_key_annotation_needed) 179 + ; ("annotation not needed", `Quick, test_key_annotation_not_needed) ] 180 + 181 + let () = 182 + run "Naming" 183 + [ ("camel_to_snake", camel_to_snake_tests) 184 + ; ("is_reserved", is_reserved_tests) 185 + ; ("escape_keyword", escape_keyword_tests) 186 + ; ("field_name", field_name_tests) 187 + ; ("module_name", module_name_tests) 188 + ; ("type_name", type_name_tests) 189 + ; ("variant_name", variant_name_tests) 190 + ; ("flat_module", flat_module_tests) 191 + ; ("annotations", annotation_tests) ]
+325
hermes-cli/test/test_parser.ml
··· 1 + open Alcotest 2 + open Hermes_cli 3 + 4 + (** helpers *) 5 + let test_string = testable Fmt.string String.equal 6 + 7 + (* parsing a simple object type *) 8 + let test_parse_simple_object () = 9 + let json = 10 + {|{ 11 + "lexicon": 1, 12 + "id": "com.example.test", 13 + "defs": { 14 + "main": { 15 + "type": "object", 16 + "properties": { 17 + "name": {"type": "string"}, 18 + "count": {"type": "integer"} 19 + }, 20 + "required": ["name"] 21 + } 22 + } 23 + }|} 24 + in 25 + match Parser.parse_string json with 26 + | Ok doc -> 27 + check test_string "id matches" "com.example.test" doc.id ; 28 + check int "lexicon version" 1 doc.lexicon ; 29 + check int "one definition" 1 (List.length doc.defs) 30 + | Error e -> 31 + fail ("parse failed: " ^ e) 32 + 33 + (* parsing string type with constraints *) 34 + let test_parse_string_type () = 35 + let json = 36 + {|{ 37 + "lexicon": 1, 38 + "id": "com.example.string", 39 + "defs": { 40 + "main": { 41 + "type": "object", 42 + "properties": { 43 + "handle": { 44 + "type": "string", 45 + "format": "handle", 46 + "minLength": 3, 47 + "maxLength": 50 48 + } 49 + } 50 + } 51 + } 52 + }|} 53 + in 54 + match Parser.parse_string json with 55 + | Ok doc -> ( 56 + check int "one definition" 1 (List.length doc.defs) ; 57 + let def = List.hd doc.defs in 58 + match def.type_def with 59 + | Lexicon_types.Object spec -> ( 60 + check int "one property" 1 (List.length spec.properties) ; 61 + let _, prop = List.hd spec.properties in 62 + match prop.type_def with 63 + | Lexicon_types.String s -> 64 + check (option test_string) "format" (Some "handle") s.format ; 65 + check (option int) "minLength" (Some 3) s.min_length ; 66 + check (option int) "maxLength" (Some 50) s.max_length 67 + | _ -> 68 + fail "expected string type" ) 69 + | _ -> 70 + fail "expected object type" ) 71 + | Error e -> 72 + fail ("parse failed: " ^ e) 73 + 74 + (* parsing array type *) 75 + let test_parse_array_type () = 76 + let json = 77 + {|{ 78 + "lexicon": 1, 79 + "id": "com.example.array", 80 + "defs": { 81 + "main": { 82 + "type": "object", 83 + "properties": { 84 + "items": { 85 + "type": "array", 86 + "items": {"type": "string"}, 87 + "maxLength": 100 88 + } 89 + } 90 + } 91 + } 92 + }|} 93 + in 94 + match Parser.parse_string json with 95 + | Ok doc -> ( 96 + let def = List.hd doc.defs in 97 + match def.type_def with 98 + | Lexicon_types.Object spec -> ( 99 + let _, prop = List.hd spec.properties in 100 + match prop.type_def with 101 + | Lexicon_types.Array arr -> ( 102 + check (option int) "maxLength" (Some 100) arr.max_length ; 103 + match arr.items with 104 + | Lexicon_types.String _ -> 105 + () 106 + | _ -> 107 + fail "expected string items" ) 108 + | _ -> 109 + fail "expected array type" ) 110 + | _ -> 111 + fail "expected object type" ) 112 + | Error e -> 113 + fail ("parse failed: " ^ e) 114 + 115 + (* parsing ref type *) 116 + let test_parse_ref_type () = 117 + let json = 118 + {|{ 119 + "lexicon": 1, 120 + "id": "com.example.ref", 121 + "defs": { 122 + "main": { 123 + "type": "object", 124 + "properties": { 125 + "user": { 126 + "type": "ref", 127 + "ref": "com.example.defs#user" 128 + } 129 + } 130 + } 131 + } 132 + }|} 133 + in 134 + match Parser.parse_string json with 135 + | Ok doc -> ( 136 + let def = List.hd doc.defs in 137 + match def.type_def with 138 + | Lexicon_types.Object spec -> ( 139 + let _, prop = List.hd spec.properties in 140 + match prop.type_def with 141 + | Lexicon_types.Ref r -> 142 + check test_string "ref value" "com.example.defs#user" r.ref_ 143 + | _ -> 144 + fail "expected ref type" ) 145 + | _ -> 146 + fail "expected object type" ) 147 + | Error e -> 148 + fail ("parse failed: " ^ e) 149 + 150 + (* parsing union type *) 151 + let test_parse_union_type () = 152 + let json = 153 + {|{ 154 + "lexicon": 1, 155 + "id": "com.example.union", 156 + "defs": { 157 + "main": { 158 + "type": "union", 159 + "refs": ["#typeA", "#typeB"], 160 + "closed": true 161 + } 162 + } 163 + }|} 164 + in 165 + match Parser.parse_string json with 166 + | Ok doc -> ( 167 + let def = List.hd doc.defs in 168 + match def.type_def with 169 + | Lexicon_types.Union u -> 170 + check int "two refs" 2 (List.length u.refs) ; 171 + check (option bool) "closed" (Some true) u.closed 172 + | _ -> 173 + fail "expected union type" ) 174 + | Error e -> 175 + fail ("parse failed: " ^ e) 176 + 177 + (* parsing query type *) 178 + let test_parse_query_type () = 179 + let json = 180 + {|{ 181 + "lexicon": 1, 182 + "id": "com.example.getUser", 183 + "defs": { 184 + "main": { 185 + "type": "query", 186 + "description": "Get a user", 187 + "parameters": { 188 + "type": "params", 189 + "properties": { 190 + "userId": {"type": "string"} 191 + }, 192 + "required": ["userId"] 193 + }, 194 + "output": { 195 + "encoding": "application/json", 196 + "schema": { 197 + "type": "object", 198 + "properties": { 199 + "name": {"type": "string"} 200 + } 201 + } 202 + } 203 + } 204 + } 205 + }|} 206 + in 207 + match Parser.parse_string json with 208 + | Ok doc -> ( 209 + let def = List.hd doc.defs in 210 + match def.type_def with 211 + | Lexicon_types.Query q -> ( 212 + check (option test_string) "description" (Some "Get a user") 213 + q.description ; 214 + ( match q.parameters with 215 + | Some params -> 216 + check int "one param" 1 (List.length params.properties) 217 + | None -> 218 + fail "expected parameters" ) ; 219 + match q.output with 220 + | Some output -> 221 + check test_string "encoding" "application/json" output.encoding 222 + | None -> 223 + fail "expected output" ) 224 + | _ -> 225 + fail "expected query type" ) 226 + | Error e -> 227 + fail ("parse failed: " ^ e) 228 + 229 + (* parsing procedure type *) 230 + let test_parse_procedure_type () = 231 + let json = 232 + {|{ 233 + "lexicon": 1, 234 + "id": "com.example.createUser", 235 + "defs": { 236 + "main": { 237 + "type": "procedure", 238 + "input": { 239 + "encoding": "application/json", 240 + "schema": { 241 + "type": "object", 242 + "properties": { 243 + "name": {"type": "string"} 244 + }, 245 + "required": ["name"] 246 + } 247 + }, 248 + "output": { 249 + "encoding": "application/json", 250 + "schema": { 251 + "type": "object", 252 + "properties": { 253 + "id": {"type": "string"} 254 + } 255 + } 256 + } 257 + } 258 + } 259 + }|} 260 + in 261 + match Parser.parse_string json with 262 + | Ok doc -> ( 263 + let def = List.hd doc.defs in 264 + match def.type_def with 265 + | Lexicon_types.Procedure p -> ( 266 + ( match p.input with 267 + | Some input -> 268 + check test_string "input encoding" "application/json" 269 + input.encoding 270 + | None -> 271 + fail "expected input" ) ; 272 + match p.output with 273 + | Some output -> 274 + check test_string "output encoding" "application/json" 275 + output.encoding 276 + | None -> 277 + fail "expected output" ) 278 + | _ -> 279 + fail "expected procedure type" ) 280 + | Error e -> 281 + fail ("parse failed: " ^ e) 282 + 283 + (* parsing invalid JSON *) 284 + let test_parse_invalid_json () = 285 + let json = {|{ invalid json }|} in 286 + match Parser.parse_string json with 287 + | Ok _ -> 288 + fail "should have failed" 289 + | Error e -> 290 + check bool "has error message" true (String.length e > 0) 291 + 292 + (* parsing missing required field *) 293 + let test_parse_missing_field () = 294 + let json = {|{ 295 + "lexicon": 1, 296 + "defs": {} 297 + }|} in 298 + match Parser.parse_string json with 299 + | Ok _ -> 300 + fail "should have failed (missing id)" 301 + | Error _ -> 302 + () 303 + 304 + (** tests *) 305 + 306 + let object_tests = 307 + [ ("simple object", `Quick, test_parse_simple_object) 308 + ; ("string with constraints", `Quick, test_parse_string_type) 309 + ; ("array type", `Quick, test_parse_array_type) 310 + ; ("ref type", `Quick, test_parse_ref_type) ] 311 + 312 + let complex_type_tests = 313 + [ ("union type", `Quick, test_parse_union_type) 314 + ; ("query type", `Quick, test_parse_query_type) 315 + ; ("procedure type", `Quick, test_parse_procedure_type) ] 316 + 317 + let error_tests = 318 + [ ("invalid json", `Quick, test_parse_invalid_json) 319 + ; ("missing field", `Quick, test_parse_missing_field) ] 320 + 321 + let () = 322 + run "Parser" 323 + [ ("objects", object_tests) 324 + ; ("complex_types", complex_type_tests) 325 + ; ("errors", error_tests) ]
+37
hermes.opam
··· 1 + # This file is generated by dune, edit dune-project instead 2 + opam-version: "2.0" 3 + synopsis: "Type-safe XRPC client for ATProto" 4 + description: "XRPC client with PPX extensions for type-safe API calls" 5 + maintainer: ["futurGH"] 6 + authors: ["futurGH"] 7 + license: "MPL-2.0" 8 + homepage: "https://github.com/futurGH/pegasus" 9 + bug-reports: "https://github.com/futurGH/pegasus/issues" 10 + depends: [ 11 + "ocaml" {= "5.2.1"} 12 + "dune" {>= "3.20"} 13 + "lwt" 14 + "cohttp-lwt-unix" {>= "6.1.1"} 15 + "uri" {>= "4.4.0"} 16 + "yojson" {>= "3.0.0"} 17 + "base64" {>= "3.5.0"} 18 + "lwt_ppx" {>= "5.9.1"} 19 + "ppx_deriving_yojson" {>= "3.9.1"} 20 + "odoc" {with-doc} 21 + ] 22 + build: [ 23 + ["dune" "subst"] {dev} 24 + [ 25 + "dune" 26 + "build" 27 + "-p" 28 + name 29 + "-j" 30 + jobs 31 + "@install" 32 + "@runtest" {with-test} 33 + "@doc" {with-doc} 34 + ] 35 + ] 36 + dev-repo: "git+https://github.com/futurGH/pegasus.git" 37 + x-maintenance-intent: ["(latest)"]
+330
hermes/README.md
··· 1 + # hermes 2 + 3 + is a type-safe XRPC client for atproto. 4 + 5 + Hermes provides three components: 6 + 7 + - **hermes** - Core library for making XRPC calls 8 + - **hermes-cli** - Code generator for atproto lexicons 9 + - **hermes_ppx** - PPX extension for ergonomic API calls 10 + 11 + - [Quick Start](#quick-start) 12 + - [Complete Example](#complete-example) 13 + - [Installation](#installation) 14 + - [hermes](#hermes-lib) 15 + - [Session Management](#session-management) 16 + - [Making XRPC Calls](#making-xrpc-calls) 17 + - [Error Handling](#error-handling) 18 + - [hermes_ppx](#hermes-ppx) 19 + - [Setup](#setup) 20 + - [Usage](#ppx-usage) 21 + - [hermes-cli](#hermes-cli) 22 + - [Usage](#usage) 23 + - [Options](#options) 24 + - [Generated Code Structure](#generated-code-structure) 25 + - [Type Mappings](#type-mappings) 26 + - [Bytes Encoding](#bytes-encoding) 27 + - [Union Types](#union-types) 28 + 29 + ## quick start 30 + 31 + ```ocaml 32 + open Hermes_lexicons (* generate lexicons using hermes-cli! *) 33 + open Lwt.Syntax 34 + 35 + let () = Lwt_main.run begin 36 + (* Create an unauthenticated client *) 37 + let client = Hermes.make_client ~service:"https://public.api.bsky.app" () in 38 + 39 + (* Make a query using the generated module *) 40 + let* profile = App_bsky_actor_getProfile.call ~actor:"bsky.app" client in 41 + print_endline profile.display_name; 42 + Lwt.return_unit 43 + end 44 + ``` 45 + 46 + ## complete example 47 + 48 + ```ocaml 49 + open Hermes_lexicons (* generate lexicons using hermes-cli! *) 50 + open Lwt.Syntax 51 + 52 + let main () = 53 + (* Set up credential manager with persistence *) 54 + let manager = Hermes.make_credential_manager ~service:"https://pegasus.example" () in 55 + 56 + Hermes.on_session_update manager (fun session -> 57 + let json = Hermes.session_to_yojson session in 58 + Yojson.Safe.to_file "session.json" json; 59 + Lwt.return_unit 60 + ); 61 + 62 + (* Log in or resume session *) 63 + let* client = 64 + if Sys.file_exists "session.json" then 65 + let json = Yojson.Safe.from_file "session.json" in 66 + match Hermes.session_of_yojson json with 67 + | Ok session -> Hermes.resume manager ~session () 68 + | Error _ -> failwith "Invalid session file" 69 + else 70 + Hermes.login manager 71 + ~identifier:"you.bsky.social" 72 + ~password:"your-app-password" 73 + () 74 + in 75 + 76 + (* Fetch your profile *) 77 + let session = Hermes.get_session client |> Option.get in 78 + let* profile = 79 + [%xrpc get "app.bsky.actor.getProfile"] 80 + ~actor:session.did 81 + client 82 + in 83 + Printf.printf "Logged in as %s\n" profile.handle; 84 + 85 + (* Create a post *) 86 + let* _ = 87 + [%xrpc post "com.atproto.repo.createRecord"] 88 + ~repo:session.did 89 + ~collection:"app.bsky.feed.post" 90 + ~record:(`Assoc [ 91 + ("$type", `String "app.bsky.feed.post"); 92 + ("text", `String "Hello from Hermes!"); 93 + ("createdAt", `String (Ptime.to_rfc3339 (Ptime_clock.now ()))); 94 + ]) 95 + client 96 + in 97 + print_endline "Post created!"; 98 + Lwt.return_unit 99 + 100 + let () = Lwt_main.run (main ()) 101 + ``` 102 + 103 + ## installation 104 + 105 + Add to your `dune-project`: 106 + 107 + ```lisp 108 + (depends 109 + hermes 110 + hermes_ppx) 111 + ``` 112 + 113 + <h2 id="hermes-lib">hermes</h2> 114 + 115 + ### session management 116 + 117 + ```ocaml 118 + (* Unauthenticated client for public endpoints *) 119 + let client = Hermes.make_client ~service:"https://public.api.bsky.app" () 120 + 121 + (* Authenticated client with credential manager *) 122 + let manager = Hermes.make_credential_manager ~service:"https://bsky.social" () 123 + 124 + let%lwt client = Hermes.login manager 125 + ~identifier:"user.bsky.social" 126 + ~password:"app-password-here" 127 + () 128 + 129 + (* Get current session for persistence *) 130 + let session = Hermes.get_session client 131 + 132 + (* Save session to JSON *) 133 + let json = Hermes.session_to_yojson session 134 + 135 + (* Resume from saved session *) 136 + let%lwt client = Hermes.resume manager ~session () 137 + 138 + (* Auto-save session to disk *) 139 + let () = Hermes.on_session_update manager (fun session -> 140 + save_to_disk (Hermes.session_to_yojson session); 141 + Lwt.return_unit 142 + ) 143 + 144 + (* Listen for session expiration *) 145 + let () = Hermes.on_session_expired manager (fun () -> 146 + print_endline "session expired, log in again!"; 147 + Lwt.return_unit 148 + ) 149 + ``` 150 + 151 + ### making XRPC calls 152 + 153 + ```ocaml 154 + (* GET request *) 155 + let%lwt result = Hermes.query client 156 + "app.bsky.actor.getProfile" 157 + (`Assoc [("actor", `String "bsky.app")]) 158 + decode_profile 159 + 160 + (* GET request returning raw bytes *) 161 + let%lwt (data, content_type) = Hermes.query_bytes client 162 + "com.atproto.sync.getBlob" 163 + (`Assoc [("did", `String did); ("cid", `String cid)]) 164 + 165 + (* POST request *) 166 + let%lwt result = Hermes.procedure client 167 + "com.atproto.repo.createRecord" 168 + (`Assoc []) (* query params *) 169 + (Some input_json) 170 + decode_response 171 + 172 + (* POST request with raw bytes as input *) 173 + let%lwt response = Hermes.procedure_bytes client 174 + "com.atproto.repo.importRepo" 175 + (`Assoc []) 176 + (Some car_data) 177 + ~content_type:"application/vnd.ipld.car" 178 + 179 + (* upload bytes, get a blob back *) 180 + let%lwt blob = Hermes.procedure_blob client 181 + "com.atproto.repo.uploadBlob" 182 + (`Assoc []) 183 + image_bytes 184 + ~content_type:"image/jpeg" 185 + decode_blob 186 + ``` 187 + 188 + ### error handling 189 + 190 + ```ocaml 191 + try%lwt 192 + let%lwt _ = some_xrpc_call client in 193 + Lwt.return_unit 194 + with Hermes.Xrpc_error { status; error; message } -> 195 + Printf.printf "Error %d: %s (%s)\n" 196 + status error (Option.value message ~default:"no message"); 197 + Lwt.return_unit 198 + ``` 199 + 200 + <h2 id="hermes-cli">hermes-cli (codegen)</h2> 201 + 202 + generates type-safe OCaml modules from atproto lexicon files. 203 + 204 + ### usage 205 + 206 + ```bash 207 + # Generate from lexicons directory 208 + hermes-cli generate --input ./lexicons --output ./lib/generated 209 + 210 + # With custom root module name 211 + hermes-cli generate -i ./lexicons -o ./lib/generated --module-name Bsky_api 212 + ``` 213 + 214 + ### options 215 + 216 + | Option | Short | Description | 217 + | --------------- | ----- | --------------------------------------- | 218 + | `--input` | `-i` | Directory containing lexicon JSON files | 219 + | `--output` | `-o` | Output directory for generated OCaml | 220 + | `--module-name` | `-m` | Root module name (default: Lexicons) | 221 + 222 + ### generated code structure 223 + 224 + For a lexicon like `app.bsky.actor.getProfile`, the generator creates: 225 + 226 + ``` 227 + lib/generated/ 228 + ├── dune 229 + ├── lexicons.ml # Re-exports all modules 230 + └── app/ 231 + └── bsky/ 232 + └── actor/ 233 + └── getProfile.ml 234 + ``` 235 + 236 + Each endpoint module contains: 237 + 238 + ```ocaml 239 + module GetProfile = struct 240 + type params = { 241 + actor: string; 242 + } [@@deriving yojson] 243 + 244 + type output = { 245 + did: string; 246 + handle: string; 247 + display_name: string option; 248 + (* ... *) 249 + } [@@deriving yojson] 250 + 251 + let nsid = "app.bsky.actor.getProfile" 252 + 253 + let call ~actor (client : Hermes.client) : output Lwt.t = 254 + let params = { actor } in 255 + Hermes.query client nsid (params_to_yojson params) output_of_yojson 256 + end 257 + ``` 258 + 259 + ### type mappings 260 + 261 + | Lexicon Type | OCaml Type | 262 + | ------------ | --------------- | 263 + | `boolean` | `bool` | 264 + | `integer` | `int` | 265 + | `string` | `string` | 266 + | `bytes` | `string` | 267 + | `blob` | `Hermes.blob` | 268 + | `cid-link` | `Cid.t` | 269 + | `array` | `list` | 270 + | `object` | record type | 271 + | `union` | variant type | 272 + | `unknown` | `Yojson.Safe.t` | 273 + 274 + ### bytes encoding 275 + 276 + Endpoints with non-JSON encoding are automatically detected and handled: 277 + 278 + - **Queries with bytes output** (e.g., `com.atproto.sync.getBlob` with `encoding: "*/*"`): 279 + - Output type is `string * string` (data, content_type) 280 + - Generated code uses `Hermes.query_bytes` 281 + 282 + - **Procedures with bytes input**: 283 + - Input is `?input:string` (optional raw bytes) 284 + - Generated code uses `Hermes.procedure_bytes` 285 + 286 + ### union types 287 + 288 + Unions generate variant types with a discriminator: 289 + 290 + ```ocaml 291 + type relationship_union = 292 + | Relationship of relationship 293 + | NotFoundActor of not_found_actor 294 + | Unknown of Yojson.Safe.t (* for open unions *) 295 + ``` 296 + 297 + <h2 id="hermes-ppx">hermes_ppx (PPX extension)</h2> 298 + 299 + Transforms `[%xrpc ...]` into generated module calls. 300 + 301 + ### setup 302 + 303 + ```lisp 304 + (library 305 + (name my_app) 306 + (libraries hermes hermes_ppx lwt) 307 + (preprocess (pps hermes_ppx))) 308 + ``` 309 + 310 + <h3 id="ppx-usage">usage</h3> 311 + 312 + ```ocaml 313 + let get_followers ~actor ~limit client = 314 + [%xrpc get "app.bsky.graph.getFollowers"] 315 + ~actor 316 + ?limit 317 + client 318 + 319 + let create_post ~text client = 320 + let session = Hermes.get_session client |> Option.get in 321 + [%xrpc post "com.atproto.repo.createRecord"] 322 + ~repo:session.did 323 + ~collection:"app.bsky.feed.post" 324 + ~record:(`Assoc [ 325 + ("$type", `String "app.bsky.feed.post"); 326 + ("text", `String text); 327 + ("createdAt", `String (Ptime.to_rfc3339 (Ptime_clock.now ()))); 328 + ]) 329 + client 330 + ```
+375
hermes/lib/client.ml
··· 1 + open Lwt.Syntax 2 + 3 + type t = 4 + { service: Uri.t 5 + ; mutable headers: (string * string) list 6 + ; mutable session: Types.session option 7 + ; on_request: (t -> unit Lwt.t) option 8 + (* called before each request for token refresh *) } 9 + 10 + module type S = sig 11 + val make : service:string -> unit -> t 12 + 13 + val make_with_interceptor : 14 + service:string -> on_request:(t -> unit Lwt.t) -> unit -> t 15 + 16 + val set_session : t -> Types.session -> unit 17 + 18 + val clear_session : t -> unit 19 + 20 + val get_session : t -> Types.session option 21 + 22 + val get_service : t -> Uri.t 23 + 24 + val query : 25 + t 26 + -> string 27 + -> Yojson.Safe.t 28 + -> (Yojson.Safe.t -> ('a, string) result) 29 + -> 'a Lwt.t 30 + 31 + val procedure : 32 + t 33 + -> string 34 + -> Yojson.Safe.t 35 + -> Yojson.Safe.t option 36 + -> (Yojson.Safe.t -> ('a, string) result) 37 + -> 'a Lwt.t 38 + 39 + val query_bytes : t -> string -> Yojson.Safe.t -> (string * string) Lwt.t 40 + 41 + val procedure_bytes : 42 + t 43 + -> string 44 + -> Yojson.Safe.t 45 + -> string option 46 + -> content_type:string 47 + -> (string * string) option Lwt.t 48 + 49 + val procedure_blob : 50 + t 51 + -> string 52 + -> Yojson.Safe.t 53 + -> bytes 54 + -> content_type:string 55 + -> (Yojson.Safe.t -> ('a, string) result) 56 + -> 'a Lwt.t 57 + end 58 + 59 + module Make (Http : Http_backend.S) : S = struct 60 + let make ~service () = 61 + let service = Uri.of_string service in 62 + {service; headers= []; session= None; on_request= None} 63 + 64 + let make_with_interceptor ~service ~on_request () = 65 + let service = Uri.of_string service in 66 + {service; headers= []; session= None; on_request= Some on_request} 67 + 68 + let set_session t session = 69 + t.session <- Some session ; 70 + t.headers <- 71 + List.filter (fun (k, _) -> k <> "Authorization") t.headers 72 + @ [("Authorization", "Bearer " ^ session.Types.access_jwt)] 73 + 74 + let clear_session t = 75 + t.session <- None ; 76 + t.headers <- List.filter (fun (k, _) -> k <> "Authorization") t.headers 77 + 78 + let get_session t = t.session 79 + 80 + let get_service t = t.service 81 + 82 + (* build query string from json params *) 83 + let params_to_query (params : Yojson.Safe.t) : (string * string list) list = 84 + match params with 85 + | `Assoc pairs -> 86 + List.filter_map 87 + (fun (k, v) -> 88 + match v with 89 + | `Null -> 90 + None 91 + | `Bool b -> 92 + Some (k, [string_of_bool b]) 93 + | `Int i -> 94 + Some (k, [string_of_int i]) 95 + | `Float f -> 96 + Some (k, [string_of_float f]) 97 + | `String s -> 98 + Some (k, [s]) 99 + | `List items -> 100 + let strs = 101 + List.filter_map 102 + (function 103 + | `String s -> 104 + Some s 105 + | `Int i -> 106 + Some (string_of_int i) 107 + | `Bool b -> 108 + Some (string_of_bool b) 109 + | _ -> 110 + None ) 111 + items 112 + in 113 + if strs = [] then None else Some (k, strs) 114 + | _ -> 115 + None ) 116 + pairs 117 + | _ -> 118 + [] 119 + 120 + let make_headers ?(extra = []) ?(accept = "application/json") t = 121 + Cohttp.Header.of_list 122 + ([("User-Agent", "hermes/1.0"); ("Accept", accept)] @ t.headers @ extra) 123 + 124 + let query (t : t) (nsid : string) (params : Yojson.Safe.t) 125 + (of_yojson : Yojson.Safe.t -> ('a, string) result) : 'a Lwt.t = 126 + (* call interceptor if present for token refresh *) 127 + let* () = 128 + match t.on_request with Some f -> f t | None -> Lwt.return_unit 129 + in 130 + let query = params_to_query params in 131 + let uri = 132 + Uri.with_path t.service ("/xrpc/" ^ nsid) 133 + |> fun u -> Uri.with_query u query 134 + in 135 + let headers = make_headers t in 136 + let* resp, body = 137 + Lwt.catch 138 + (fun () -> Lwt_unix.with_timeout 30.0 (fun () -> Http.get ~headers uri)) 139 + (fun exn -> 140 + Types.raise_xrpc_error_raw ~status:0 ~error:"NetworkError" 141 + ~message:(Printexc.to_string exn) () ) 142 + in 143 + let status = Cohttp.Response.status resp |> Cohttp.Code.code_of_status in 144 + let* body_str = Cohttp_lwt.Body.to_string body in 145 + if status >= 200 && status < 300 then 146 + if String.length body_str = 0 then 147 + (* empty response, try parsing empty object *) 148 + match of_yojson (`Assoc []) with 149 + | Ok v -> 150 + Lwt.return v 151 + | Error e -> 152 + Types.raise_xrpc_error_raw ~status ~error:"ParseError" ~message:e () 153 + else 154 + let json = Yojson.Safe.from_string body_str in 155 + match of_yojson json with 156 + | Ok v -> 157 + Lwt.return v 158 + | Error e -> 159 + Types.raise_xrpc_error_raw ~status ~error:"ParseError" ~message:e () 160 + else 161 + let payload = 162 + try 163 + let json = Yojson.Safe.from_string body_str in 164 + match Types.xrpc_error_payload_of_yojson json with 165 + | Ok p -> 166 + p 167 + | Error _ -> 168 + {error= "UnknownError"; message= Some body_str} 169 + with _ -> {error= "UnknownError"; message= Some body_str} 170 + in 171 + Types.raise_xrpc_error ~status payload 172 + 173 + let procedure (t : t) (nsid : string) (params : Yojson.Safe.t) 174 + (input : Yojson.Safe.t option) 175 + (of_yojson : Yojson.Safe.t -> ('a, string) result) : 'a Lwt.t = 176 + (* call interceptor if present for token refresh *) 177 + let* () = 178 + match t.on_request with Some f -> f t | None -> Lwt.return_unit 179 + in 180 + let query = params_to_query params in 181 + let uri = 182 + Uri.with_path t.service ("/xrpc/" ^ nsid) 183 + |> fun u -> Uri.with_query u query 184 + in 185 + let body, content_type = 186 + match input with 187 + | Some json -> 188 + ( Cohttp_lwt.Body.of_string (Yojson.Safe.to_string json) 189 + , "application/json" ) 190 + | None -> 191 + (Cohttp_lwt.Body.empty, "application/json") 192 + in 193 + let headers = make_headers ~extra:[("Content-Type", content_type)] t in 194 + let* resp, resp_body = 195 + Lwt.catch 196 + (fun () -> 197 + Lwt_unix.with_timeout 30.0 (fun () -> Http.post ~headers ~body uri) ) 198 + (fun exn -> 199 + Types.raise_xrpc_error_raw ~status:0 ~error:"NetworkError" 200 + ~message:(Printexc.to_string exn) () ) 201 + in 202 + let status = Cohttp.Response.status resp |> Cohttp.Code.code_of_status in 203 + let* body_str = Cohttp_lwt.Body.to_string resp_body in 204 + if status >= 200 && status < 300 then 205 + if String.length body_str = 0 then 206 + match of_yojson (`Assoc []) with 207 + | Ok v -> 208 + Lwt.return v 209 + | Error e -> 210 + Types.raise_xrpc_error_raw ~status ~error:"ParseError" ~message:e () 211 + else 212 + let json = Yojson.Safe.from_string body_str in 213 + match of_yojson json with 214 + | Ok v -> 215 + Lwt.return v 216 + | Error e -> 217 + Types.raise_xrpc_error_raw ~status ~error:"ParseError" ~message:e () 218 + else 219 + let payload = 220 + try 221 + let json = Yojson.Safe.from_string body_str in 222 + match Types.xrpc_error_payload_of_yojson json with 223 + | Ok p -> 224 + p 225 + | Error _ -> 226 + {error= "UnknownError"; message= Some body_str} 227 + with _ -> {error= "UnknownError"; message= Some body_str} 228 + in 229 + Types.raise_xrpc_error ~status payload 230 + 231 + let query_bytes (t : t) (nsid : string) (params : Yojson.Safe.t) : 232 + (string * string) Lwt.t = 233 + (* call interceptor if present for token refresh *) 234 + let* () = 235 + match t.on_request with Some f -> f t | None -> Lwt.return_unit 236 + in 237 + let query = params_to_query params in 238 + let uri = 239 + Uri.with_path t.service ("/xrpc/" ^ nsid) 240 + |> fun u -> Uri.with_query u query 241 + in 242 + let headers = make_headers ~accept:"*/*" t in 243 + let* resp, body = 244 + Lwt.catch 245 + (fun () -> Lwt_unix.with_timeout 120.0 (fun () -> Http.get ~headers uri)) 246 + (fun exn -> 247 + Types.raise_xrpc_error_raw ~status:0 ~error:"NetworkError" 248 + ~message:(Printexc.to_string exn) () ) 249 + in 250 + let status = Cohttp.Response.status resp |> Cohttp.Code.code_of_status in 251 + let* body_str = Cohttp_lwt.Body.to_string body in 252 + if status >= 200 && status < 300 then 253 + let content_type = 254 + Cohttp.Response.headers resp 255 + |> fun h -> 256 + Cohttp.Header.get h "content-type" 257 + |> Option.value ~default:"application/octet-stream" 258 + in 259 + Lwt.return (body_str, content_type) 260 + else 261 + let payload = 262 + try 263 + let json = Yojson.Safe.from_string body_str in 264 + match Types.xrpc_error_payload_of_yojson json with 265 + | Ok p -> 266 + p 267 + | Error _ -> 268 + {error= "UnknownError"; message= Some body_str} 269 + with _ -> {error= "UnknownError"; message= Some body_str} 270 + in 271 + Types.raise_xrpc_error ~status payload 272 + 273 + (* execute procedure with raw bytes input, returns raw bytes or none if no output *) 274 + let procedure_bytes (t : t) (nsid : string) (params : Yojson.Safe.t) 275 + (input : string option) ~(content_type : string) : 276 + (string * string) option Lwt.t = 277 + (* call interceptor if present for token refresh *) 278 + let* () = 279 + match t.on_request with Some f -> f t | None -> Lwt.return_unit 280 + in 281 + let query = params_to_query params in 282 + let uri = 283 + Uri.with_path t.service ("/xrpc/" ^ nsid) 284 + |> fun u -> Uri.with_query u query 285 + in 286 + let body = 287 + match input with 288 + | Some data -> 289 + Cohttp_lwt.Body.of_string data 290 + | None -> 291 + Cohttp_lwt.Body.empty 292 + in 293 + let headers = 294 + make_headers ~extra:[("Content-Type", content_type)] ~accept:"*/*" t 295 + in 296 + let* resp, resp_body = 297 + Lwt.catch 298 + (fun () -> 299 + Lwt_unix.with_timeout 120.0 (fun () -> Http.post ~headers ~body uri) ) 300 + (fun exn -> 301 + Types.raise_xrpc_error_raw ~status:0 ~error:"NetworkError" 302 + ~message:(Printexc.to_string exn) () ) 303 + in 304 + let status = Cohttp.Response.status resp |> Cohttp.Code.code_of_status in 305 + let* body_str = Cohttp_lwt.Body.to_string resp_body in 306 + if status >= 200 && status < 300 then 307 + if String.length body_str = 0 then Lwt.return None 308 + else 309 + let resp_content_type = 310 + Cohttp.Response.headers resp 311 + |> fun h -> 312 + Cohttp.Header.get h "content-type" 313 + |> Option.value ~default:"application/octet-stream" 314 + in 315 + Lwt.return (Some (body_str, resp_content_type)) 316 + else 317 + let payload = 318 + try 319 + let json = Yojson.Safe.from_string body_str in 320 + match Types.xrpc_error_payload_of_yojson json with 321 + | Ok p -> 322 + p 323 + | Error _ -> 324 + {error= "UnknownError"; message= Some body_str} 325 + with _ -> {error= "UnknownError"; message= Some body_str} 326 + in 327 + Types.raise_xrpc_error ~status payload 328 + 329 + let procedure_blob (t : t) (nsid : string) (params : Yojson.Safe.t) 330 + (blob_data : bytes) ~(content_type : string) 331 + (of_yojson : Yojson.Safe.t -> ('a, string) result) : 'a Lwt.t = 332 + (* call interceptor if present for token refresh *) 333 + let* () = 334 + match t.on_request with Some f -> f t | None -> Lwt.return_unit 335 + in 336 + let query = params_to_query params in 337 + let uri = 338 + Uri.with_path t.service ("/xrpc/" ^ nsid) 339 + |> fun u -> Uri.with_query u query 340 + in 341 + let body = Cohttp_lwt.Body.of_string (Bytes.to_string blob_data) in 342 + let headers = make_headers ~extra:[("Content-Type", content_type)] t in 343 + let* resp, resp_body = 344 + Lwt.catch 345 + (fun () -> 346 + Lwt_unix.with_timeout 120.0 (fun () -> Http.post ~headers ~body uri) ) 347 + (fun exn -> 348 + Types.raise_xrpc_error_raw ~status:0 ~error:"NetworkError" 349 + ~message:(Printexc.to_string exn) () ) 350 + in 351 + let status = Cohttp.Response.status resp |> Cohttp.Code.code_of_status in 352 + let* body_str = Cohttp_lwt.Body.to_string resp_body in 353 + if status >= 200 && status < 300 then 354 + let json = Yojson.Safe.from_string body_str in 355 + match of_yojson json with 356 + | Ok v -> 357 + Lwt.return v 358 + | Error e -> 359 + Types.raise_xrpc_error_raw ~status ~error:"ParseError" ~message:e () 360 + else 361 + let payload = 362 + try 363 + let json = Yojson.Safe.from_string body_str in 364 + match Types.xrpc_error_payload_of_yojson json with 365 + | Ok p -> 366 + p 367 + | Error _ -> 368 + {error= "UnknownError"; message= Some body_str} 369 + with _ -> {error= "UnknownError"; message= Some body_str} 370 + in 371 + Types.raise_xrpc_error ~status payload 372 + end 373 + 374 + (* default client using real http backend *) 375 + include Make (Http_backend.Default)
+180
hermes/lib/credential_manager.ml
··· 1 + open Lwt.Syntax 2 + 3 + type t = 4 + { service: Uri.t 5 + ; mutable session: Types.session option 6 + ; mutable on_session_update: (Types.session -> unit Lwt.t) option 7 + ; mutable on_session_expired: (unit -> unit Lwt.t) option 8 + ; refresh_mutex: Lwt_mutex.t 9 + ; mutable refresh_promise: unit Lwt.t option } 10 + 11 + module type S = sig 12 + val make : service:string -> unit -> t 13 + 14 + val on_session_update : t -> (Types.session -> unit Lwt.t) -> unit 15 + 16 + val on_session_expired : t -> (unit -> unit Lwt.t) -> unit 17 + 18 + val get_session : t -> Types.session option 19 + 20 + val login : 21 + t 22 + -> identifier:string 23 + -> password:string 24 + -> ?auth_factor_token:string 25 + -> unit 26 + -> Client.t Lwt.t 27 + 28 + val resume : t -> session:Types.session -> unit -> Client.t Lwt.t 29 + 30 + val logout : t -> unit Lwt.t 31 + end 32 + 33 + module Make (C : Client.S) : S = struct 34 + let make ~service () = 35 + { service= Uri.of_string service 36 + ; session= None 37 + ; on_session_update= None 38 + ; on_session_expired= None 39 + ; refresh_mutex= Lwt_mutex.create () 40 + ; refresh_promise= None } 41 + 42 + let on_session_update t callback = t.on_session_update <- Some callback 43 + 44 + let on_session_expired t callback = t.on_session_expired <- Some callback 45 + 46 + let get_session t = t.session 47 + 48 + (* update session and notify *) 49 + let update_session t session = 50 + t.session <- Some session ; 51 + match t.on_session_update with 52 + | Some callback -> 53 + callback session 54 + | None -> 55 + Lwt.return_unit 56 + 57 + (* clear session and notify *) 58 + let clear_session t = 59 + t.session <- None ; 60 + match t.on_session_expired with 61 + | Some callback -> 62 + callback () 63 + | None -> 64 + Lwt.return_unit 65 + 66 + (* create raw client for auth operations *) 67 + let make_raw_client t = C.make ~service:(Uri.to_string t.service) () 68 + 69 + let rec login t ~identifier ~password ?auth_factor_token () = 70 + let client = make_raw_client t in 71 + let input = 72 + Types.login_request_to_yojson 73 + {Types.identifier; password; auth_factor_token} 74 + in 75 + let* session = 76 + C.procedure client "com.atproto.server.createSession" (`Assoc []) 77 + (Some input) Types.session_of_yojson 78 + in 79 + let* () = update_session t session in 80 + (* create client with request interceptor for auto-refresh *) 81 + let authed_client = 82 + C.make_with_interceptor ~service:(Uri.to_string t.service) 83 + ~on_request:(fun c -> check_and_refresh t c) 84 + () 85 + in 86 + C.set_session authed_client session ; 87 + Lwt.return authed_client 88 + 89 + and resume t ~session () = 90 + let* () = update_session t session in 91 + let authed_client = 92 + C.make_with_interceptor ~service:(Uri.to_string t.service) 93 + ~on_request:(fun c -> check_and_refresh t c) 94 + () 95 + in 96 + C.set_session authed_client session ; 97 + Lwt.return authed_client 98 + 99 + (* refresh the session *) 100 + and refresh_session t = 101 + match t.session with 102 + | None -> 103 + Types.raise_xrpc_error_raw ~status:401 ~error:"AuthRequired" 104 + ~message:"No session to refresh" () 105 + | Some session -> 106 + let client = make_raw_client t in 107 + (* use refresh token for auth *) 108 + C.set_session client {session with access_jwt= session.refresh_jwt} ; 109 + Lwt.catch 110 + (fun () -> 111 + let* new_session = 112 + C.procedure client "com.atproto.server.refreshSession" (`Assoc []) 113 + None Types.session_of_yojson 114 + in 115 + let* () = update_session t new_session in 116 + Lwt.return (Some new_session) ) 117 + (fun exn -> 118 + match exn with 119 + | Types.Xrpc_error {error= "ExpiredToken"; _} 120 + | Types.Xrpc_error {error= "InvalidToken"; _} -> 121 + let* () = clear_session t in 122 + Lwt.return None 123 + | _ -> 124 + Lwt.reraise exn ) 125 + 126 + (* check token expiry and refresh if needed *) 127 + and check_and_refresh t client = 128 + match t.session with 129 + | None -> 130 + Lwt.return_unit 131 + | Some session -> 132 + if Jwt.is_expired ~buffer_seconds:300 session.access_jwt then 133 + (* token expired or about to expire, need to refresh *) 134 + Lwt_mutex.with_lock t.refresh_mutex (fun () -> 135 + (* check again in case another request already refreshed *) 136 + match t.session with 137 + | None -> 138 + Lwt.return_unit 139 + | Some current_session -> 140 + if 141 + Jwt.is_expired ~buffer_seconds:300 142 + current_session.access_jwt 143 + then ( 144 + let* new_session = refresh_session t in 145 + match new_session with 146 + | Some s -> 147 + C.set_session client s ; Lwt.return_unit 148 + | None -> 149 + C.clear_session client ; 150 + Types.raise_xrpc_error_raw ~status:401 151 + ~error:"SessionExpired" 152 + ~message:"Failed to refresh session" () ) 153 + else ( 154 + (* another request already refreshed, just update our client *) 155 + C.set_session client current_session ; 156 + Lwt.return_unit ) ) 157 + else Lwt.return_unit 158 + 159 + let logout t = 160 + match t.session with 161 + | None -> 162 + Lwt.return_unit 163 + | Some session -> 164 + let client = make_raw_client t in 165 + C.set_session client session ; 166 + Lwt.catch 167 + (fun () -> 168 + let* (_ : Yojson.Safe.t) = 169 + C.procedure client "com.atproto.server.deleteSession" (`Assoc []) 170 + None (fun j -> Ok j ) 171 + in 172 + let* () = clear_session t in 173 + Lwt.return_unit ) 174 + (fun _ -> 175 + (* even if server fails, clear local session *) 176 + let* () = clear_session t in 177 + Lwt.return_unit ) 178 + end 179 + 180 + include Make (Client)
+5
hermes/lib/dune
··· 1 + (library 2 + (name hermes) 3 + (libraries cohttp-lwt-unix lwt lwt.unix yojson uri base64 ipld mist) 4 + (preprocess 5 + (pps lwt_ppx ppx_deriving_yojson)))
+62
hermes/lib/hermes.ml
··· 1 + type blob = Types.blob = {ref_: Cid.t; mime_type: string; size: int64} 2 + 3 + exception Xrpc_error = Types.Xrpc_error 4 + 5 + type session = Types.session = 6 + { access_jwt: string 7 + ; refresh_jwt: string 8 + ; did: string 9 + ; handle: string 10 + ; pds_uri: string option 11 + ; email: string option 12 + ; email_confirmed: bool option 13 + ; email_auth_factor: bool option 14 + ; active: bool option 15 + ; status: string option } 16 + 17 + type client = Client.t 18 + 19 + type credential_manager = Credential_manager.t 20 + 21 + let make_client = Client.make 22 + 23 + let make_credential_manager = Credential_manager.make 24 + 25 + let login = Credential_manager.login 26 + 27 + let resume = Credential_manager.resume 28 + 29 + let logout = Credential_manager.logout 30 + 31 + let get_manager_session = Credential_manager.get_session 32 + 33 + let on_session_update = Credential_manager.on_session_update 34 + 35 + let on_session_expired = Credential_manager.on_session_expired 36 + 37 + let get_session = Client.get_session 38 + 39 + let get_service = Client.get_service 40 + 41 + let query = Client.query 42 + 43 + let procedure = Client.procedure 44 + 45 + let procedure_blob = Client.procedure_blob 46 + 47 + let query_bytes = Client.query_bytes 48 + 49 + let procedure_bytes = Client.procedure_bytes 50 + 51 + let session_to_yojson = Types.session_to_yojson 52 + 53 + let session_of_yojson = Types.session_of_yojson 54 + 55 + let blob_to_yojson = Types.blob_to_yojson 56 + 57 + let blob_of_yojson = Types.blob_of_yojson 58 + 59 + module Jwt = Jwt 60 + module Http_backend = Http_backend 61 + module Client = Client 62 + module Credential_manager = Credential_manager
+201
hermes/lib/hermes.mli
··· 1 + type blob = {ref_: Cid.t; mime_type: string; size: int64} 2 + 3 + exception Xrpc_error of {status: int; error: string; message: string option} 4 + 5 + type session = 6 + { access_jwt: string 7 + ; refresh_jwt: string 8 + ; did: string 9 + ; handle: string 10 + ; pds_uri: string option 11 + ; email: string option 12 + ; email_confirmed: bool option 13 + ; email_auth_factor: bool option 14 + ; active: bool option 15 + ; status: string option } 16 + 17 + type client 18 + 19 + type credential_manager 20 + 21 + val make_client : service:string -> unit -> client 22 + 23 + val make_credential_manager : service:string -> unit -> credential_manager 24 + 25 + val login : 26 + credential_manager 27 + -> identifier:string 28 + -> password:string 29 + -> ?auth_factor_token:string 30 + -> unit 31 + -> client Lwt.t 32 + 33 + val resume : credential_manager -> session:session -> unit -> client Lwt.t 34 + 35 + val logout : credential_manager -> unit Lwt.t 36 + 37 + val get_manager_session : credential_manager -> session option 38 + 39 + val on_session_update : credential_manager -> (session -> unit Lwt.t) -> unit 40 + 41 + val on_session_expired : credential_manager -> (unit -> unit Lwt.t) -> unit 42 + 43 + val get_session : client -> session option 44 + 45 + val get_service : client -> Uri.t 46 + 47 + val query : 48 + client 49 + -> string 50 + -> Yojson.Safe.t 51 + -> (Yojson.Safe.t -> ('a, string) result) 52 + -> 'a Lwt.t 53 + 54 + val procedure : 55 + client 56 + -> string 57 + -> Yojson.Safe.t 58 + -> Yojson.Safe.t option 59 + -> (Yojson.Safe.t -> ('a, string) result) 60 + -> 'a Lwt.t 61 + 62 + val procedure_blob : 63 + client 64 + -> string 65 + -> Yojson.Safe.t 66 + -> bytes 67 + -> content_type:string 68 + -> (Yojson.Safe.t -> ('a, string) result) 69 + -> 'a Lwt.t 70 + 71 + val query_bytes : client -> string -> Yojson.Safe.t -> (string * string) Lwt.t 72 + 73 + val procedure_bytes : 74 + client 75 + -> string 76 + -> Yojson.Safe.t 77 + -> string option 78 + -> content_type:string 79 + -> (string * string) option Lwt.t 80 + 81 + val session_to_yojson : session -> Yojson.Safe.t 82 + 83 + val session_of_yojson : Yojson.Safe.t -> (session, string) result 84 + 85 + val blob_to_yojson : blob -> Yojson.Safe.t 86 + 87 + val blob_of_yojson : Yojson.Safe.t -> (blob, string) result 88 + 89 + module Jwt : sig 90 + type payload = 91 + { exp: int option 92 + ; iat: int option 93 + ; sub: string option 94 + ; aud: string option 95 + ; iss: string option } 96 + 97 + val decode_payload : string -> (payload, string) result 98 + 99 + val is_expired : ?buffer_seconds:int -> string -> bool 100 + 101 + val get_expiration : string -> int option 102 + end 103 + 104 + module Http_backend : sig 105 + type response = Cohttp.Response.t * Cohttp_lwt.Body.t 106 + 107 + module type S = sig 108 + val get : headers:Cohttp.Header.t -> Uri.t -> response Lwt.t 109 + 110 + val post : 111 + headers:Cohttp.Header.t 112 + -> body:Cohttp_lwt.Body.t 113 + -> Uri.t 114 + -> response Lwt.t 115 + end 116 + 117 + module Default : S 118 + end 119 + 120 + module Client : sig 121 + type t = client 122 + 123 + module type S = sig 124 + val make : service:string -> unit -> t 125 + 126 + val make_with_interceptor : 127 + service:string -> on_request:(t -> unit Lwt.t) -> unit -> t 128 + 129 + val set_session : t -> session -> unit 130 + 131 + val clear_session : t -> unit 132 + 133 + val get_session : t -> session option 134 + 135 + val get_service : t -> Uri.t 136 + 137 + val query : 138 + t 139 + -> string 140 + -> Yojson.Safe.t 141 + -> (Yojson.Safe.t -> ('a, string) result) 142 + -> 'a Lwt.t 143 + 144 + val procedure : 145 + t 146 + -> string 147 + -> Yojson.Safe.t 148 + -> Yojson.Safe.t option 149 + -> (Yojson.Safe.t -> ('a, string) result) 150 + -> 'a Lwt.t 151 + 152 + val query_bytes : t -> string -> Yojson.Safe.t -> (string * string) Lwt.t 153 + 154 + val procedure_bytes : 155 + t 156 + -> string 157 + -> Yojson.Safe.t 158 + -> string option 159 + -> content_type:string 160 + -> (string * string) option Lwt.t 161 + 162 + val procedure_blob : 163 + t 164 + -> string 165 + -> Yojson.Safe.t 166 + -> bytes 167 + -> content_type:string 168 + -> (Yojson.Safe.t -> ('a, string) result) 169 + -> 'a Lwt.t 170 + end 171 + 172 + module Make (_ : Http_backend.S) : S 173 + end 174 + 175 + module Credential_manager : sig 176 + type t = credential_manager 177 + 178 + module type S = sig 179 + val make : service:string -> unit -> t 180 + 181 + val on_session_update : t -> (session -> unit Lwt.t) -> unit 182 + 183 + val on_session_expired : t -> (unit -> unit Lwt.t) -> unit 184 + 185 + val get_session : t -> session option 186 + 187 + val login : 188 + t 189 + -> identifier:string 190 + -> password:string 191 + -> ?auth_factor_token:string 192 + -> unit 193 + -> Client.t Lwt.t 194 + 195 + val resume : t -> session:session -> unit -> Client.t Lwt.t 196 + 197 + val logout : t -> unit Lwt.t 198 + end 199 + 200 + module Make (_ : Client.S) : S 201 + end
+17
hermes/lib/http_backend.ml
··· 1 + (* abstract http backend for dependency injection *) 2 + 3 + type response = Cohttp.Response.t * Cohttp_lwt.Body.t 4 + 5 + module type S = sig 6 + val get : headers:Cohttp.Header.t -> Uri.t -> response Lwt.t 7 + 8 + val post : 9 + headers:Cohttp.Header.t -> body:Cohttp_lwt.Body.t -> Uri.t -> response Lwt.t 10 + end 11 + 12 + (* default implementation using cohttp-lwt-unix *) 13 + module Default : S = struct 14 + let get ~headers uri = Cohttp_lwt_unix.Client.get ~headers uri 15 + 16 + let post ~headers ~body uri = Cohttp_lwt_unix.Client.post ~headers ~body uri 17 + end
+45
hermes/lib/jwt.ml
··· 1 + type payload = 2 + { exp: int option [@default None] 3 + ; iat: int option [@default None] 4 + ; sub: string option [@default None] 5 + ; aud: string option [@default None] 6 + ; iss: string option [@default None] } 7 + [@@deriving yojson {strict= false}] 8 + 9 + (* decode jwt payload without signature verification *) 10 + let decode_payload (jwt : string) : (payload, string) result = 11 + try 12 + match String.split_on_char '.' jwt with 13 + | [_header; payload_str; _signature] -> ( 14 + match 15 + Base64.decode ~pad:false ~alphabet:Base64.uri_safe_alphabet payload_str 16 + with 17 + | Ok decoded -> ( 18 + let json = Yojson.Safe.from_string decoded in 19 + match payload_of_yojson json with Ok p -> Ok p | Error e -> Error e ) 20 + | Error (`Msg e) -> 21 + Error ("invalid base64 in JWT: " ^ e) ) 22 + | _ -> 23 + Error "invalid JWT format" 24 + with 25 + | Yojson.Json_error e -> 26 + Error ("invalid JSON in JWT payload: " ^ e) 27 + | e -> 28 + Error (Printexc.to_string e) 29 + 30 + (* check if jwt is expired with buffer in seconds *) 31 + let is_expired ?(buffer_seconds = 60) (jwt : string) : bool = 32 + match decode_payload jwt with 33 + | Ok {exp= Some exp; _} -> 34 + let now = int_of_float (Unix.time ()) in 35 + exp - buffer_seconds <= now 36 + | Ok {exp= None; _} -> 37 + (* no expiration, assume not expired *) 38 + false 39 + | Error _ -> 40 + (* can't decode, assume expired to be safe *) 41 + true 42 + 43 + (* get expiration time from jwt *) 44 + let get_expiration (jwt : string) : int option = 45 + match decode_payload jwt with Ok {exp; _} -> exp | Error _ -> None
+37
hermes/lib/types.ml
··· 1 + (* core types for xrpc client *) 2 + 3 + type blob = 4 + {ref_: Cid.t [@key "$link"]; mime_type: string [@key "mimeType"]; size: int64} 5 + [@@deriving yojson] 6 + 7 + type xrpc_error_payload = {error: string; message: string option [@default None]} 8 + [@@deriving yojson {strict= false}] 9 + 10 + exception Xrpc_error of {status: int; error: string; message: string option} 11 + 12 + type session = 13 + { access_jwt: string [@key "accessJwt"] 14 + ; refresh_jwt: string [@key "refreshJwt"] 15 + ; did: string 16 + ; handle: string 17 + ; pds_uri: string option [@key "pdsUri"] [@default None] 18 + ; email: string option [@default None] 19 + ; email_confirmed: bool option [@key "emailConfirmed"] [@default None] 20 + ; email_auth_factor: bool option [@key "emailAuthFactor"] [@default None] 21 + ; active: bool option [@default None] 22 + ; status: string option [@default None] } 23 + [@@deriving yojson {strict= false}] 24 + 25 + type login_request = 26 + { identifier: string 27 + ; password: string 28 + ; auth_factor_token: string option [@key "authFactorToken"] [@default None] } 29 + [@@deriving yojson] 30 + 31 + type refresh_request = unit [@@deriving yojson] 32 + 33 + let raise_xrpc_error ~status (payload : xrpc_error_payload) = 34 + raise (Xrpc_error {status; error= payload.error; message= payload.message}) 35 + 36 + let raise_xrpc_error_raw ~status ~error ?message () = 37 + raise (Xrpc_error {status; error; message})
+28
hermes/test/dune
··· 1 + (library 2 + (name test_support) 3 + (modules mock_http test_utils) 4 + (libraries hermes base64 http cohttp cohttp-lwt lwt yojson uri)) 5 + 6 + (test 7 + (name test_types) 8 + (modules test_types) 9 + (libraries alcotest hermes ipld yojson)) 10 + 11 + (test 12 + (name test_jwt) 13 + (modules test_jwt) 14 + (libraries alcotest hermes base64)) 15 + 16 + (test 17 + (name test_client) 18 + (modules test_client) 19 + (libraries alcotest hermes test_support lwt lwt.unix cohttp yojson) 20 + (preprocess 21 + (pps lwt_ppx))) 22 + 23 + (test 24 + (name test_credential_manager) 25 + (modules test_credential_manager) 26 + (libraries alcotest hermes test_support lwt lwt.unix cohttp yojson) 27 + (preprocess 28 + (pps lwt_ppx)))
+129
hermes/test/mock_http.ml
··· 1 + (** mock HTTP backend for testing *) 2 + 3 + open Lwt.Syntax 4 + 5 + type request = 6 + { meth: [`GET | `POST] 7 + ; uri: Uri.t 8 + ; headers: Cohttp.Header.t 9 + ; body: string option } 10 + 11 + type response = 12 + { status: Cohttp.Code.status_code 13 + ; headers: (string * string) list 14 + ; body: string } 15 + 16 + type handler = request -> response Lwt.t 17 + 18 + let make_cohttp_response (r : response) : Hermes.Http_backend.response = 19 + let resp = 20 + Cohttp.Response.make ~status:r.status 21 + ~headers:(Cohttp.Header.of_list r.headers) 22 + () 23 + in 24 + let body = Cohttp_lwt.Body.of_string r.body in 25 + (resp, body) 26 + 27 + (** create a mock HTTP backend with a handler *) 28 + module Make (Config : sig 29 + val handler : handler ref 30 + end) : Hermes.Http_backend.S = struct 31 + let get ~headers uri = 32 + let req = {meth= `GET; uri; headers; body= None} in 33 + let* r = !Config.handler req in 34 + Lwt.return (make_cohttp_response r) 35 + 36 + let post ~headers ~body uri = 37 + let* body_str = Cohttp_lwt.Body.to_string body in 38 + let req = {meth= `POST; uri; headers; body= Some body_str} in 39 + let* r = !Config.handler req in 40 + Lwt.return (make_cohttp_response r) 41 + end 42 + 43 + (** simple response builders *) 44 + 45 + let json_response ?(status = `OK) ?(headers = []) json = 46 + { status 47 + ; headers= ("content-type", "application/json") :: headers 48 + ; body= Yojson.Safe.to_string json } 49 + 50 + let bytes_response ?(status = `OK) ~content_type body = 51 + {status; headers= [("content-type", content_type)]; body} 52 + 53 + let error_response ~status ~error ?message () = 54 + let msg_field = 55 + match message with Some m -> [("message", `String m)] | None -> [] 56 + in 57 + json_response ~status (`Assoc ([("error", `String error)] @ msg_field)) 58 + 59 + let empty_response ?(status = `OK) () = {status; headers= []; body= ""} 60 + 61 + (** queue-based mock, returns responses in order *) 62 + module Queue = struct 63 + type t = {mutable responses: response list; mutable requests: request list} 64 + 65 + let create responses = {responses; requests= []} 66 + 67 + let handler q req = 68 + q.requests <- q.requests @ [req] ; 69 + match q.responses with 70 + | [] -> 71 + failwith "Mock_http.Queue: no more responses" 72 + | r :: rest -> 73 + q.responses <- rest ; 74 + Lwt.return r 75 + 76 + let get_requests q = q.requests 77 + 78 + let clear q = 79 + q.responses <- [] ; 80 + q.requests <- [] 81 + end 82 + 83 + (** pattern-matching mock, selects responses based on request *) 84 + module Pattern = struct 85 + type rule = 86 + { nsid: string option (** match NSID in path *) 87 + ; meth: [`GET | `POST] option (** match method *) 88 + ; response: response } 89 + 90 + type t = 91 + {rules: rule list; mutable requests: request list; default: response option} 92 + 93 + let create ?(default = None) rules = {rules; requests= []; default} 94 + 95 + let extract_nsid uri = 96 + let path = Uri.path uri in 97 + if String.length path > 6 && String.sub path 0 6 = "/xrpc/" then 98 + Some (String.sub path 6 (String.length path - 6)) 99 + else None 100 + 101 + let matches rule req = 102 + let nsid_matches = 103 + match rule.nsid with 104 + | None -> 105 + true 106 + | Some nsid -> 107 + extract_nsid req.uri = Some nsid 108 + in 109 + let meth_matches = 110 + match rule.meth with None -> true | Some m -> req.meth = m 111 + in 112 + nsid_matches && meth_matches 113 + 114 + let handler t req = 115 + t.requests <- t.requests @ [req] ; 116 + match List.find_opt (fun r -> matches r req) t.rules with 117 + | Some rule -> 118 + Lwt.return rule.response 119 + | None -> ( 120 + match t.default with 121 + | Some r -> 122 + Lwt.return r 123 + | None -> 124 + failwith 125 + ("Mock_http.Pattern: no matching rule for " ^ Uri.to_string req.uri) 126 + ) 127 + 128 + let get_requests t = t.requests 129 + end
+371
hermes/test/test_client.ml
··· 1 + open Alcotest 2 + open Lwt.Syntax 3 + open Test_support 4 + 5 + let run_lwt f = Lwt_main.run (f ()) 6 + 7 + (* helpers *) 8 + let test_string = testable Fmt.string String.equal 9 + 10 + (** query tests *) 11 + 12 + let test_query_success () = 13 + run_lwt 14 + @@ fun () -> 15 + let response = 16 + Mock_http.json_response 17 + (`Assoc 18 + [("did", `String "did:plc:123"); ("handle", `String "test.bsky.social")] 19 + ) 20 + in 21 + let* result, requests = 22 + Test_utils.with_mock_responses [response] (fun (module C) client -> 23 + C.query client "com.atproto.identity.resolveHandle" 24 + (`Assoc [("handle", `String "test.bsky.social")]) 25 + (fun json -> 26 + let open Yojson.Safe.Util in 27 + Ok (json |> member "did" |> to_string) ) ) 28 + in 29 + check test_string "result" "did:plc:123" result ; 30 + check int "request count" 1 (List.length requests) ; 31 + let req = List.hd requests in 32 + Test_utils.assert_request_path "/xrpc/com.atproto.identity.resolveHandle" req ; 33 + Test_utils.assert_request_method `GET req ; 34 + Test_utils.assert_request_query_param "handle" "test.bsky.social" req ; 35 + Lwt.return_unit 36 + 37 + let test_query_with_multiple_params () = 38 + run_lwt 39 + @@ fun () -> 40 + let response = Mock_http.json_response (`Assoc [("followers", `List [])]) in 41 + let* _, requests = 42 + Test_utils.with_mock_responses [response] (fun (module C) client -> 43 + C.query client "app.bsky.graph.getFollowers" 44 + (`Assoc 45 + [ ("actor", `String "did:plc:123") 46 + ; ("limit", `Int 50) 47 + ; ("cursor", `String "abc123") ] ) 48 + (fun _ -> Ok ()) ) 49 + in 50 + let req = List.hd requests in 51 + Test_utils.assert_request_query_param "actor" "did:plc:123" req ; 52 + Test_utils.assert_request_query_param "limit" "50" req ; 53 + Test_utils.assert_request_query_param "cursor" "abc123" req ; 54 + Lwt.return_unit 55 + 56 + let test_query_error_response () = 57 + run_lwt 58 + @@ fun () -> 59 + let response = 60 + Mock_http.error_response ~status:`Bad_request ~error:"InvalidHandle" 61 + ~message:"Handle not found" () 62 + in 63 + let* () = 64 + Test_utils.with_mock_responses [response] (fun (module C) client -> 65 + Lwt.catch 66 + (fun () -> 67 + let* _ = 68 + C.query client "com.atproto.identity.resolveHandle" 69 + (`Assoc [("handle", `String "invalid")]) 70 + (fun _ -> Ok ()) 71 + in 72 + fail "should have raised Xrpc_error" ) 73 + (function 74 + | Hermes.Xrpc_error {status; error; message} -> 75 + check int "status" 400 status ; 76 + check test_string "error" "InvalidHandle" error ; 77 + check (option test_string) "message" (Some "Handle not found") 78 + message ; 79 + Lwt.return_unit 80 + | e -> 81 + Lwt.reraise e ) ) 82 + |> Lwt.map fst 83 + in 84 + Lwt.return_unit 85 + 86 + let test_query_empty_response () = 87 + run_lwt 88 + @@ fun () -> 89 + let response = Mock_http.empty_response () in 90 + let* result, _ = 91 + Test_utils.with_mock_responses [response] (fun (module C) client -> 92 + C.query client "some.endpoint" (`Assoc []) (fun _ -> Ok "empty") ) 93 + in 94 + check test_string "result" "empty" result ; 95 + Lwt.return_unit 96 + 97 + let test_query_bytes () = 98 + run_lwt 99 + @@ fun () -> 100 + let response = 101 + Mock_http.bytes_response ~content_type:"image/jpeg" "fake-image-data" 102 + in 103 + let* (data, content_type), requests = 104 + Test_utils.with_mock_responses [response] (fun (module C) client -> 105 + C.query_bytes client "com.atproto.sync.getBlob" 106 + (`Assoc [("did", `String "did:plc:123"); ("cid", `String "bafyabc")]) ) 107 + in 108 + check test_string "data" "fake-image-data" data ; 109 + check test_string "content_type" "image/jpeg" content_type ; 110 + let req = List.hd requests in 111 + Test_utils.assert_request_has_header "accept" "*/*" req ; 112 + Lwt.return_unit 113 + 114 + (** procedure tests *) 115 + 116 + let test_procedure_success () = 117 + run_lwt 118 + @@ fun () -> 119 + let response = 120 + Mock_http.json_response 121 + (`Assoc [("uri", `String "at://did:plc:123/app.bsky.feed.post/abc")]) 122 + in 123 + let* result, requests = 124 + Test_utils.with_mock_responses [response] (fun (module C) client -> 125 + C.procedure client "com.atproto.repo.createRecord" (`Assoc []) 126 + (Some 127 + (`Assoc 128 + [ ("repo", `String "did:plc:123") 129 + ; ("collection", `String "app.bsky.feed.post") 130 + ; ( "record" 131 + , `Assoc [("text", `String "This post was sent from PDSls")] 132 + ) ] ) ) 133 + (fun json -> 134 + let open Yojson.Safe.Util in 135 + Ok (json |> member "uri" |> to_string) ) ) 136 + in 137 + check test_string "uri" "at://did:plc:123/app.bsky.feed.post/abc" result ; 138 + let req = List.hd requests in 139 + Test_utils.assert_request_method `POST req ; 140 + Test_utils.assert_request_path "/xrpc/com.atproto.repo.createRecord" req ; 141 + Test_utils.assert_request_has_header "content-type" "application/json" req ; 142 + Test_utils.assert_request_body_contains "This post was sent from PDSls" req ; 143 + Lwt.return_unit 144 + 145 + let test_procedure_no_input () = 146 + run_lwt 147 + @@ fun () -> 148 + let response = Mock_http.empty_response () in 149 + let* _, requests = 150 + Test_utils.with_mock_responses [response] (fun (module C) client -> 151 + C.procedure client "com.atproto.server.deleteSession" (`Assoc []) None 152 + (fun _ -> Ok () ) ) 153 + in 154 + let req = List.hd requests in 155 + Test_utils.assert_request_method `POST req ; 156 + check (option test_string) "body" (Some "") req.body ; 157 + Lwt.return_unit 158 + 159 + let test_procedure_bytes () = 160 + run_lwt 161 + @@ fun () -> 162 + let response = Mock_http.empty_response () in 163 + let* result, requests = 164 + Test_utils.with_mock_responses [response] (fun (module C) client -> 165 + C.procedure_bytes client "com.atproto.repo.importRepo" (`Assoc []) 166 + (Some "fake-car-data") ~content_type:"application/vnd.ipld.car" ) 167 + in 168 + check (option (pair test_string test_string)) "result" None result ; 169 + let req = List.hd requests in 170 + Test_utils.assert_request_has_header "content-type" "application/vnd.ipld.car" 171 + req ; 172 + Test_utils.assert_request_has_header "accept" "*/*" req ; 173 + check (option test_string) "body" (Some "fake-car-data") req.body ; 174 + Lwt.return_unit 175 + 176 + let test_procedure_blob () = 177 + run_lwt 178 + @@ fun () -> 179 + let response = 180 + Mock_http.json_response 181 + (`Assoc 182 + [ ( "blob" 183 + , `Assoc 184 + [ ("$type", `String "blob") 185 + ; ("ref", `Assoc [("$link", `String "bafyabc")]) 186 + ; ("mimeType", `String "image/jpeg") 187 + ; ("size", `Int 1234) ] ) ] ) 188 + in 189 + let* result, requests = 190 + Test_utils.with_mock_responses [response] (fun (module C) client -> 191 + C.procedure_blob client "com.atproto.repo.uploadBlob" (`Assoc []) 192 + (Bytes.of_string "fake-image-bytes") ~content_type:"image/jpeg" 193 + (fun json -> 194 + let open Yojson.Safe.Util in 195 + Ok (json |> member "blob" |> member "mimeType" |> to_string) ) ) 196 + in 197 + check test_string "mimeType" "image/jpeg" result ; 198 + let req = List.hd requests in 199 + Test_utils.assert_request_has_header "content-type" "image/jpeg" req ; 200 + check (option test_string) "body" (Some "fake-image-bytes") req.body ; 201 + Lwt.return_unit 202 + 203 + (** authentication tests *) 204 + 205 + let test_auth_header_added () = 206 + run_lwt 207 + @@ fun () -> 208 + let response = Mock_http.json_response (`Assoc []) in 209 + let* _, requests = 210 + Test_utils.with_mock_responses [response] (fun (module C) client -> 211 + let session = Test_utils.make_test_session () in 212 + C.set_session client session ; 213 + C.query client "some.endpoint" (`Assoc []) (fun _ -> Ok ()) ) 214 + in 215 + let req = List.hd requests in 216 + Test_utils.assert_request_has_auth_header req ; 217 + Lwt.return_unit 218 + 219 + let test_session_can_be_cleared () = 220 + run_lwt 221 + @@ fun () -> 222 + let response = Mock_http.json_response (`Assoc []) in 223 + let* _, requests = 224 + Test_utils.with_mock_responses [response] (fun (module C) client -> 225 + let session = Test_utils.make_test_session () in 226 + C.set_session client session ; 227 + C.clear_session client ; 228 + C.query client "some.endpoint" (`Assoc []) (fun _ -> Ok ()) ) 229 + in 230 + let req = List.hd requests in 231 + let has_auth = Cohttp.Header.get req.headers "authorization" in 232 + check (option test_string) "no auth header" None has_auth ; 233 + Lwt.return_unit 234 + 235 + (** error handling tests *) 236 + 237 + let test_401_unauthorized () = 238 + run_lwt 239 + @@ fun () -> 240 + let response = 241 + Mock_http.error_response ~status:`Unauthorized ~error:"AuthRequired" 242 + ~message:"Authentication required" () 243 + in 244 + let* () = 245 + Test_utils.with_mock_responses [response] (fun (module C) client -> 246 + Lwt.catch 247 + (fun () -> 248 + let* _ = 249 + C.query client "some.protected.endpoint" (`Assoc []) (fun _ -> 250 + Ok () ) 251 + in 252 + fail "should have raised" ) 253 + (function 254 + | Hermes.Xrpc_error {status= 401; error= "AuthRequired"; _} -> 255 + Lwt.return_unit 256 + | e -> 257 + Lwt.reraise e ) ) 258 + |> Lwt.map fst 259 + in 260 + Lwt.return_unit 261 + 262 + let test_500_server_error () = 263 + run_lwt 264 + @@ fun () -> 265 + let response = 266 + Mock_http.error_response ~status:`Internal_server_error 267 + ~error:"InternalServerError" () 268 + in 269 + let* () = 270 + Test_utils.with_mock_responses [response] (fun (module C) client -> 271 + Lwt.catch 272 + (fun () -> 273 + let* _ = 274 + C.query client "some.endpoint" (`Assoc []) (fun _ -> Ok ()) 275 + in 276 + fail "Should have raised" ) 277 + (function 278 + | Hermes.Xrpc_error {status= 500; _} -> 279 + Lwt.return_unit 280 + | e -> 281 + Lwt.reraise e ) ) 282 + |> Lwt.map fst 283 + in 284 + Lwt.return_unit 285 + 286 + let test_malformed_error_response () = 287 + run_lwt 288 + @@ fun () -> 289 + let response = 290 + { Mock_http.status= `Bad_request 291 + ; headers= [("content-type", "application/json")] 292 + ; body= "not valid json" } 293 + in 294 + let* () = 295 + Test_utils.with_mock_responses [response] (fun (module C) client -> 296 + Lwt.catch 297 + (fun () -> 298 + let* _ = 299 + C.query client "some.endpoint" (`Assoc []) (fun _ -> Ok ()) 300 + in 301 + fail "should have raised" ) 302 + (function 303 + | Hermes.Xrpc_error {status= 400; error= "UnknownError"; _} -> 304 + Lwt.return_unit 305 + | e -> 306 + Lwt.reraise e ) ) 307 + |> Lwt.map fst 308 + in 309 + Lwt.return_unit 310 + 311 + (** client creation tests *) 312 + 313 + let test_make_client () = 314 + let client = Hermes.make_client ~service:"https://api.bsky.app" () in 315 + let service = Hermes.get_service client in 316 + check (option test_string) "host" (Some "api.bsky.app") (Uri.host service) 317 + 318 + let test_client_service_urls () = 319 + let urls = 320 + [ "https://bsky.social" 321 + ; "https://api.bsky.app" 322 + ; "http://localhost:3000" 323 + ; "https://pds.example.com:8080" ] 324 + in 325 + List.iter 326 + (fun url -> 327 + let client = Hermes.make_client ~service:url () in 328 + let service = Hermes.get_service client in 329 + check bool "service set" true (String.length (Uri.to_string service) > 0) ) 330 + urls 331 + 332 + let test_get_session_unauthenticated () = 333 + let client = Hermes.make_client ~service:"https://example.com" () in 334 + check (option reject) "no session" None (Hermes.get_session client) 335 + 336 + (** tests *) 337 + 338 + let query_tests = 339 + [ ("query success", `Quick, test_query_success) 340 + ; ("query with multiple params", `Quick, test_query_with_multiple_params) 341 + ; ("query error response", `Quick, test_query_error_response) 342 + ; ("query empty response", `Quick, test_query_empty_response) 343 + ; ("query bytes", `Quick, test_query_bytes) ] 344 + 345 + let procedure_tests = 346 + [ ("procedure success", `Quick, test_procedure_success) 347 + ; ("procedure no input", `Quick, test_procedure_no_input) 348 + ; ("procedure bytes", `Quick, test_procedure_bytes) 349 + ; ("procedure blob", `Quick, test_procedure_blob) ] 350 + 351 + let auth_tests = 352 + [ ("auth header added", `Quick, test_auth_header_added) 353 + ; ("session can be cleared", `Quick, test_session_can_be_cleared) ] 354 + 355 + let error_tests = 356 + [ ("401 unauthorized", `Quick, test_401_unauthorized) 357 + ; ("500 server error", `Quick, test_500_server_error) 358 + ; ("malformed error response", `Quick, test_malformed_error_response) ] 359 + 360 + let creation_tests = 361 + [ ("make_client", `Quick, test_make_client) 362 + ; ("service URLs", `Quick, test_client_service_urls) 363 + ; ("get_session unauthenticated", `Quick, test_get_session_unauthenticated) ] 364 + 365 + let () = 366 + run "Client" 367 + [ ("query", query_tests) 368 + ; ("procedure", procedure_tests) 369 + ; ("auth", auth_tests) 370 + ; ("errors", error_tests) 371 + ; ("creation", creation_tests) ]
+397
hermes/test/test_credential_manager.ml
··· 1 + (** tests for Hermes.Credential_manager *) 2 + 3 + open Alcotest 4 + open Lwt.Syntax 5 + open Test_support 6 + 7 + let run_lwt f = Lwt_main.run (f ()) 8 + 9 + (* helpers *) 10 + let test_string = testable Fmt.string String.equal 11 + 12 + (** login tests *) 13 + 14 + let test_login_success () = 15 + run_lwt 16 + @@ fun () -> 17 + let session = Test_utils.make_test_session () in 18 + let response = 19 + Mock_http.json_response 20 + (`Assoc 21 + [ ("accessJwt", `String session.access_jwt) 22 + ; ("refreshJwt", `String session.refresh_jwt) 23 + ; ("did", `String session.did) 24 + ; ("handle", `String session.handle) ] ) 25 + in 26 + let* client, requests = 27 + Test_utils.with_mock_credential_manager [response] 28 + (fun (module CM) (module C) manager -> 29 + let* client = 30 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 31 + in 32 + Lwt.return client ) 33 + in 34 + check int "request count" 1 (List.length requests) ; 35 + let req = List.hd requests in 36 + Test_utils.assert_request_path "/xrpc/com.atproto.server.createSession" req ; 37 + Test_utils.assert_request_method `POST req ; 38 + Test_utils.assert_request_body_contains "test@example.com" req ; 39 + Test_utils.assert_request_body_contains "secret" req ; 40 + check bool "client has session" true (Hermes.get_session client <> None) ; 41 + Lwt.return_unit 42 + 43 + let test_login_error () = 44 + run_lwt 45 + @@ fun () -> 46 + let response = 47 + Mock_http.error_response ~status:`Unauthorized 48 + ~error:"AuthenticationRequired" ~message:"Invalid credentials" () 49 + in 50 + let* () = 51 + Test_utils.with_mock_credential_manager [response] 52 + (fun (module CM) (module C : Hermes.Client.S) manager -> 53 + let _ = C.make in 54 + (* suppress unused warning *) 55 + Lwt.catch 56 + (fun () -> 57 + let* _ = 58 + CM.login manager ~identifier:"test@example.com" ~password:"wrong" 59 + () 60 + in 61 + fail "should have raised Xrpc_error" ) 62 + (function 63 + | Hermes.Xrpc_error {status; error; _} -> 64 + check int "status" 401 status ; 65 + check test_string "error" "AuthenticationRequired" error ; 66 + Lwt.return_unit 67 + | e -> 68 + Lwt.reraise e ) ) 69 + |> Lwt.map fst 70 + in 71 + Lwt.return_unit 72 + 73 + let test_login_with_auth_factor () = 74 + run_lwt 75 + @@ fun () -> 76 + let session = Test_utils.make_test_session () in 77 + let response = 78 + Mock_http.json_response 79 + (`Assoc 80 + [ ("accessJwt", `String session.access_jwt) 81 + ; ("refreshJwt", `String session.refresh_jwt) 82 + ; ("did", `String session.did) 83 + ; ("handle", `String session.handle) ] ) 84 + in 85 + let* _, requests = 86 + Test_utils.with_mock_credential_manager [response] 87 + (fun (module CM) (module C : Hermes.Client.S) manager -> 88 + let* _ = 89 + CM.login manager ~identifier:"test@example.com" ~password:"secret" 90 + ~auth_factor_token:"123456" () 91 + in 92 + Lwt.return () ) 93 + in 94 + let req = List.hd requests in 95 + Test_utils.assert_request_body_contains "123456" req ; 96 + Lwt.return_unit 97 + 98 + (* resume session tests *) 99 + 100 + let test_resume_session () = 101 + run_lwt 102 + @@ fun () -> 103 + let session = Test_utils.make_test_session () in 104 + let* client, _requests = 105 + Test_utils.with_mock_credential_manager [] 106 + (fun (module CM) (module C : Hermes.Client.S) manager -> 107 + let* client = CM.resume manager ~session () in 108 + Lwt.return client ) 109 + in 110 + let client_session = Hermes.get_session client in 111 + check bool "client has session" true (client_session <> None) ; 112 + let s = Option.get client_session in 113 + check test_string "did matches" session.did s.did ; 114 + check test_string "handle matches" session.handle s.handle ; 115 + Lwt.return_unit 116 + 117 + (** session callback tests *) 118 + 119 + let test_session_update_callback () = 120 + run_lwt 121 + @@ fun () -> 122 + let session = Test_utils.make_test_session () in 123 + let response = 124 + Mock_http.json_response 125 + (`Assoc 126 + [ ("accessJwt", `String session.access_jwt) 127 + ; ("refreshJwt", `String session.refresh_jwt) 128 + ; ("did", `String session.did) 129 + ; ("handle", `String session.handle) ] ) 130 + in 131 + let callback_called = ref false in 132 + let received_session = ref None in 133 + let* _, _ = 134 + Test_utils.with_mock_credential_manager [response] 135 + (fun (module CM) (module C : Hermes.Client.S) manager -> 136 + CM.on_session_update manager (fun s -> 137 + callback_called := true ; 138 + received_session := Some s ; 139 + Lwt.return_unit ) ; 140 + let* _ = 141 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 142 + in 143 + Lwt.return () ) 144 + in 145 + check bool "callback was called" true !callback_called ; 146 + check bool "session was received" true (!received_session <> None) ; 147 + let s = Option.get !received_session in 148 + check test_string "received did" session.did s.did ; 149 + Lwt.return_unit 150 + 151 + let test_session_expired_callback () = 152 + run_lwt 153 + @@ fun () -> 154 + (* first log in, then log out with server error; should still call expired callback *) 155 + let session = Test_utils.make_test_session () in 156 + let login_response = 157 + Mock_http.json_response 158 + (`Assoc 159 + [ ("accessJwt", `String session.access_jwt) 160 + ; ("refreshJwt", `String session.refresh_jwt) 161 + ; ("did", `String session.did) 162 + ; ("handle", `String session.handle) ] ) 163 + in 164 + (* log out endpoint returns error but should still clear session *) 165 + let logout_response = 166 + Mock_http.error_response ~status:`Internal_server_error 167 + ~error:"InternalServerError" () 168 + in 169 + let expired_called = ref false in 170 + let* _, _ = 171 + Test_utils.with_mock_credential_manager [login_response; logout_response] 172 + (fun (module CM) (module C : Hermes.Client.S) manager -> 173 + CM.on_session_expired manager (fun () -> 174 + expired_called := true ; 175 + Lwt.return_unit ) ; 176 + let* _ = 177 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 178 + in 179 + let* () = CM.logout manager in 180 + Lwt.return () ) 181 + in 182 + check bool "expired callback was called" true !expired_called ; 183 + Lwt.return_unit 184 + 185 + (** log out tests *) 186 + 187 + let test_logout_success () = 188 + run_lwt 189 + @@ fun () -> 190 + let session = Test_utils.make_test_session () in 191 + let login_response = 192 + Mock_http.json_response 193 + (`Assoc 194 + [ ("accessJwt", `String session.access_jwt) 195 + ; ("refreshJwt", `String session.refresh_jwt) 196 + ; ("did", `String session.did) 197 + ; ("handle", `String session.handle) ] ) 198 + in 199 + let logout_response = Mock_http.empty_response () in 200 + let* manager_session, requests = 201 + Test_utils.with_mock_credential_manager [login_response; logout_response] 202 + (fun (module CM) (module C : Hermes.Client.S) manager -> 203 + let* _ = 204 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 205 + in 206 + let* () = CM.logout manager in 207 + Lwt.return (CM.get_session manager) ) 208 + in 209 + check (option reject) "manager session cleared" None manager_session ; 210 + check int "request count" 2 (List.length requests) ; 211 + let logout_req = List.nth requests 1 in 212 + Test_utils.assert_request_path "/xrpc/com.atproto.server.deleteSession" 213 + logout_req ; 214 + Test_utils.assert_request_method `POST logout_req ; 215 + Lwt.return_unit 216 + 217 + let test_logout_no_session () = 218 + run_lwt 219 + @@ fun () -> 220 + let* _, requests = 221 + Test_utils.with_mock_credential_manager [] 222 + (fun (module CM) (module C : Hermes.Client.S) manager -> 223 + let* () = CM.logout manager in 224 + Lwt.return () ) 225 + in 226 + check int "no requests made" 0 (List.length requests) ; 227 + Lwt.return_unit 228 + 229 + (* ===== Get Session Tests ===== *) 230 + 231 + let test_get_session_before_login () = 232 + run_lwt 233 + @@ fun () -> 234 + let* session, _ = 235 + Test_utils.with_mock_credential_manager [] 236 + (fun (module CM) (module C : Hermes.Client.S) manager -> 237 + Lwt.return (CM.get_session manager) ) 238 + in 239 + check (option reject) "no session" None session ; 240 + Lwt.return_unit 241 + 242 + let test_get_session_after_login () = 243 + run_lwt 244 + @@ fun () -> 245 + let session = Test_utils.make_test_session () in 246 + let response = 247 + Mock_http.json_response 248 + (`Assoc 249 + [ ("accessJwt", `String session.access_jwt) 250 + ; ("refreshJwt", `String session.refresh_jwt) 251 + ; ("did", `String session.did) 252 + ; ("handle", `String session.handle) ] ) 253 + in 254 + let* manager_session, _ = 255 + Test_utils.with_mock_credential_manager [response] 256 + (fun (module CM) (module C : Hermes.Client.S) manager -> 257 + let* _ = 258 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 259 + in 260 + Lwt.return (CM.get_session manager) ) 261 + in 262 + check bool "has session" true (manager_session <> None) ; 263 + let s = Option.get manager_session in 264 + check test_string "did matches" session.did s.did ; 265 + Lwt.return_unit 266 + 267 + (** token refresh tests *) 268 + 269 + let test_automatic_token_refresh () = 270 + run_lwt 271 + @@ fun () -> 272 + (* create session with token expiring soon *) 273 + let expired_session = Test_utils.make_test_session ~exp_offset:60 () in 274 + let new_session = Test_utils.make_test_session ~exp_offset:3600 () in 275 + let login_response = 276 + Mock_http.json_response 277 + (`Assoc 278 + [ ("accessJwt", `String expired_session.access_jwt) 279 + ; ("refreshJwt", `String expired_session.refresh_jwt) 280 + ; ("did", `String expired_session.did) 281 + ; ("handle", `String expired_session.handle) ] ) 282 + in 283 + let refresh_response = 284 + Mock_http.json_response 285 + (`Assoc 286 + [ ("accessJwt", `String new_session.access_jwt) 287 + ; ("refreshJwt", `String new_session.refresh_jwt) 288 + ; ("did", `String new_session.did) 289 + ; ("handle", `String new_session.handle) ] ) 290 + in 291 + (* API call response after refresh *) 292 + let api_response = 293 + Mock_http.json_response (`Assoc [("data", `String "success")]) 294 + in 295 + let* result, requests = 296 + Test_utils.with_mock_credential_manager 297 + [login_response; refresh_response; api_response] 298 + (fun (module CM) (module C) manager -> 299 + let* client = 300 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 301 + in 302 + (* making an API call should trigger token refresh *) 303 + let* result = 304 + C.query client "test.endpoint" (`Assoc []) (fun json -> 305 + let open Yojson.Safe.Util in 306 + Ok (json |> member "data" |> to_string) ) 307 + in 308 + Lwt.return result ) 309 + in 310 + check test_string "api result" "success" result ; 311 + check int "request count" 3 (List.length requests) ; 312 + (* second request should be refresh *) 313 + let refresh_req = List.nth requests 1 in 314 + Test_utils.assert_request_path "/xrpc/com.atproto.server.refreshSession" 315 + refresh_req ; 316 + Lwt.return_unit 317 + 318 + let test_refresh_failure_clears_session () = 319 + run_lwt 320 + @@ fun () -> 321 + (* create session with expired token *) 322 + let expired_session = Test_utils.make_test_session ~exp_offset:60 () in 323 + let login_response = 324 + Mock_http.json_response 325 + (`Assoc 326 + [ ("accessJwt", `String expired_session.access_jwt) 327 + ; ("refreshJwt", `String expired_session.refresh_jwt) 328 + ; ("did", `String expired_session.did) 329 + ; ("handle", `String expired_session.handle) ] ) 330 + in 331 + (* refresh fails with ExpiredToken *) 332 + let refresh_response = 333 + Mock_http.error_response ~status:`Bad_request ~error:"ExpiredToken" 334 + ~message:"Refresh token expired" () 335 + in 336 + let expired_called = ref false in 337 + let* () = 338 + Test_utils.with_mock_credential_manager [login_response; refresh_response] 339 + (fun (module CM) (module C) manager -> 340 + CM.on_session_expired manager (fun () -> 341 + expired_called := true ; 342 + Lwt.return_unit ) ; 343 + let* client = 344 + CM.login manager ~identifier:"test@example.com" ~password:"secret" () 345 + in 346 + (* making an API call should trigger token refresh which will fail *) 347 + Lwt.catch 348 + (fun () -> 349 + let* _ = 350 + C.query client "test.endpoint" (`Assoc []) (fun _ -> Ok ()) 351 + in 352 + fail "should have raised SessionExpired" ) 353 + (function 354 + | Hermes.Xrpc_error {error= "SessionExpired"; _} -> 355 + Lwt.return_unit 356 + | e -> 357 + Lwt.reraise e ) ) 358 + |> Lwt.map fst 359 + in 360 + check bool "expired callback was called" true !expired_called ; 361 + Lwt.return_unit 362 + 363 + (** tests *) 364 + 365 + let login_tests = 366 + [ ("login success", `Quick, test_login_success) 367 + ; ("login error", `Quick, test_login_error) 368 + ; ("login with auth factor", `Quick, test_login_with_auth_factor) ] 369 + 370 + let resume_tests = [("resume session", `Quick, test_resume_session)] 371 + 372 + let callback_tests = 373 + [ ("session update callback", `Quick, test_session_update_callback) 374 + ; ("session expired callback", `Quick, test_session_expired_callback) ] 375 + 376 + let logout_tests = 377 + [ ("logout success", `Quick, test_logout_success) 378 + ; ("logout no session", `Quick, test_logout_no_session) ] 379 + 380 + let session_tests = 381 + [ ("get session before login", `Quick, test_get_session_before_login) 382 + ; ("get session after login", `Quick, test_get_session_after_login) ] 383 + 384 + let refresh_tests = 385 + [ ("automatic token refresh", `Quick, test_automatic_token_refresh) 386 + ; ( "refresh failure clears session" 387 + , `Quick 388 + , test_refresh_failure_clears_session ) ] 389 + 390 + let () = 391 + run "Credential_manager" 392 + [ ("login", login_tests) 393 + ; ("resume", resume_tests) 394 + ; ("callbacks", callback_tests) 395 + ; ("logout", logout_tests) 396 + ; ("session", session_tests) 397 + ; ("refresh", refresh_tests) ]
+160
hermes/test/test_jwt.ml
··· 1 + open Alcotest 2 + 3 + (** helpers *) 4 + let test_string = testable Fmt.string String.equal 5 + 6 + (* create a minimal jwt 7 + we only care about the (base64url encoded json) payload, so header and signature can be anything *) 8 + let make_jwt payload_json = 9 + let header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" in 10 + (* {"alg":"HS256","typ":"JWT"} *) 11 + let payload = 12 + Base64.encode_string ~alphabet:Base64.uri_safe_alphabet ~pad:false 13 + payload_json 14 + in 15 + let signature = "signature" in 16 + header ^ "." ^ payload ^ "." ^ signature 17 + 18 + (* decoding a valid JWT *) 19 + let test_decode_valid () = 20 + let now = int_of_float (Unix.time ()) in 21 + let payload_json = 22 + Printf.sprintf {|{"sub":"did:plc:test","iat":%d,"exp":%d}|} now (now + 3600) 23 + in 24 + let jwt = make_jwt payload_json in 25 + match Hermes.Jwt.decode_payload jwt with 26 + | Ok payload -> 27 + check (option test_string) "sub" (Some "did:plc:test") payload.sub ; 28 + check (option int) "iat" (Some now) payload.iat ; 29 + check (option int) "exp" (Some (now + 3600)) payload.exp 30 + | Error e -> 31 + fail ("decode failed: " ^ e) 32 + 33 + (* decoding JWT with additional fields *) 34 + let test_decode_with_extra_fields () = 35 + let now = int_of_float (Unix.time ()) in 36 + let payload_json = 37 + Printf.sprintf 38 + {|{"sub":"did:plc:extra","iat":%d,"exp":%d,"scope":"atproto","aud":"did:web:server"}|} 39 + now (now + 3600) 40 + in 41 + let jwt = make_jwt payload_json in 42 + match Hermes.Jwt.decode_payload jwt with 43 + | Ok payload -> 44 + check (option test_string) "sub" (Some "did:plc:extra") payload.sub ; 45 + check (option int) "exp" (Some (now + 3600)) payload.exp ; 46 + check (option test_string) "aud" (Some "did:web:server") payload.aud 47 + | Error e -> 48 + fail ("decode failed: " ^ e) 49 + 50 + (* decoding invalid JWT format *) 51 + let test_decode_invalid_format () = 52 + let invalid = "not.a.jwt.with.wrong.parts" in 53 + match Hermes.Jwt.decode_payload invalid with 54 + | Ok _ -> 55 + fail "should have failed" 56 + | Error e -> 57 + check bool "has error" true (String.length e > 0) 58 + 59 + let test_decode_no_dots () = 60 + let invalid = "nodots" in 61 + match Hermes.Jwt.decode_payload invalid with 62 + | Ok _ -> 63 + fail "should have failed" 64 + | Error _ -> 65 + () 66 + 67 + (* decoding JWT with invalid base64 *) 68 + let test_decode_invalid_base64 () = 69 + let invalid = "header.!!!invalid!!!.signature" in 70 + match Hermes.Jwt.decode_payload invalid with 71 + | Ok _ -> 72 + fail "should have failed" 73 + | Error _ -> 74 + () 75 + 76 + (* decoding JWT with invalid JSON payload *) 77 + let test_decode_invalid_json () = 78 + let header = "eyJhbGciOiJIUzI1NiJ9" in 79 + let payload = 80 + Base64.encode_string ~alphabet:Base64.uri_safe_alphabet ~pad:false 81 + "not json" 82 + in 83 + let jwt = header ^ "." ^ payload ^ ".sig" in 84 + match Hermes.Jwt.decode_payload jwt with 85 + | Ok _ -> 86 + fail "should have failed" 87 + | Error _ -> 88 + () 89 + 90 + (* test is_expired with expired token *) 91 + let test_expired_token () = 92 + let past = int_of_float (Unix.time ()) - 3600 in 93 + (* 1 hour ago *) 94 + let payload_json = 95 + Printf.sprintf {|{"sub":"test","iat":%d,"exp":%d}|} past past 96 + in 97 + let jwt = make_jwt payload_json in 98 + check bool "is expired" true (Hermes.Jwt.is_expired jwt) 99 + 100 + (* test is_expired with valid token *) 101 + let test_valid_token () = 102 + let now = int_of_float (Unix.time ()) in 103 + let future = now + 3600 in 104 + (* 1 hour from now *) 105 + let payload_json = 106 + Printf.sprintf {|{"sub":"test","iat":%d,"exp":%d}|} now future 107 + in 108 + let jwt = make_jwt payload_json in 109 + check bool "is not expired" false (Hermes.Jwt.is_expired jwt) 110 + 111 + (* test is_expired with buffer *) 112 + let test_expired_with_buffer () = 113 + let now = int_of_float (Unix.time ()) in 114 + let almost_expired = now + 30 in 115 + (* expires in 30 seconds *) 116 + let payload_json = 117 + Printf.sprintf {|{"sub":"test","iat":%d,"exp":%d}|} now almost_expired 118 + in 119 + let jwt = make_jwt payload_json in 120 + (* default buffer is 60 seconds, so this should be considered expired *) 121 + check bool "expired with buffer" true (Hermes.Jwt.is_expired jwt) 122 + 123 + (* is_expired with invalid JWT returns true *) 124 + let test_expired_invalid_jwt () = 125 + check bool "invalid JWT treated as expired" true 126 + (Hermes.Jwt.is_expired "invalid") 127 + 128 + (* test get_expiration *) 129 + let test_get_expiration () = 130 + let now = int_of_float (Unix.time ()) in 131 + let exp_time = now + 3600 in 132 + let payload_json = Printf.sprintf {|{"sub":"test","exp":%d}|} exp_time in 133 + let jwt = make_jwt payload_json in 134 + check (option int) "expiration" (Some exp_time) 135 + (Hermes.Jwt.get_expiration jwt) 136 + 137 + let test_get_expiration_missing () = 138 + let payload_json = {|{"sub":"test"}|} in 139 + let jwt = make_jwt payload_json in 140 + check (option int) "no expiration" None (Hermes.Jwt.get_expiration jwt) 141 + 142 + (** tests *) 143 + 144 + let decode_tests = 145 + [ ("decode valid JWT", `Quick, test_decode_valid) 146 + ; ("decode with extra fields", `Quick, test_decode_with_extra_fields) 147 + ; ("decode invalid format", `Quick, test_decode_invalid_format) 148 + ; ("decode no dots", `Quick, test_decode_no_dots) 149 + ; ("decode invalid base64", `Quick, test_decode_invalid_base64) 150 + ; ("decode invalid json", `Quick, test_decode_invalid_json) ] 151 + 152 + let expiry_tests = 153 + [ ("expired token", `Quick, test_expired_token) 154 + ; ("valid token", `Quick, test_valid_token) 155 + ; ("expired with buffer", `Quick, test_expired_with_buffer) 156 + ; ("invalid JWT is expired", `Quick, test_expired_invalid_jwt) 157 + ; ("get_expiration", `Quick, test_get_expiration) 158 + ; ("get_expiration missing", `Quick, test_get_expiration_missing) ] 159 + 160 + let () = run "Jwt" [("decode", decode_tests); ("expiry", expiry_tests)]
+172
hermes/test/test_types.ml
··· 1 + open Alcotest 2 + 3 + (** helpers *) 4 + let test_string = testable Fmt.string String.equal 5 + 6 + let test_int64 = testable Fmt.int64 Int64.equal 7 + 8 + (* blob tests *) 9 + let test_blob_to_yojson () = 10 + let cid = 11 + Cid.of_string "bafyreib7k7m7h7x6qrvxqrwqe2m6q5p5zklvp4fqq2g6hh6h6t6x6x6x6y" 12 + in 13 + match cid with 14 + | Ok cid -> ( 15 + let blob : Hermes.blob = 16 + {ref_= cid; mime_type= "image/png"; size= 12345L} 17 + in 18 + let json = Hermes.blob_to_yojson blob in 19 + let json_str = Yojson.Safe.to_string json in 20 + check bool "contains mimeType" true 21 + (String.length json_str > 0 && String.sub json_str 0 1 = "{") ; 22 + (* verify valid JSON with expected structure *) 23 + match json with 24 + | `Assoc pairs -> 25 + check bool "has mimeType key" true 26 + (List.exists (fun (k, _) -> k = "mimeType") pairs) ; 27 + check bool "has size key" true 28 + (List.exists (fun (k, _) -> k = "size") pairs) 29 + | _ -> 30 + fail "expected object" ) 31 + | Error _ -> 32 + fail "couldn't parse cid constant" 33 + 34 + let test_blob_roundtrip () = 35 + let cid = 36 + Cid.of_string "bafyreib7k7m7h7x6qrvxqrwqe2m6q5p5zklvp4fqq2g6hh6h6t6x6x6x6y" 37 + in 38 + match cid with 39 + | Ok cid -> ( 40 + let original : Hermes.blob = 41 + {ref_= cid; mime_type= "image/jpeg"; size= 54321L} 42 + in 43 + let json = Hermes.blob_to_yojson original in 44 + match Hermes.blob_of_yojson json with 45 + | Ok decoded -> 46 + check test_string "mime_type matches" original.mime_type 47 + decoded.mime_type ; 48 + check test_int64 "size matches" original.size decoded.size 49 + | Error e -> 50 + fail ("roundtrip failed: " ^ e) ) 51 + | Error _ -> 52 + () 53 + 54 + (** session tests *) 55 + let test_session_to_yojson () = 56 + let session : Hermes.session = 57 + { access_jwt= "eyJ..." 58 + ; refresh_jwt= "eyJ..." 59 + ; did= "did:plc:example" 60 + ; handle= "user.bsky.social" 61 + ; pds_uri= Some "https://pds.example.com" 62 + ; email= Some "user@example.com" 63 + ; email_confirmed= Some true 64 + ; email_auth_factor= Some false 65 + ; active= Some true 66 + ; status= None } 67 + in 68 + let json = Hermes.session_to_yojson session in 69 + match json with 70 + | `Assoc pairs -> 71 + check bool "has accessJwt" true 72 + (List.exists (fun (k, _) -> k = "accessJwt") pairs) ; 73 + check bool "has refreshJwt" true 74 + (List.exists (fun (k, _) -> k = "refreshJwt") pairs) ; 75 + check bool "has did" true (List.exists (fun (k, _) -> k = "did") pairs) ; 76 + check bool "has handle" true 77 + (List.exists (fun (k, _) -> k = "handle") pairs) 78 + | _ -> 79 + fail "expected object" 80 + 81 + let test_session_of_yojson_full () = 82 + let json = 83 + Yojson.Safe.from_string 84 + {|{ 85 + "accessJwt": "access_token", 86 + "refreshJwt": "refresh_token", 87 + "did": "did:plc:test", 88 + "handle": "test.bsky.social", 89 + "pdsUri": "https://pds.test.com", 90 + "email": "test@example.com", 91 + "emailConfirmed": true, 92 + "emailAuthFactor": false, 93 + "active": true, 94 + "status": "active" 95 + }|} 96 + in 97 + match Hermes.session_of_yojson json with 98 + | Ok session -> 99 + check test_string "access_jwt" "access_token" session.access_jwt ; 100 + check test_string "refresh_jwt" "refresh_token" session.refresh_jwt ; 101 + check test_string "did" "did:plc:test" session.did ; 102 + check test_string "handle" "test.bsky.social" session.handle ; 103 + check (option test_string) "pds_uri" (Some "https://pds.test.com") 104 + session.pds_uri ; 105 + check (option test_string) "email" (Some "test@example.com") session.email ; 106 + check (option bool) "email_confirmed" (Some true) session.email_confirmed ; 107 + check (option bool) "email_auth_factor" (Some false) 108 + session.email_auth_factor ; 109 + check (option bool) "active" (Some true) session.active ; 110 + check (option test_string) "status" (Some "active") session.status 111 + | Error e -> 112 + fail ("parse failed: " ^ e) 113 + 114 + let test_session_of_yojson_minimal () = 115 + let json = 116 + Yojson.Safe.from_string 117 + {|{ 118 + "accessJwt": "access_token", 119 + "refreshJwt": "refresh_token", 120 + "did": "did:plc:test", 121 + "handle": "test.bsky.social" 122 + }|} 123 + in 124 + match Hermes.session_of_yojson json with 125 + | Ok session -> 126 + check test_string "access_jwt" "access_token" session.access_jwt ; 127 + check test_string "did" "did:plc:test" session.did ; 128 + check (option test_string) "pds_uri" None session.pds_uri ; 129 + check (option test_string) "email" None session.email ; 130 + check (option bool) "active" None session.active 131 + | Error e -> 132 + fail ("parse failed: " ^ e) 133 + 134 + let test_session_roundtrip () = 135 + let original : Hermes.session = 136 + { access_jwt= "access123" 137 + ; refresh_jwt= "refresh456" 138 + ; did= "did:plc:roundtrip" 139 + ; handle= "roundtrip.test" 140 + ; pds_uri= None 141 + ; email= Some "rt@test.com" 142 + ; email_confirmed= None 143 + ; email_auth_factor= None 144 + ; active= Some true 145 + ; status= None } 146 + in 147 + let json = Hermes.session_to_yojson original in 148 + match Hermes.session_of_yojson json with 149 + | Ok decoded -> 150 + check test_string "access_jwt" original.access_jwt decoded.access_jwt ; 151 + check test_string "refresh_jwt" original.refresh_jwt decoded.refresh_jwt ; 152 + check test_string "did" original.did decoded.did ; 153 + check test_string "handle" original.handle decoded.handle ; 154 + check (option test_string) "pds_uri" original.pds_uri decoded.pds_uri ; 155 + check (option test_string) "email" original.email decoded.email ; 156 + check (option bool) "active" original.active decoded.active 157 + | Error e -> 158 + fail ("roundtrip failed: " ^ e) 159 + 160 + (** tests *) 161 + 162 + let blob_tests = 163 + [ ("blob_to_yojson", `Quick, test_blob_to_yojson) 164 + ; ("blob roundtrip", `Quick, test_blob_roundtrip) ] 165 + 166 + let session_tests = 167 + [ ("session_to_yojson", `Quick, test_session_to_yojson) 168 + ; ("session_of_yojson full", `Quick, test_session_of_yojson_full) 169 + ; ("session_of_yojson minimal", `Quick, test_session_of_yojson_minimal) 170 + ; ("session roundtrip", `Quick, test_session_roundtrip) ] 171 + 172 + let () = run "Types" [("blob", blob_tests); ("session", session_tests)]
+150
hermes/test/test_utils.ml
··· 1 + (** test utilities *) 2 + 3 + open Lwt.Syntax 4 + 5 + (* run a test with a mock HTTP client using queued responses *) 6 + let with_mock_responses responses f = 7 + let queue = Mock_http.Queue.create responses in 8 + let handler_ref = ref (Mock_http.Queue.handler queue) in 9 + let module MockHttp = Mock_http.Make (struct 10 + let handler = handler_ref 11 + end) in 12 + let module MockClient = Hermes.Client.Make (MockHttp) in 13 + let client = MockClient.make ~service:"https://test.example.com" () in 14 + let* result = f (module MockClient : Hermes.Client.S) client in 15 + Lwt.return (result, Mock_http.Queue.get_requests queue) 16 + 17 + (* run a test with a mock HTTP client using pattern matching *) 18 + let with_mock_patterns ?default rules f = 19 + let pattern = Mock_http.Pattern.create ?default rules in 20 + let handler_ref = ref (Mock_http.Pattern.handler pattern) in 21 + let module MockHttp = Mock_http.Make (struct 22 + let handler = handler_ref 23 + end) in 24 + let module MockClient = Hermes.Client.Make (MockHttp) in 25 + let client = MockClient.make ~service:"https://test.example.com" () in 26 + let* result = f (module MockClient : Hermes.Client.S) client in 27 + Lwt.return (result, Mock_http.Pattern.get_requests pattern) 28 + 29 + (* create a valid JWT for testing *) 30 + let make_test_jwt ?(exp_offset = 3600) ?(sub = "did:plc:test") () = 31 + let now = int_of_float (Unix.time ()) in 32 + let exp = now + exp_offset in 33 + let header = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9" in 34 + let payload_json = 35 + Printf.sprintf {|{"sub":"%s","iat":%d,"exp":%d}|} sub now exp 36 + in 37 + let payload = 38 + Base64.encode_string ~alphabet:Base64.uri_safe_alphabet ~pad:false 39 + payload_json 40 + in 41 + header ^ "." ^ payload ^ ".fake_signature" 42 + 43 + (* create a test session *) 44 + let make_test_session ?(exp_offset = 3600) () : Hermes.session = 45 + let jwt = make_test_jwt ~exp_offset () in 46 + { access_jwt= jwt 47 + ; refresh_jwt= make_test_jwt ~exp_offset:86400 () 48 + ; did= "did:plc:testuser123" 49 + ; handle= "test.bsky.social" 50 + ; pds_uri= Some "https://pds.example.com" 51 + ; email= Some "test@example.com" 52 + ; email_confirmed= Some true 53 + ; email_auth_factor= Some false 54 + ; active= Some true 55 + ; status= None } 56 + 57 + (* create a session response JSON *) 58 + let session_response_json ?(exp_offset = 3600) () = 59 + let session = make_test_session ~exp_offset () in 60 + `Assoc 61 + [ ("accessJwt", `String session.access_jwt) 62 + ; ("refreshJwt", `String session.refresh_jwt) 63 + ; ("did", `String session.did) 64 + ; ("handle", `String session.handle) ] 65 + 66 + (** assert helpers for requests *) 67 + 68 + let assert_request_path expected_path (req : Mock_http.request) = 69 + let actual = Uri.path req.uri in 70 + if actual <> expected_path then 71 + failwith (Printf.sprintf "expected path %s but got %s" expected_path actual) 72 + 73 + let assert_request_method expected_meth (req : Mock_http.request) = 74 + if req.meth <> expected_meth then 75 + let expected_str = 76 + match expected_meth with `GET -> "GET" | `POST -> "POST" 77 + in 78 + let actual_str = match req.meth with `GET -> "GET" | `POST -> "POST" in 79 + failwith 80 + (Printf.sprintf "expected method %s but got %s" expected_str actual_str) 81 + 82 + let assert_request_has_header name value (req : Mock_http.request) = 83 + match Cohttp.Header.get req.headers name with 84 + | Some v when v = value -> 85 + () 86 + | Some v -> 87 + failwith (Printf.sprintf "header %s: expected %s but got %s" name value v) 88 + | None -> 89 + failwith (Printf.sprintf "header %s not found" name) 90 + 91 + let assert_request_has_auth_header (req : Mock_http.request) = 92 + match Cohttp.Header.get req.headers "authorization" with 93 + | Some v when String.length v > 7 && String.sub v 0 7 = "Bearer " -> 94 + () 95 + | Some v -> 96 + failwith (Printf.sprintf "invalid auth header: %s" v) 97 + | None -> 98 + failwith "authorization header not found" 99 + 100 + let assert_request_query_param name expected_value (req : Mock_http.request) = 101 + let query = Uri.query req.uri in 102 + match List.assoc_opt name query with 103 + | Some [v] when v = expected_value -> 104 + () 105 + | Some [v] -> 106 + failwith 107 + (Printf.sprintf "query param %s: expected %s but got %s" name 108 + expected_value v ) 109 + | Some vs -> 110 + failwith 111 + (Printf.sprintf "query param %s has multiple values: %s" name 112 + (String.concat ", " vs) ) 113 + | None -> 114 + failwith (Printf.sprintf "query param %s not found" name) 115 + 116 + let assert_request_body_contains substring (req : Mock_http.request) = 117 + match req.body with 118 + | Some body when String.length body > 0 -> 119 + if 120 + not 121 + ( String.length substring <= String.length body 122 + && 123 + let rec check i = 124 + if i > String.length body - String.length substring then false 125 + else if String.sub body i (String.length substring) = substring then 126 + true 127 + else check (i + 1) 128 + in 129 + check 0 ) 130 + then failwith (Printf.sprintf "body does not contain '%s'" substring) 131 + | _ -> 132 + failwith "expected request body but none found" 133 + 134 + (* run a test with a mock credential manager *) 135 + let with_mock_credential_manager responses f = 136 + let queue = Mock_http.Queue.create responses in 137 + let handler_ref = ref (Mock_http.Queue.handler queue) in 138 + let module MockHttp = Mock_http.Make (struct 139 + let handler = handler_ref 140 + end) in 141 + let module MockClient = Hermes.Client.Make (MockHttp) in 142 + let module MockCredManager = Hermes.Credential_manager.Make (MockClient) in 143 + let manager = MockCredManager.make ~service:"https://test.example.com" () in 144 + let* result = 145 + f 146 + (module MockCredManager : Hermes.Credential_manager.S) 147 + (module MockClient : Hermes.Client.S) 148 + manager 149 + in 150 + Lwt.return (result, Mock_http.Queue.get_requests queue)
+30
hermes_ppx.opam
··· 1 + # This file is generated by dune, edit dune-project instead 2 + opam-version: "2.0" 3 + synopsis: "PPX extension for Hermes XRPC calls" 4 + maintainer: ["futurGH"] 5 + authors: ["futurGH"] 6 + license: "MPL-2.0" 7 + homepage: "https://github.com/futurGH/pegasus" 8 + bug-reports: "https://github.com/futurGH/pegasus/issues" 9 + depends: [ 10 + "ocaml" {= "5.2.1"} 11 + "dune" {>= "3.20"} 12 + "ppxlib" {>= "0.32.0"} 13 + "odoc" {with-doc} 14 + ] 15 + build: [ 16 + ["dune" "subst"] {dev} 17 + [ 18 + "dune" 19 + "build" 20 + "-p" 21 + name 22 + "-j" 23 + jobs 24 + "@install" 25 + "@runtest" {with-test} 26 + "@doc" {with-doc} 27 + ] 28 + ] 29 + dev-repo: "git+https://github.com/futurGH/pegasus.git" 30 + x-maintenance-intent: ["(latest)"]
+6
hermes_ppx/lib/dune
··· 1 + (library 2 + (name hermes_ppx) 3 + (kind ppx_rewriter) 4 + (libraries ppxlib) 5 + (preprocess 6 + (pps ppxlib.metaquot)))
+77
hermes_ppx/lib/hermes_ppx.ml
··· 1 + open Ppxlib 2 + 3 + (* convert nsid to module path: "app.bsky.graph.get" -> ["App"; "Bsky"; "Graph"; "Get"] *) 4 + let nsid_to_module_path nsid = 5 + String.split_on_char '.' nsid |> List.map String.capitalize_ascii 6 + 7 + (* convert nsid to flat module name: "com.atproto.identity.resolveHandle" -> "Com_atproto_identity_resolveHandle" *) 8 + let nsid_to_flat_module_name nsid = 9 + let flat = String.concat "_" (String.split_on_char '.' nsid) in 10 + String.capitalize_ascii flat 11 + 12 + (* build module access expression from path: ["App"; "Bsky"] -> App.Bsky *) 13 + let build_module_path ~loc path = 14 + match path with 15 + | [] -> 16 + Location.raise_errorf ~loc "Empty module path" 17 + | first :: rest -> 18 + List.fold_left 19 + (fun acc part -> 20 + let lid = Loc.make ~loc (Longident.Ldot (acc.txt, part)) in 21 + lid ) 22 + (Loc.make ~loc (Longident.Lident first)) 23 + rest 24 + 25 + (* build full expression for flat module structure: Module_name.Main.call *) 26 + let build_call_expr_flat ~loc nsid = 27 + let module_name = nsid_to_flat_module_name nsid in 28 + (* Build: Module_name.Main.call *) 29 + let lid = Longident.(Ldot (Ldot (Lident module_name, "Main"), "call")) in 30 + Ast_builder.Default.pexp_ident ~loc (Loc.make ~loc lid) 31 + 32 + (* build full expression: Module.Path.call (nested style, kept for compatibility) *) 33 + let build_call_expr ~loc nsid = 34 + let parts = nsid_to_module_path nsid in 35 + let module_lid = build_module_path ~loc parts in 36 + let call_lid = Loc.make ~loc (Longident.Ldot (module_lid.txt, "call")) in 37 + Ast_builder.Default.pexp_ident ~loc call_lid 38 + 39 + (* parse method and nsid from structure items *) 40 + let parse_method_and_nsid ~loc str = 41 + match str with 42 + | [{pstr_desc= Pstr_eval (expr, _); _}] -> ( 43 + match expr.pexp_desc with 44 + (* [%xrpc get "nsid"] *) 45 + | Pexp_apply 46 + ( {pexp_desc= Pexp_ident {txt= Lident method_; _}; _} 47 + , [(Nolabel, {pexp_desc= Pexp_constant (Pconst_string (nsid, _, _)); _})] 48 + ) -> 49 + let method_lower = String.lowercase_ascii method_ in 50 + if method_lower = "get" || method_lower = "post" then 51 + (method_lower, nsid) 52 + else 53 + Location.raise_errorf ~loc "Expected 'get' or 'post', got '%s'" 54 + method_ 55 + (* [%xrpc "nsid"] - assume get *) 56 + | Pexp_constant (Pconst_string (nsid, _, _)) -> 57 + ("get", nsid) 58 + | _ -> 59 + Location.raise_errorf ~loc 60 + "Expected [%%xrpc get \"nsid\"] or [%%xrpc post \"nsid\"]" ) 61 + | _ -> 62 + Location.raise_errorf ~loc 63 + "Expected [%%xrpc get \"nsid\"] or [%%xrpc post \"nsid\"]" 64 + 65 + let expand ~ctxt str = 66 + let loc = Expansion_context.Extension.extension_point_loc ctxt in 67 + let _method, nsid = parse_method_and_nsid ~loc str in 68 + build_call_expr_flat ~loc nsid 69 + 70 + let xrpc_extension = 71 + Extension.V3.declare "xrpc" Extension.Context.expression 72 + Ast_pattern.(pstr __) 73 + expand 74 + 75 + let rule = Context_free.Rule.extension xrpc_extension 76 + 77 + let () = Driver.register_transformation "hermes_ppx" ~rules:[rule]
+4
hermes_ppx/test/dune
··· 1 + (test 2 + (name test_ppx) 3 + (libraries alcotest hermes_ppx ppxlib) 4 + (preprocess (pps hermes_ppx)))
+84
hermes_ppx/test/test_ppx.ml
··· 1 + open Alcotest 2 + 3 + let loc = Location.none 4 + 5 + let test_nsid_to_module_path_simple () = 6 + let result = Hermes_ppx.nsid_to_module_path "app.bsky.graph" in 7 + check (list string) "simple path" ["App"; "Bsky"; "Graph"] result 8 + 9 + let test_nsid_to_module_path_camel_case () = 10 + let result = 11 + Hermes_ppx.nsid_to_module_path "app.bsky.graph.getRelationships" 12 + in 13 + check (list string) "camelCase" 14 + ["App"; "Bsky"; "Graph"; "GetRelationships"] 15 + result 16 + 17 + let test_nsid_to_module_path_single () = 18 + let result = Hermes_ppx.nsid_to_module_path "test" in 19 + check (list string) "single segment" ["Test"] result 20 + 21 + let test_build_module_path_single () = 22 + let result = Hermes_ppx.build_module_path ~loc ["App"] in 23 + check string "single module" "App" (Ppxlib.Longident.name result.txt) 24 + 25 + let test_build_module_path_nested () = 26 + let result = Hermes_ppx.build_module_path ~loc ["App"; "Bsky"; "Graph"] in 27 + check string "nested module" "App.Bsky.Graph" 28 + (Ppxlib.Longident.name result.txt) 29 + 30 + let test_build_call_expr () = 31 + let result = Hermes_ppx.build_call_expr ~loc "app.bsky.graph.getProfile" in 32 + let expected_str = "App.Bsky.Graph.GetProfile.call" in 33 + check string "call expr" expected_str 34 + (Ppxlib.Pprintast.string_of_expression result) 35 + 36 + let expand_xrpc code = 37 + let lexbuf = Lexing.from_string code in 38 + let structure = Ppxlib.Parse.implementation lexbuf in 39 + let transformed = Ppxlib.Driver.map_structure structure in 40 + match transformed with 41 + | [{Ppxlib.pstr_desc= Ppxlib.Pstr_eval (expr, _); _}] -> 42 + expr 43 + | [{Ppxlib.pstr_desc= Ppxlib.Pstr_value (_, [{pvb_expr; _}]); _}] -> 44 + pvb_expr 45 + | _ -> 46 + failwith "unexpected structure after expansion" 47 + 48 + let test_expand_get_nsid () = 49 + let actual = expand_xrpc {|[%xrpc get "app.bsky.graph.getRelationships"]|} in 50 + let expected_str = "App.Bsky.Graph.GetRelationships.call" in 51 + check string "get expansion" expected_str 52 + (Ppxlib.Pprintast.string_of_expression actual) 53 + 54 + let test_expand_post_nsid () = 55 + let actual = 56 + expand_xrpc {|[%xrpc post "com.atproto.server.createSession"]|} 57 + in 58 + let expected_str = "Com.Atproto.Server.CreateSession.call" in 59 + check string "post expansion" expected_str 60 + (Ppxlib.Pprintast.string_of_expression actual) 61 + 62 + let test_expand_nsid_only () = 63 + (* [%xrpc "nsid"] defaults to get *) 64 + let actual = expand_xrpc {|[%xrpc "app.bsky.actor.getProfile"]|} in 65 + let expected_str = "App.Bsky.Actor.GetProfile.call" in 66 + check string "nsid only expansion" expected_str 67 + (Ppxlib.Pprintast.string_of_expression actual) 68 + 69 + let unit_tests = 70 + [ ("nsid_to_module_path simple", `Quick, test_nsid_to_module_path_simple) 71 + ; ( "nsid_to_module_path camelCase" 72 + , `Quick 73 + , test_nsid_to_module_path_camel_case ) 74 + ; ("nsid_to_module_path single", `Quick, test_nsid_to_module_path_single) 75 + ; ("build_module_path single", `Quick, test_build_module_path_single) 76 + ; ("build_module_path nested", `Quick, test_build_module_path_nested) 77 + ; ("build_call_expr", `Quick, test_build_call_expr) ] 78 + 79 + let expansion_tests = 80 + [ ("expand get nsid", `Quick, test_expand_get_nsid) 81 + ; ("expand post nsid", `Quick, test_expand_post_nsid) 82 + ; ("expand nsid only", `Quick, test_expand_nsid_only) ] 83 + 84 + let () = run "hermes_ppx" [("unit", unit_tests); ("expansion", expansion_tests)]