1// Package lex generates Go code for lexicons.
2//
3// (It is not a lexer.)
4package lex
5
6import (
7 "bytes"
8 "encoding/json"
9 "fmt"
10 "io"
11 "os"
12 "path/filepath"
13 "sort"
14 "strings"
15
16 "golang.org/x/tools/imports"
17)
18
19const (
20 EncodingCBOR = "application/cbor"
21 EncodingJSON = "application/json"
22 EncodingJSONL = "application/jsonl"
23 EncodingCAR = "application/vnd.ipld.car"
24 EncodingMP4 = "video/mp4"
25 EncodingANY = "*/*"
26)
27
28type outputType struct {
29 Name string
30 Type *TypeSchema
31 NeedsCbor bool
32 NeedsType bool
33}
34
35// Build total map of all types defined inside schemas.
36// Return map from fully qualified type name to its *TypeSchema
37func BuildExtDefMap(ss []*Schema, packages []Package) map[string]*ExtDef {
38 out := make(map[string]*ExtDef)
39 for _, s := range ss {
40 for k, d := range s.Defs {
41 d.defMap = out
42 d.id = s.ID
43 d.defName = k
44
45 var pref string
46 for _, pkg := range packages {
47 if strings.HasPrefix(s.ID, pkg.Prefix) {
48 pref = pkg.Prefix
49 break
50 }
51 }
52 d.prefix = pref
53
54 n := s.ID
55 if k != "main" {
56 n = s.ID + "#" + k
57 }
58 out[n] = &ExtDef{
59 Type: d,
60 }
61 }
62 }
63 return out
64}
65
66type ExtDef struct {
67 Type *TypeSchema
68}
69
70// TODO: this method is necessary because in lexicon there is no way to know if
71// a type needs to be marshaled with a "$type" field up front, you can only
72// know for sure by seeing where the type is used.
73func FixRecordReferences(schemas []*Schema, defmap map[string]*ExtDef, prefix string) {
74 for _, s := range schemas {
75 if !strings.HasPrefix(s.ID, prefix) {
76 continue
77 }
78
79 tps := s.AllTypes(prefix, defmap)
80 for _, t := range tps {
81 if t.Type.Type == "record" {
82 t.NeedsType = true
83 t.Type.needsType = true
84 }
85
86 if t.Type.Type == "union" {
87 for _, r := range t.Type.Refs {
88 if r[0] == '#' {
89 r = s.ID + r
90 }
91
92 if _, known := defmap[r]; known != true {
93 panic(fmt.Sprintf("reference to unknown record type: %s", r))
94 }
95
96 if t.NeedsCbor {
97 defmap[r].Type.needsCbor = true
98 }
99 }
100 }
101 }
102 }
103}
104
105func printerf(w io.Writer) func(format string, args ...any) {
106 return func(format string, args ...any) {
107 fmt.Fprintf(w, format, args...)
108 }
109}
110
111func GenCodeForSchema(pkg Package, reqcode bool, s *Schema, packages []Package, defmap map[string]*ExtDef) error {
112 err := os.MkdirAll(pkg.Outdir, 0755)
113 if err != nil {
114 return fmt.Errorf("%s: could not mkdir, %w", pkg.Outdir, err)
115 }
116 fname := filepath.Join(pkg.Outdir, s.Name()+".go")
117 buf := new(bytes.Buffer)
118 pf := printerf(buf)
119
120 s.prefix = pkg.Prefix
121 for _, d := range s.Defs {
122 d.prefix = pkg.Prefix
123 }
124
125 // Add the standard Go generated code header as recognized by GitHub, VS Code, etc.
126 // See https://golang.org/s/generatedcode.
127 pf("// Code generated by cmd/lexgen (see Makefile's lexgen); DO NOT EDIT.\n\n")
128
129 pf("package %s\n\n", pkg.GoPackage)
130
131 pf("// schema: %s\n\n", s.ID)
132
133 pf("import (\n")
134 pf("\t\"context\"\n")
135 pf("\t\"fmt\"\n")
136 pf("\t\"encoding/json\"\n")
137 pf("\tcbg \"github.com/whyrusleeping/cbor-gen\"\n")
138 pf("\t\"github.com/bluesky-social/indigo/lex/util\"\n")
139 for _, xpkg := range packages {
140 if xpkg.Prefix != pkg.Prefix {
141 pf("\t%s %q\n", importNameForPrefix(xpkg.Prefix), xpkg.Import)
142 }
143 }
144 pf(")\n\n")
145
146 tps := s.AllTypes(pkg.Prefix, defmap)
147
148 if err := writeDecoderRegister(buf, tps); err != nil {
149 return err
150 }
151
152 sort.Slice(tps, func(i, j int) bool {
153 return tps[i].Name < tps[j].Name
154 })
155 for _, ot := range tps {
156 fmt.Println("TYPE: ", ot.Name, ot.NeedsCbor, ot.NeedsType)
157 if err := ot.Type.WriteType(ot.Name, buf); err != nil {
158 return err
159 }
160 }
161
162 // reqcode is always True
163 if reqcode {
164 name := nameFromID(s.ID, pkg.Prefix)
165 main, ok := s.Defs["main"]
166 if ok {
167 if err := writeMethods(name, main, buf); err != nil {
168 return err
169 }
170 }
171 }
172
173 if err := writeCodeFile(buf.Bytes(), fname); err != nil {
174 return err
175 }
176
177 return nil
178}
179
180func writeDecoderRegister(w io.Writer, tps []outputType) error {
181 var buf bytes.Buffer
182 outf := printerf(&buf)
183
184 for _, t := range tps {
185 if t.Type.needsType && !strings.Contains(t.Name, "_") {
186 id := t.Type.id
187 if t.Type.defName != "" {
188 id = id + "#" + t.Type.defName
189 }
190 if buf.Len() == 0 {
191 outf("func init() {\n")
192 }
193 outf("util.RegisterType(%q, &%s{})\n", id, t.Name)
194 }
195 }
196 if buf.Len() == 0 {
197 return nil
198 }
199 outf("}")
200 _, err := w.Write(buf.Bytes())
201 return err
202}
203
204func writeCodeFile(b []byte, fname string) error {
205 fixed, err := imports.Process(fname, b, nil)
206 if err != nil {
207 werr := os.WriteFile("temp", b, 0664)
208 if werr != nil {
209 return werr
210 }
211 return fmt.Errorf("failed to format output of %q with goimports: %w (wrote failed file to ./temp)", fname, err)
212 }
213
214 if err := os.WriteFile(fname, fixed, 0664); err != nil {
215 return err
216 }
217
218 return nil
219}
220
221func writeMethods(typename string, ts *TypeSchema, w io.Writer) error {
222 switch ts.Type {
223 case "token":
224 n := ts.id
225 if ts.defName != "main" {
226 n += "#" + ts.defName
227 }
228
229 fmt.Fprintf(w, "const %s = %q\n", typename, n)
230 return nil
231 case "record":
232 return nil
233 case "query":
234 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename))
235 case "procedure":
236 if ts.Input == nil || ts.Input.Schema == nil || ts.Input.Schema.Type == "object" {
237 return ts.WriteRPC(w, typename, fmt.Sprintf("%s_Input", typename))
238 } else if ts.Input.Schema.Type == "ref" {
239 inputname, _ := ts.namesFromRef(ts.Input.Schema.Ref)
240 return ts.WriteRPC(w, typename, inputname)
241 } else {
242 return fmt.Errorf("unhandled input type: %s", ts.Input.Schema.Type)
243 }
244 case "object", "string":
245 return nil
246 case "subscription":
247 // TODO: should probably have some methods generated for this
248 return nil
249 default:
250 return fmt.Errorf("unrecognized lexicon type %q", ts.Type)
251 }
252}
253
254func nameFromID(id, prefix string) string {
255 parts := strings.Split(strings.TrimPrefix(id, prefix), ".")
256 var tname string
257 for _, s := range parts {
258 tname += strings.Title(s)
259 }
260
261 return tname
262
263}
264
265func orderedMapIter[T any](m map[string]T, cb func(string, T) error) error {
266 var keys []string
267 for k := range m {
268 keys = append(keys, k)
269 }
270
271 sort.Strings(keys)
272
273 for _, k := range keys {
274 if err := cb(k, m[k]); err != nil {
275 return err
276 }
277 }
278 return nil
279}
280
281func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error {
282 buf := new(bytes.Buffer)
283
284 if err := WriteXrpcServer(buf, schemas, pkg, impmap); err != nil {
285 return err
286 }
287
288 fname := filepath.Join(dir, "stubs.go")
289 if err := writeCodeFile(buf.Bytes(), fname); err != nil {
290 return err
291 }
292
293 if handlers {
294 buf := new(bytes.Buffer)
295
296 if err := WriteServerHandlers(buf, schemas, pkg, impmap); err != nil {
297 return err
298 }
299
300 fname := filepath.Join(dir, "handlers.go")
301 if err := writeCodeFile(buf.Bytes(), fname); err != nil {
302 return err
303 }
304
305 }
306
307 return nil
308}
309
310func importNameForPrefix(prefix string) string {
311 return strings.Join(strings.Split(prefix, "."), "") + "types"
312}
313
314func WriteServerHandlers(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error {
315 pf := printerf(w)
316 pf("package %s\n\n", pkg)
317 pf("import (\n")
318 pf("\t\"context\"\n")
319 pf("\t\"fmt\"\n")
320 pf("\t\"encoding/json\"\n")
321 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n")
322 for k, v := range impmap {
323 pf("\t%s\"%s\"\n", importNameForPrefix(k), v)
324 }
325 pf(")\n\n")
326
327 for _, s := range schemas {
328
329 var prefix string
330 for k := range impmap {
331 if strings.HasPrefix(s.ID, k) {
332 prefix = k
333 break
334 }
335 }
336
337 main, ok := s.Defs["main"]
338 if !ok {
339 fmt.Printf("WARNING: schema %q doesn't have a main def\n", s.ID)
340 continue
341 }
342
343 if main.Type == "procedure" || main.Type == "query" {
344 fname := idToTitle(s.ID)
345 tname := nameFromID(s.ID, prefix)
346 impname := importNameForPrefix(prefix)
347 if err := main.WriteHandlerStub(w, fname, tname, impname); err != nil {
348 return err
349 }
350 }
351 }
352
353 return nil
354}
355
356func WriteXrpcServer(w io.Writer, schemas []*Schema, pkg string, impmap map[string]string) error {
357 pf := printerf(w)
358 pf("package %s\n\n", pkg)
359 pf("import (\n")
360 pf("\t\"context\"\n")
361 pf("\t\"fmt\"\n")
362 pf("\t\"encoding/json\"\n")
363 pf("\t\"github.com/bluesky-social/indigo/xrpc\"\n")
364 pf("\t\"github.com/labstack/echo/v4\"\n")
365
366 var prefixes []string
367 orderedMapIter[string](impmap, func(k, v string) error {
368 prefixes = append(prefixes, k)
369 pf("\t%s\"%s\"\n", importNameForPrefix(k), v)
370 return nil
371 })
372 pf(")\n\n")
373
374 ssets := make(map[string][]*Schema)
375 for _, s := range schemas {
376 var pref string
377 for _, p := range prefixes {
378 if strings.HasPrefix(s.ID, p) {
379 pref = p
380 break
381 }
382 }
383 if pref == "" {
384 return fmt.Errorf("no matching prefix for schema %q (tried %s)", s.ID, prefixes)
385 }
386
387 ssets[pref] = append(ssets[pref], s)
388 }
389
390 for _, p := range prefixes {
391 ss := ssets[p]
392
393 pf("func (s *Server) RegisterHandlers%s(e *echo.Echo) error {\n", idToTitle(p))
394 for _, s := range ss {
395
396 main, ok := s.Defs["main"]
397 if !ok {
398 continue
399 }
400
401 var verb string
402 switch main.Type {
403 case "query":
404 verb = "GET"
405 case "procedure":
406 verb = "POST"
407 default:
408 continue
409 }
410
411 pf("e.%s(\"/xrpc/%s\", s.Handle%s)\n", verb, s.ID, idToTitle(s.ID))
412 }
413
414 pf("return nil\n}\n\n")
415
416 for _, s := range ss {
417
418 var prefix string
419 for k := range impmap {
420 if strings.HasPrefix(s.ID, k) {
421 prefix = k
422 break
423 }
424 }
425
426 main, ok := s.Defs["main"]
427 if !ok {
428 continue
429 }
430
431 if main.Type == "procedure" || main.Type == "query" {
432 fname := idToTitle(s.ID)
433 tname := nameFromID(s.ID, prefix)
434 impname := importNameForPrefix(prefix)
435 if err := main.WriteRPCHandler(w, fname, tname, impname); err != nil {
436 return fmt.Errorf("writing handler for %s: %w", s.ID, err)
437 }
438 }
439 }
440 }
441
442 return nil
443}
444
445func idToTitle(id string) string {
446 var fname string
447 for _, p := range strings.Split(id, ".") {
448 fname += strings.Title(p)
449 }
450 return fname
451}
452
453type Package struct {
454 GoPackage string `json:"package"`
455 Prefix string `json:"prefix"`
456 Outdir string `json:"outdir"`
457 Import string `json:"import"`
458}
459
460// ParsePackages reads a json blob which should be an array of Package{} objects.
461func ParsePackages(jsonBytes []byte) ([]Package, error) {
462 var packages []Package
463 err := json.Unmarshal(jsonBytes, &packages)
464 if err != nil {
465 return nil, err
466 }
467 return packages, nil
468}
469
470func Run(schemas []*Schema, externalSchemas []*Schema, packages []Package) error {
471 defmap := BuildExtDefMap(append(schemas, externalSchemas...), packages)
472
473 for _, pkg := range packages {
474 prefix := pkg.Prefix
475 FixRecordReferences(schemas, defmap, prefix)
476 }
477
478 for _, pkg := range packages {
479 for _, s := range schemas {
480 if !strings.HasPrefix(s.ID, pkg.Prefix) {
481 continue
482 }
483
484 if err := GenCodeForSchema(pkg, true, s, packages, defmap); err != nil {
485 return fmt.Errorf("failed to process schema %q: %w", s.path, err)
486 }
487 }
488 }
489 return nil
490}