fork of indigo with slightly nicer lexgen
at main 11 kB view raw
1// Package lex generates Go code for lexicons. 2// 3// (It is not a lexer.) 4package lex 5 6import ( 7 "bytes" 8 "encoding/json" 9 "fmt" 10 "io" 11 "os" 12 "path/filepath" 13 "sort" 14 "strings" 15 16 "golang.org/x/tools/imports" 17) 18 19const ( 20 EncodingCBOR = "application/cbor" 21 EncodingJSON = "application/json" 22 EncodingJSONL = "application/jsonl" 23 EncodingCAR = "application/vnd.ipld.car" 24 EncodingMP4 = "video/mp4" 25 EncodingANY = "*/*" 26) 27 28type outputType struct { 29 Name string 30 Type *TypeSchema 31 NeedsCbor bool 32 NeedsType bool 33} 34 35// Build total map of all types defined inside schemas. 36// Return map from fully qualified type name to its *TypeSchema 37func BuildExtDefMap(ss []*Schema, packages []Package) map[string]*ExtDef { 38 out := make(map[string]*ExtDef) 39 for _, s := range ss { 40 for k, d := range s.Defs { 41 d.defMap = out 42 d.id = s.ID 43 d.defName = k 44 45 var pref string 46 for _, pkg := range packages { 47 if strings.HasPrefix(s.ID, pkg.Prefix) { 48 pref = pkg.Prefix 49 break 50 } 51 } 52 d.prefix = pref 53 54 n := s.ID 55 if k != "main" { 56 n = s.ID + "#" + k 57 } 58 out[n] = &ExtDef{ 59 Type: d, 60 } 61 } 62 } 63 return out 64} 65 66type ExtDef struct { 67 Type *TypeSchema 68} 69 70// TODO: this method is necessary because in lexicon there is no way to know if 71// a type needs to be marshaled with a "$type" field up front, you can only 72// know for sure by seeing where the type is used. 73func FixRecordReferences(schemas []*Schema, defmap map[string]*ExtDef, prefix string) { 74 for _, s := range schemas { 75 if !strings.HasPrefix(s.ID, prefix) { 76 continue 77 } 78 79 tps := s.AllTypes(prefix, defmap) 80 for _, t := range tps { 81 if t.Type.Type == "record" { 82 t.NeedsType = true 83 t.Type.needsType = true 84 } 85 86 if t.Type.Type == "union" { 87 for _, r := range t.Type.Refs { 88 if r[0] == '#' { 89 r = s.ID + r 90 } 91 92 if _, known := defmap[r]; known != true { 93 panic(fmt.Sprintf("reference to unknown record type: %s", r)) 94 } 95 96 if t.NeedsCbor { 97 defmap[r].Type.needsCbor = true 98 } 99 } 100 } 101 } 102 } 103} 104 105func printerf(w io.Writer) func(format string, args ...any) { 106 return func(format string, args ...any) { 107 fmt.Fprintf(w, format, args...) 108 } 109} 110 111func GenCodeForSchema(pkg Package, reqcode bool, s *Schema, packages []Package, defmap map[string]*ExtDef) error { 112 err := os.MkdirAll(pkg.Outdir, 0755) 113 if err != nil { 114 return fmt.Errorf("%s: could not mkdir, %w", pkg.Outdir, err) 115 } 116 fname := filepath.Join(pkg.Outdir, s.Name()+".go") 117 buf := new(bytes.Buffer) 118 pf := printerf(buf) 119 120 s.prefix = pkg.Prefix 121 for _, d := range s.Defs { 122 d.prefix = pkg.Prefix 123 } 124 125 // Add the standard Go generated code header as recognized by GitHub, VS Code, etc. 126 // See https://golang.org/s/generatedcode. 127 pf("// Code generated by cmd/lexgen (see Makefile's lexgen); DO NOT EDIT.\n\n") 128 129 pf("package %s\n\n", pkg.GoPackage) 130 131 pf("// schema: %s\n\n", s.ID) 132 133 pf("import (\n") 134 pf("\t\"context\"\n") 135 pf("\t\"fmt\"\n") 136 pf("\t\"encoding/json\"\n") 137 pf("\tcbg \"github.com/whyrusleeping/cbor-gen\"\n") 138 pf("\t\"github.com/bluesky-social/indigo/lex/util\"\n") 139 for _, xpkg := range packages { 140 if xpkg.Prefix != pkg.Prefix { 141 pf("\t%s %q\n", importNameForPrefix(xpkg.Prefix), xpkg.Import) 142 } 143 } 144 pf(")\n\n") 145 146 tps := s.AllTypes(pkg.Prefix, defmap) 147 148 // write NSIDs for top level types 149 pf("const (") 150 // for _, t := range tps { 151 // // this seems to be the way to ignore non-top level types 152 // if !strings.Contains(t.Name, "_") { 153 // pf("%sNSID = %q", t.Name, t.Type.id) 154 // } 155 // } 156 if reqcode { 157 name := nameFromID(s.ID, pkg.Prefix) 158 if _, ok := s.Defs["main"]; ok { 159 pf("%sNSID = %q\n", name, s.ID) 160 } 161 } 162 pf(")\n\n") 163 164 if err := writeDecoderRegister(buf, tps); err != nil { 165 return err 166 } 167 168 sort.Slice(tps, func(i, j int) bool { 169 return tps[i].Name < tps[j].Name 170 }) 171 for _, ot := range tps { 172 fmt.Println("TYPE: ", ot.Name, ot.NeedsCbor, ot.NeedsType) 173 if err := ot.Type.WriteType(ot.Name, buf); err != nil { 174 return err 175 } 176 } 177 178 // reqcode is always True 179 if reqcode { 180 name := nameFromID(s.ID, pkg.Prefix) 181 main, ok := s.Defs["main"] 182 if ok { 183 if err := writeMethods(name, main, buf); err != nil { 184 return err 185 } 186 } 187 } 188 189 if err := writeCodeFile(buf.Bytes(), fname); err != nil { 190 return err 191 } 192 193 return nil 194} 195 196func writeDecoderRegister(w io.Writer, tps []outputType) error { 197 var buf bytes.Buffer 198 outf := printerf(&buf) 199 200 for _, t := range tps { 201 if t.Type.needsType && !strings.Contains(t.Name, "_") { 202 id := t.Type.id 203 if t.Type.defName != "" { 204 id = id + "#" + t.Type.defName 205 } 206 if buf.Len() == 0 { 207 outf("func init() {\n") 208 } 209 outf("util.RegisterType(%q, &%s{})\n", id, t.Name) 210 } 211 } 212 if buf.Len() == 0 { 213 return nil 214 } 215 outf("}") 216 _, err := w.Write(buf.Bytes()) 217 return err 218} 219 220func writeCodeFile(b []byte, fname string) error { 221 fixed, err := imports.Process(fname, b, nil) 222 if err != nil { 223 werr := os.WriteFile("temp", b, 0664) 224 if werr != nil { 225 return werr 226 } 227 return fmt.Errorf("failed to format output of %q with goimports: %w (wrote failed file to ./temp)", fname, err) 228 } 229 230 if err := os.WriteFile(fname, fixed, 0664); err != nil { 231 return err 232 } 233 234 return nil 235} 236 237func writeMethods(typename string, ts *TypeSchema, w io.Writer) error { 238 switch ts.Type { 239 case "token": 240 n := ts.id 241 if ts.defName != "main" { 242 n += "#" + ts.defName 243 } 244 245 fmt.Fprintf(w, "const %s = %q\n", typename, n) 246 return nil 247 case "record": 248 return nil 249 case "query": 250 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename)) 251 case "procedure": 252 if ts.Input == nil || ts.Input.Schema == nil || ts.Input.Schema.Type == "object" { 253 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename)) 254 } else if ts.Input.Schema.Type == "ref" { 255 inputname, _ := ts.namesFromRef(ts.Input.Schema.Ref) 256 return ts.WriteRPC(w, typename, inputname) 257 } else { 258 return fmt.Errorf("unhandled input type: %s", ts.Input.Schema.Type) 259 } 260 case "object", "string": 261 return nil 262 case "subscription": 263 // TODO: should probably have some methods generated for this 264 return nil 265 default: 266 return fmt.Errorf("unrecognized lexicon type %q", ts.Type) 267 } 268} 269 270func nameFromID(id, prefix string) string { 271 parts := strings.Split(strings.TrimPrefix(id, prefix), ".") 272 var tname string 273 for _, s := range parts { 274 tname += strings.Title(s) 275 } 276 277 return tname 278 279} 280 281func orderedMapIter[T any](m map[string]T, cb func(string, T) error) error { 282 var keys []string 283 for k := range m { 284 keys = append(keys, k) 285 } 286 287 sort.Strings(keys) 288 289 for _, k := range keys { 290 if err := cb(k, m[k]); err != nil { 291 return err 292 } 293 } 294 return nil 295} 296 297func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error { 298 buf := new(bytes.Buffer) 299 300 if err := WriteXrpcServer(buf, schemas, pkg, impmap); err != nil { 301 return err 302 } 303 304 fname := filepath.Join(dir, "stubs.go") 305 if err := writeCodeFile(buf.Bytes(), fname); err != nil { 306 return err 307 } 308 309 if handlers { 310 buf := new(bytes.Buffer) 311 312 if err := WriteServerHandlers(buf, schemas, pkg, impmap); err != nil { 313 return err 314 } 315 316 fname := filepath.Join(dir, "handlers.go") 317 if err := writeCodeFile(buf.Bytes(), fname); err != nil { 318 return err 319 } 320 321 } 322 323 return nil 324} 325 326func importNameForPrefix(prefix string) string { 327 return strings.Join(strings.Split(prefix, "."), "") + "types" 328} 329 330func WriteServerHandlers(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error { 331 pf := printerf(w) 332 pf("package %s\n\n", pkg) 333 pf("import (\n") 334 pf("\t\"context\"\n") 335 pf("\t\"fmt\"\n") 336 pf("\t\"encoding/json\"\n") 337 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n") 338 for k, v := range impmap { 339 pf("\t%s\"%s\"\n", importNameForPrefix(k), v) 340 } 341 pf(")\n\n") 342 343 for _, s := range schemas { 344 345 var prefix string 346 for k := range impmap { 347 if strings.HasPrefix(s.ID, k) { 348 prefix = k 349 break 350 } 351 } 352 353 main, ok := s.Defs["main"] 354 if !ok { 355 fmt.Printf("WARNING: schema %q doesn't have a main def\n", s.ID) 356 continue 357 } 358 359 if main.Type == "procedure" || main.Type == "query" { 360 fname := idToTitle(s.ID) 361 tname := nameFromID(s.ID, prefix) 362 impname := importNameForPrefix(prefix) 363 if err := main.WriteHandlerStub(w, fname, tname, impname); err != nil { 364 return err 365 } 366 } 367 } 368 369 return nil 370} 371 372func WriteXrpcServer(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error { 373 pf := printerf(w) 374 pf("package %s\n\n", pkg) 375 pf("import (\n") 376 pf("\t\"context\"\n") 377 pf("\t\"fmt\"\n") 378 pf("\t\"encoding/json\"\n") 379 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n") 380 pf("\t\"github.com/labstack/echo/v4\"\n") 381 382 var prefixes []string 383 orderedMapIter[string](impmap, func(k, v string) error { 384 prefixes = append(prefixes, k) 385 pf("\t%s\"%s\"\n", importNameForPrefix(k), v) 386 return nil 387 }) 388 pf(")\n\n") 389 390 ssets := make(map[string][]*Schema) 391 for _, s := range schemas { 392 var pref string 393 for _, p := range prefixes { 394 if strings.HasPrefix(s.ID, p) { 395 pref = p 396 break 397 } 398 } 399 if pref == "" { 400 return fmt.Errorf("no matching prefix for schema %q (tried %s)", s.ID, prefixes) 401 } 402 403 ssets[pref] = append(ssets[pref], s) 404 } 405 406 for _, p := range prefixes { 407 ss := ssets[p] 408 409 pf("func (s *Server) RegisterHandlers%s(e *echo.Echo) error {\n", idToTitle(p)) 410 for _, s := range ss { 411 412 main, ok := s.Defs["main"] 413 if !ok { 414 continue 415 } 416 417 var verb string 418 switch main.Type { 419 case "query": 420 verb = "GET" 421 case "procedure": 422 verb = "POST" 423 default: 424 continue 425 } 426 427 pf("e.%s(\"/xrpc/%s\", s.Handle%s)\n", verb, s.ID, idToTitle(s.ID)) 428 } 429 430 pf("return nil\n}\n\n") 431 432 for _, s := range ss { 433 434 var prefix string 435 for k := range impmap { 436 if strings.HasPrefix(s.ID, k) { 437 prefix = k 438 break 439 } 440 } 441 442 main, ok := s.Defs["main"] 443 if !ok { 444 continue 445 } 446 447 if main.Type == "procedure" || main.Type == "query" { 448 fname := idToTitle(s.ID) 449 tname := nameFromID(s.ID, prefix) 450 impname := importNameForPrefix(prefix) 451 if err := main.WriteRPCHandler(w, fname, tname, impname); err != nil { 452 return fmt.Errorf("writing handler for %s: %w", s.ID, err) 453 } 454 } 455 } 456 } 457 458 return nil 459} 460 461func idToTitle(id string) string { 462 var fname string 463 for _, p := range strings.Split(id, ".") { 464 fname += strings.Title(p) 465 } 466 return fname 467} 468 469type Package struct { 470 GoPackage string `json:"package"` 471 Prefix string `json:"prefix"` 472 Outdir string `json:"outdir"` 473 Import string `json:"import"` 474} 475 476// ParsePackages reads a json blob which should be an array of Package{} objects. 477func ParsePackages(jsonBytes []byte) ([]Package, error) { 478 var packages []Package 479 err := json.Unmarshal(jsonBytes, &packages) 480 if err != nil { 481 return nil, err 482 } 483 return packages, nil 484} 485 486func Run(schemas []*Schema, externalSchemas []*Schema, packages []Package) error { 487 defmap := BuildExtDefMap(append(schemas, externalSchemas...), packages) 488 489 for _, pkg := range packages { 490 prefix := pkg.Prefix 491 FixRecordReferences(schemas, defmap, prefix) 492 } 493 494 for _, pkg := range packages { 495 for _, s := range schemas { 496 if !strings.HasPrefix(s.ID, pkg.Prefix) { 497 continue 498 } 499 500 if err := GenCodeForSchema(pkg, true, s, packages, defmap); err != nil { 501 return fmt.Errorf("failed to process schema %q: %w", s.path, err) 502 } 503 } 504 } 505 return nil 506}