1// Package lex generates Go code for lexicons.
2//
3// (It is not a lexer.)
4package lex
5
6import (
7 "bytes"
8 "encoding/json"
9 "fmt"
10 "io"
11 "os"
12 "path/filepath"
13 "sort"
14 "strings"
15
16 "golang.org/x/tools/imports"
17)
18
19const (
20 EncodingCBOR = "application/cbor"
21 EncodingJSON = "application/json"
22 EncodingJSONL = "application/jsonl"
23 EncodingCAR = "application/vnd.ipld.car"
24 EncodingMP4 = "video/mp4"
25 EncodingANY = "*/*"
26)
27
28type outputType struct {
29 Name string
30 Type *TypeSchema
31 NeedsCbor bool
32 NeedsType bool
33}
34
35// Build total map of all types defined inside schemas.
36// Return map from fully qualified type name to its *TypeSchema
37func BuildExtDefMap(ss []*Schema, packages []Package) map[string]*ExtDef {
38 out := make(map[string]*ExtDef)
39 for _, s := range ss {
40 for k, d := range s.Defs {
41 d.defMap = out
42 d.id = s.ID
43 d.defName = k
44
45 var pref string
46 for _, pkg := range packages {
47 if strings.HasPrefix(s.ID, pkg.Prefix) {
48 pref = pkg.Prefix
49 break
50 }
51 }
52 d.prefix = pref
53
54 n := s.ID
55 if k != "main" {
56 n = s.ID + "#" + k
57 }
58 out[n] = &ExtDef{
59 Type: d,
60 }
61 }
62 }
63 return out
64}
65
66type ExtDef struct {
67 Type *TypeSchema
68}
69
70// TODO: this method is necessary because in lexicon there is no way to know if
71// a type needs to be marshaled with a "$type" field up front, you can only
72// know for sure by seeing where the type is used.
73func FixRecordReferences(schemas []*Schema, defmap map[string]*ExtDef, prefix string) {
74 for _, s := range schemas {
75 if !strings.HasPrefix(s.ID, prefix) {
76 continue
77 }
78
79 tps := s.AllTypes(prefix, defmap)
80 for _, t := range tps {
81 if t.Type.Type == "record" {
82 t.NeedsType = true
83 t.Type.needsType = true
84 }
85
86 if t.Type.Type == "union" {
87 for _, r := range t.Type.Refs {
88 if r[0] == '#' {
89 r = s.ID + r
90 }
91
92 if _, known := defmap[r]; known != true {
93 panic(fmt.Sprintf("reference to unknown record type: %s", r))
94 }
95
96 if t.NeedsCbor {
97 defmap[r].Type.needsCbor = true
98 }
99 }
100 }
101 }
102 }
103}
104
105func printerf(w io.Writer) func(format string, args ...any) {
106 return func(format string, args ...any) {
107 fmt.Fprintf(w, format, args...)
108 }
109}
110
111func GenCodeForSchema(pkg Package, reqcode bool, s *Schema, packages []Package, defmap map[string]*ExtDef) error {
112 err := os.MkdirAll(pkg.Outdir, 0755)
113 if err != nil {
114 return fmt.Errorf("%s: could not mkdir, %w", pkg.Outdir, err)
115 }
116 fname := filepath.Join(pkg.Outdir, s.Name()+".go")
117 buf := new(bytes.Buffer)
118 pf := printerf(buf)
119
120 s.prefix = pkg.Prefix
121 for _, d := range s.Defs {
122 d.prefix = pkg.Prefix
123 }
124
125 // Add the standard Go generated code header as recognized by GitHub, VS Code, etc.
126 // See https://golang.org/s/generatedcode.
127 pf("// Code generated by cmd/lexgen (see Makefile's lexgen); DO NOT EDIT.\n\n")
128
129 pf("package %s\n\n", pkg.GoPackage)
130
131 pf("// schema: %s\n\n", s.ID)
132
133 pf("import (\n")
134 pf("\t\"context\"\n")
135 pf("\t\"fmt\"\n")
136 pf("\t\"encoding/json\"\n")
137 pf("\tcbg \"github.com/whyrusleeping/cbor-gen\"\n")
138 pf("\t\"github.com/bluesky-social/indigo/lex/util\"\n")
139 for _, xpkg := range packages {
140 if xpkg.Prefix != pkg.Prefix {
141 pf("\t%s %q\n", importNameForPrefix(xpkg.Prefix), xpkg.Import)
142 }
143 }
144 pf(")\n\n")
145
146 tps := s.AllTypes(pkg.Prefix, defmap)
147
148 // write NSIDs for top level types
149 pf("const (")
150 // for _, t := range tps {
151 // // this seems to be the way to ignore non-top level types
152 // if !strings.Contains(t.Name, "_") {
153 // pf("%sNSID = %q", t.Name, t.Type.id)
154 // }
155 // }
156 if reqcode {
157 name := nameFromID(s.ID, pkg.Prefix)
158 if _, ok := s.Defs["main"]; ok {
159 pf("%sNSID = %q\n", name, s.ID)
160 }
161 }
162 pf(")\n\n")
163
164 if err := writeDecoderRegister(buf, tps); err != nil {
165 return err
166 }
167
168 sort.Slice(tps, func(i, j int) bool {
169 return tps[i].Name < tps[j].Name
170 })
171 for _, ot := range tps {
172 fmt.Println("TYPE: ", ot.Name, ot.NeedsCbor, ot.NeedsType)
173 if err := ot.Type.WriteType(ot.Name, buf); err != nil {
174 return err
175 }
176 }
177
178 // reqcode is always True
179 if reqcode {
180 name := nameFromID(s.ID, pkg.Prefix)
181 main, ok := s.Defs["main"]
182 if ok {
183 if err := writeMethods(name, main, buf); err != nil {
184 return err
185 }
186 }
187 }
188
189 if err := writeCodeFile(buf.Bytes(), fname); err != nil {
190 return err
191 }
192
193 return nil
194}
195
196func writeDecoderRegister(w io.Writer, tps []outputType) error {
197 var buf bytes.Buffer
198 outf := printerf(&buf)
199
200 for _, t := range tps {
201 if t.Type.needsType && !strings.Contains(t.Name, "_") {
202 id := t.Type.id
203 if t.Type.defName != "" {
204 id = id + "#" + t.Type.defName
205 }
206 if buf.Len() == 0 {
207 outf("func init() {\n")
208 }
209 outf("util.RegisterType(%q, &%s{})\n", id, t.Name)
210 }
211 }
212 if buf.Len() == 0 {
213 return nil
214 }
215 outf("}")
216 _, err := w.Write(buf.Bytes())
217 return err
218}
219
220func writeCodeFile(b []byte, fname string) error {
221 fixed, err := imports.Process(fname, b, nil)
222 if err != nil {
223 werr := os.WriteFile("temp", b, 0664)
224 if werr != nil {
225 return werr
226 }
227 return fmt.Errorf("failed to format output of %q with goimports: %w (wrote failed file to ./temp)", fname, err)
228 }
229
230 if err := os.WriteFile(fname, fixed, 0664); err != nil {
231 return err
232 }
233
234 return nil
235}
236
237func writeMethods(typename string, ts *TypeSchema, w io.Writer) error {
238 switch ts.Type {
239 case "token":
240 n := ts.id
241 if ts.defName != "main" {
242 n += "#" + ts.defName
243 }
244
245 fmt.Fprintf(w, "const %s = %q\n", typename, n)
246 return nil
247 case "record":
248 return nil
249 case "query":
250 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename))
251 case "procedure":
252 if ts.Input == nil || ts.Input.Schema == nil || ts.Input.Schema.Type == "object" {
253 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename))
254 } else if ts.Input.Schema.Type == "ref" {
255 inputname, _ := ts.namesFromRef(ts.Input.Schema.Ref)
256 return ts.WriteRPC(w, typename, inputname)
257 } else {
258 return fmt.Errorf("unhandled input type: %s", ts.Input.Schema.Type)
259 }
260 case "object", "string":
261 return nil
262 case "subscription":
263 // TODO: should probably have some methods generated for this
264 return nil
265 default:
266 return fmt.Errorf("unrecognized lexicon type %q", ts.Type)
267 }
268}
269
270func nameFromID(id, prefix string) string {
271 parts := strings.Split(strings.TrimPrefix(id, prefix), ".")
272 var tname string
273 for _, s := range parts {
274 tname += strings.Title(s)
275 }
276
277 return tname
278
279}
280
281func orderedMapIter[T any](m map[string]T, cb func(string, T) error) error {
282 var keys []string
283 for k := range m {
284 keys = append(keys, k)
285 }
286
287 sort.Strings(keys)
288
289 for _, k := range keys {
290 if err := cb(k, m[k]); err != nil {
291 return err
292 }
293 }
294 return nil
295}
296
297func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error {
298 buf := new(bytes.Buffer)
299
300 if err := WriteXrpcServer(buf, schemas, pkg, impmap); err != nil {
301 return err
302 }
303
304 fname := filepath.Join(dir, "stubs.go")
305 if err := writeCodeFile(buf.Bytes(), fname); err != nil {
306 return err
307 }
308
309 if handlers {
310 buf := new(bytes.Buffer)
311
312 if err := WriteServerHandlers(buf, schemas, pkg, impmap); err != nil {
313 return err
314 }
315
316 fname := filepath.Join(dir, "handlers.go")
317 if err := writeCodeFile(buf.Bytes(), fname); err != nil {
318 return err
319 }
320
321 }
322
323 return nil
324}
325
326func importNameForPrefix(prefix string) string {
327 return strings.Join(strings.Split(prefix, "."), "") + "types"
328}
329
330func WriteServerHandlers(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error {
331 pf := printerf(w)
332 pf("package %s\n\n", pkg)
333 pf("import (\n")
334 pf("\t\"context\"\n")
335 pf("\t\"fmt\"\n")
336 pf("\t\"encoding/json\"\n")
337 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n")
338 for k, v := range impmap {
339 pf("\t%s\"%s\"\n", importNameForPrefix(k), v)
340 }
341 pf(")\n\n")
342
343 for _, s := range schemas {
344
345 var prefix string
346 for k := range impmap {
347 if strings.HasPrefix(s.ID, k) {
348 prefix = k
349 break
350 }
351 }
352
353 main, ok := s.Defs["main"]
354 if !ok {
355 fmt.Printf("WARNING: schema %q doesn't have a main def\n", s.ID)
356 continue
357 }
358
359 if main.Type == "procedure" || main.Type == "query" {
360 fname := idToTitle(s.ID)
361 tname := nameFromID(s.ID, prefix)
362 impname := importNameForPrefix(prefix)
363 if err := main.WriteHandlerStub(w, fname, tname, impname); err != nil {
364 return err
365 }
366 }
367 }
368
369 return nil
370}
371
372func WriteXrpcServer(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error {
373 pf := printerf(w)
374 pf("package %s\n\n", pkg)
375 pf("import (\n")
376 pf("\t\"context\"\n")
377 pf("\t\"fmt\"\n")
378 pf("\t\"encoding/json\"\n")
379 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n")
380 pf("\t\"github.com/labstack/echo/v4\"\n")
381
382 var prefixes []string
383 orderedMapIter[string](impmap, func(k, v string) error {
384 prefixes = append(prefixes, k)
385 pf("\t%s\"%s\"\n", importNameForPrefix(k), v)
386 return nil
387 })
388 pf(")\n\n")
389
390 ssets := make(map[string][]*Schema)
391 for _, s := range schemas {
392 var pref string
393 for _, p := range prefixes {
394 if strings.HasPrefix(s.ID, p) {
395 pref = p
396 break
397 }
398 }
399 if pref == "" {
400 return fmt.Errorf("no matching prefix for schema %q (tried %s)", s.ID, prefixes)
401 }
402
403 ssets[pref] = append(ssets[pref], s)
404 }
405
406 for _, p := range prefixes {
407 ss := ssets[p]
408
409 pf("func (s *Server) RegisterHandlers%s(e *echo.Echo) error {\n", idToTitle(p))
410 for _, s := range ss {
411
412 main, ok := s.Defs["main"]
413 if !ok {
414 continue
415 }
416
417 var verb string
418 switch main.Type {
419 case "query":
420 verb = "GET"
421 case "procedure":
422 verb = "POST"
423 default:
424 continue
425 }
426
427 pf("e.%s(\"/xrpc/%s\", s.Handle%s)\n", verb, s.ID, idToTitle(s.ID))
428 }
429
430 pf("return nil\n}\n\n")
431
432 for _, s := range ss {
433
434 var prefix string
435 for k := range impmap {
436 if strings.HasPrefix(s.ID, k) {
437 prefix = k
438 break
439 }
440 }
441
442 main, ok := s.Defs["main"]
443 if !ok {
444 continue
445 }
446
447 if main.Type == "procedure" || main.Type == "query" {
448 fname := idToTitle(s.ID)
449 tname := nameFromID(s.ID, prefix)
450 impname := importNameForPrefix(prefix)
451 if err := main.WriteRPCHandler(w, fname, tname, impname); err != nil {
452 return fmt.Errorf("writing handler for %s: %w", s.ID, err)
453 }
454 }
455 }
456 }
457
458 return nil
459}
460
461func idToTitle(id string) string {
462 var fname string
463 for _, p := range strings.Split(id, ".") {
464 fname += strings.Title(p)
465 }
466 return fname
467}
468
469type Package struct {
470 GoPackage string `json:"package"`
471 Prefix string `json:"prefix"`
472 Outdir string `json:"outdir"`
473 Import string `json:"import"`
474}
475
476// ParsePackages reads a json blob which should be an array of Package{} objects.
477func ParsePackages(jsonBytes []byte) ([]Package, error) {
478 var packages []Package
479 err := json.Unmarshal(jsonBytes, &packages)
480 if err != nil {
481 return nil, err
482 }
483 return packages, nil
484}
485
486func Run(schemas []*Schema, externalSchemas []*Schema, packages []Package) error {
487 defmap := BuildExtDefMap(append(schemas, externalSchemas...), packages)
488
489 for _, pkg := range packages {
490 prefix := pkg.Prefix
491 FixRecordReferences(schemas, defmap, prefix)
492 }
493
494 for _, pkg := range packages {
495 for _, s := range schemas {
496 if !strings.HasPrefix(s.ID, pkg.Prefix) {
497 continue
498 }
499
500 if err := GenCodeForSchema(pkg, true, s, packages, defmap); err != nil {
501 return fmt.Errorf("failed to process schema %q: %w", s.path, err)
502 }
503 }
504 }
505 return nil
506}