1package lex
2
3import (
4 "fmt"
5 "io"
6 "slices"
7 "strings"
8)
9
10type OutputType struct {
11 Encoding string `json:"encoding"`
12 Schema *TypeSchema `json:"schema"`
13}
14
15type InputType struct {
16 Encoding string `json:"encoding"`
17 Schema *TypeSchema `json:"schema"`
18}
19
20// TypeSchema is the content of a lexicon schema file "defs" section.
21// https://atproto.com/specs/lexicon
22type TypeSchema struct {
23 prefix string // prefix of a major package being processed, e.g. com.atproto
24 id string // parent Schema.ID
25 defName string // parent Schema.Defs[defName] points to this *TypeSchema
26 defMap map[string]*ExtDef
27 needsCbor bool
28 needsType bool
29
30 Type string `json:"type"`
31 Key string `json:"key"`
32 Description string `json:"description"`
33 Parameters *TypeSchema `json:"parameters"`
34 Input *InputType `json:"input"`
35 Output *OutputType `json:"output"`
36 Record *TypeSchema `json:"record"`
37
38 Ref string `json:"ref"`
39 Refs []string `json:"refs"`
40 Required []string `json:"required"`
41 Nullable []string `json:"nullable"`
42 Properties map[string]*TypeSchema `json:"properties"`
43 MaxLength int `json:"maxLength"`
44 Items *TypeSchema `json:"items"`
45 Const any `json:"const"`
46 Enum []string `json:"enum"`
47 Closed bool `json:"closed"`
48
49 Default any `json:"default"`
50 Minimum any `json:"minimum"`
51 Maximum any `json:"maximum"`
52}
53
54func (s *TypeSchema) WriteRPC(w io.Writer, typename, inputname string) error {
55 pf := printerf(w)
56 fname := typename
57
58 params := "ctx context.Context, c util.LexClient"
59 inpvar := "nil"
60 inpenc := ""
61
62 if s.Input != nil {
63 inpvar = "input"
64 inpenc = s.Input.Encoding
65 switch s.Input.Encoding {
66 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingMP4:
67 params = fmt.Sprintf("%s, input io.Reader", params)
68 case EncodingJSON:
69 params = fmt.Sprintf("%s, input *%s", params, inputname)
70
71 default:
72 return fmt.Errorf("unsupported input encoding (RPC input): %q", s.Input.Encoding)
73 }
74 }
75
76 if s.Parameters != nil {
77 if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
78 tn, err := s.typeNameForField(name, "", *t)
79 if err != nil {
80 return err
81 }
82
83 // TODO: deal with optional params
84 params = params + fmt.Sprintf(", %s %s", name, tn)
85 return nil
86 }); err != nil {
87 return err
88 }
89 }
90
91 out := "error"
92 if s.Output != nil {
93 switch s.Output.Encoding {
94 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
95 out = "([]byte, error)"
96 case EncodingJSON:
97 outname := fname + "_Output"
98 if s.Output.Schema.Type == "ref" {
99 _, outname = s.namesFromRef(s.Output.Schema.Ref)
100 }
101
102 out = fmt.Sprintf("(*%s, error)", outname)
103 default:
104 return fmt.Errorf("unrecognized encoding scheme (RPC output): %q", s.Output.Encoding)
105 }
106 }
107
108 pf("// %s calls the XRPC method %q.\n", fname, s.id)
109 if s.Parameters != nil && len(s.Parameters.Properties) > 0 {
110 pf("//\n")
111 if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
112 if t.Description != "" {
113 pf("// %s: %s\n", name, t.Description)
114 }
115 return nil
116 }); err != nil {
117 return err
118 }
119 }
120 pf("func %s(%s) %s {\n", fname, params, out)
121
122 outvar := "nil"
123 errRet := "err"
124 outRet := "nil"
125 if s.Output != nil {
126 switch s.Output.Encoding {
127 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
128 pf("buf := new(bytes.Buffer)\n")
129 outvar = "buf"
130 errRet = "nil, err"
131 outRet = "buf.Bytes(), nil"
132 case EncodingJSON:
133 outname := fname + "_Output"
134 if s.Output.Schema.Type == "ref" {
135 _, outname = s.namesFromRef(s.Output.Schema.Ref)
136 }
137 pf("\tvar out %s\n", outname)
138 outvar = "&out"
139 errRet = "nil, err"
140 outRet = "&out, nil"
141 default:
142 return fmt.Errorf("unrecognized output encoding (func signature): %q", s.Output.Encoding)
143 }
144 }
145
146 queryparams := "nil"
147 if s.Parameters != nil {
148 queryparams = "params"
149 pf("\n\tparams := map[string]interface{}{}\n")
150 if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
151 if slices.Contains(s.Parameters.Required, name) || slices.Contains(s.Parameters.Nullable, name) {
152 pf("params[\"%s\"] = %s\n", name, name)
153 } else {
154 // if parameter isn't required, only include conditionally
155 switch t.Type {
156 case "integer":
157 pf("if %s != 0 { params[\"%s\"] = %s }\n", name, name, name)
158 case "string":
159 pf("if %s != \"\" { params[\"%s\"] = %s }\n", name, name, name)
160 case "array":
161 pf("if len(%s) != 0 { params[\"%s\"] = %s }\n", name, name, name)
162 case "boolean":
163 pf("if %s { params[\"%s\"] = %s }\n", name, name, name)
164 default:
165 return fmt.Errorf("unhandled query param type: %s", t.Type)
166 }
167 }
168 return nil
169 }); err != nil {
170 return err
171 }
172 }
173
174 var reqtype string
175 switch s.Type {
176 case "procedure":
177 reqtype = "util.Procedure"
178 case "query":
179 reqtype = "util.Query"
180 default:
181 return fmt.Errorf("can only generate RPC for Query or Procedure (got %s)", s.Type)
182 }
183
184 pf("\tif err := c.LexDo(ctx, %s, %q, \"%s\", %s, %s, %s); err != nil {\n", reqtype, inpenc, s.id, queryparams, inpvar, outvar)
185 pf("\t\treturn %s\n", errRet)
186 pf("\t}\n\n")
187 pf("\treturn %s\n", outRet)
188 pf("}\n\n")
189
190 return nil
191}
192
193func (s *TypeSchema) WriteHandlerStub(w io.Writer, fname, shortname, impname string) error {
194 pf := printerf(w)
195 paramtypes := []string{"ctx context.Context"}
196 if s.Type == "query" {
197
198 if s.Parameters != nil {
199 var required map[string]bool
200 if s.Parameters.Required != nil {
201 required = make(map[string]bool)
202 for _, r := range s.Required {
203 required[r] = true
204 }
205 }
206 orderedMapIter[*TypeSchema](s.Parameters.Properties, func(k string, t *TypeSchema) error {
207 switch t.Type {
208 case "string":
209 paramtypes = append(paramtypes, k+" string")
210 case "integer":
211 // TODO(bnewbold) could be handling "nullable" here
212 if required != nil && !required[k] {
213 paramtypes = append(paramtypes, k+" *int")
214 } else {
215 paramtypes = append(paramtypes, k+" int")
216 }
217 case "float":
218 return fmt.Errorf("non-integer numbers currently unsupported")
219 case "array":
220 paramtypes = append(paramtypes, k+"[]"+t.Items.Type)
221 default:
222 return fmt.Errorf("unsupported handler parameter type: %s", t.Type)
223 }
224 return nil
225 })
226 }
227 }
228
229 returndef := "error"
230 if s.Output != nil {
231 switch s.Output.Encoding {
232 case "application/json":
233 outname := shortname + "_Output"
234 if s.Output.Schema.Type == "ref" {
235 outname, _ = s.namesFromRef(s.Output.Schema.Ref)
236 }
237 returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname)
238 case "application/cbor", "application/vnd.ipld.car", "*/*":
239 returndef = "(io.Reader, error)"
240 default:
241 return fmt.Errorf("unrecognized output encoding (handler stub): %q", s.Output.Encoding)
242 }
243 }
244
245 if s.Input != nil {
246 switch s.Input.Encoding {
247 case "application/json":
248 paramtypes = append(paramtypes, fmt.Sprintf("input *%s.%s_Input", impname, shortname))
249 case "application/cbor":
250 paramtypes = append(paramtypes, "r io.Reader")
251 }
252 }
253
254 pf("func (s *Server) handle%s(%s) %s {\n", fname, strings.Join(paramtypes, ","), returndef)
255 pf("panic(\"not yet implemented\")\n}\n\n")
256
257 return nil
258}
259
260func (s *TypeSchema) WriteRPCHandler(w io.Writer, fname, shortname, impname string) error {
261 pf := printerf(w)
262 tname := shortname
263
264 pf("func (s *Server) Handle%s(c echo.Context) error {\n", fname)
265
266 pf("ctx, span := otel.Tracer(\"server\").Start(c.Request().Context(), %q)\n", "Handle"+fname)
267 pf("defer span.End()\n")
268
269 paramtypes := []string{"ctx context.Context"}
270 params := []string{"ctx"}
271 if s.Type == "query" {
272 if s.Parameters != nil {
273 // TODO(bnewbold): could be handling 'nullable' here
274 required := make(map[string]bool)
275 for _, r := range s.Parameters.Required {
276 required[r] = true
277 }
278 for k, v := range s.Parameters.Properties {
279 if v.Default != nil {
280 required[k] = true
281 }
282 }
283 if err := orderedMapIter(s.Parameters.Properties, func(k string, t *TypeSchema) error {
284 switch t.Type {
285 case "string":
286 params = append(params, k)
287 paramtypes = append(paramtypes, k+" string")
288 pf("%s := c.QueryParam(\"%s\")\n", k, k)
289 case "integer":
290 params = append(params, k)
291
292 if !required[k] {
293 paramtypes = append(paramtypes, k+" *int")
294 pf(`
295var %s *int
296if p := c.QueryParam("%s"); p != "" {
297 %s_val, err := strconv.Atoi(p)
298 if err != nil {
299 return err
300 }
301 %s = &%s_val
302}
303`, k, k, k, k, k)
304 } else if t.Default != nil {
305 paramtypes = append(paramtypes, k+" int")
306 pf(`
307var %s int
308if p := c.QueryParam("%s"); p != "" {
309var err error
310%s, err = strconv.Atoi(p)
311if err != nil {
312 return err
313}
314} else {
315 %s = %d
316}
317`, k, k, k, k, int(t.Default.(float64)))
318 } else {
319
320 paramtypes = append(paramtypes, k+" int")
321 pf(`
322%s, err := strconv.Atoi(c.QueryParam("%s"))
323if err != nil {
324 return err
325}
326`, k, k)
327 }
328
329 case "float":
330 return fmt.Errorf("non-integer numbers currently unsupported")
331 case "boolean":
332 params = append(params, k)
333 if !required[k] {
334 paramtypes = append(paramtypes, k+" *bool")
335 pf(`
336var %s *bool
337if p := c.QueryParam("%s"); p != "" {
338 %s_val, err := strconv.ParseBool(p)
339 if err != nil {
340 return err
341 }
342 %s = &%s_val
343}
344`, k, k, k, k, k)
345 } else if t.Default != nil {
346 paramtypes = append(paramtypes, k+" bool")
347 pf(`
348var %s bool
349if p := c.QueryParam("%s"); p != "" {
350var err error
351%s, err = strconv.ParseBool(p)
352if err != nil {
353 return err
354}
355} else {
356 %s = %v
357}
358`, k, k, k, k, t.Default.(bool))
359 } else {
360
361 paramtypes = append(paramtypes, k+" bool")
362 pf(`
363%s, err := strconv.ParseBool(c.QueryParam("%s"))
364if err != nil {
365 return err
366}
367`, k, k)
368 }
369
370 case "array":
371 if t.Items.Type != "string" {
372 return fmt.Errorf("currently only string arrays are supported in query params")
373 }
374 paramtypes = append(paramtypes, k+" []string")
375 params = append(params, k)
376 pf(`
377%s := c.QueryParams()["%s"]
378`, k, k)
379
380 default:
381 return fmt.Errorf("unsupported handler parameter type: %s", t.Type)
382 }
383 return nil
384 }); err != nil {
385 return err
386 }
387 }
388 } else if s.Type == "procedure" {
389 if s.Input != nil {
390 intname := impname + "." + tname + "_Input"
391 switch s.Input.Encoding {
392 case EncodingJSON:
393 pf(`
394var body %s
395if err := c.Bind(&body); err != nil {
396 return err
397}
398`, intname)
399 paramtypes = append(paramtypes, "body *"+intname)
400 params = append(params, "&body")
401 case EncodingCBOR:
402 pf("body := c.Request().Body\n")
403 paramtypes = append(paramtypes, "r io.Reader")
404 params = append(params, "body")
405 case EncodingANY:
406 pf("body := c.Request().Body\n")
407 pf("contentType := c.Request().Header.Get(\"Content-Type\")\n")
408 paramtypes = append(paramtypes, "r io.Reader", "contentType string")
409 params = append(params, "body", "contentType")
410 case EncodingMP4:
411 pf("body := c.Request().Body\n")
412 paramtypes = append(paramtypes, "r io.Reader")
413 params = append(params, "body")
414 default:
415 return fmt.Errorf("unrecognized input encoding: %q", s.Input.Encoding)
416 }
417 }
418 } else {
419 return fmt.Errorf("can only generate handlers for queries or procedures")
420 }
421
422 assign := "handleErr"
423 returndef := "error"
424 if s.Output != nil {
425 switch s.Output.Encoding {
426 case EncodingJSON:
427 assign = "out, handleErr"
428 outname := tname + "_Output"
429 if s.Output.Schema.Type == "ref" {
430 outname, _ = s.namesFromRef(s.Output.Schema.Ref)
431 }
432 pf("var out *%s.%s\n", impname, outname)
433 returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname)
434 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
435 assign = "out, handleErr"
436 pf("var out io.Reader\n")
437 returndef = "(io.Reader, error)"
438 default:
439 return fmt.Errorf("unrecognized output encoding (RPC output handler): %q", s.Output.Encoding)
440 }
441 }
442 pf("var handleErr error\n")
443 pf("// func (s *Server) handle%s(%s) %s\n", fname, strings.Join(paramtypes, ","), returndef)
444 pf("%s = s.handle%s(%s)\n", assign, fname, strings.Join(params, ","))
445 pf("if handleErr != nil {\nreturn handleErr\n}\n")
446
447 if s.Output != nil {
448 switch s.Output.Encoding {
449 case EncodingJSON:
450 pf("return c.JSON(200, out)\n}\n\n")
451 case EncodingANY:
452 pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n")
453 case EncodingCBOR:
454 pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n")
455 case EncodingCAR:
456 pf("return c.Stream(200, \"application/vnd.ipld.car\", out)\n}\n\n")
457 case EncodingJSONL:
458 pf("return c.Stream(200, \"application/jsonl\", out)\n}\n\n")
459 case EncodingMP4:
460 pf("return c.Stream(200, \"video/mp4\", out)\n}\n\n")
461 default:
462 return fmt.Errorf("unrecognized output encoding (RPC output handler return): %q", s.Output.Encoding)
463 }
464 } else {
465 pf("return nil\n}\n\n")
466 }
467
468 return nil
469}
470
471func (s *TypeSchema) namesFromRef(r string) (string, string) {
472 ts, err := s.lookupRef(r)
473 if err != nil {
474 panic(err)
475 }
476
477 if ts.prefix == "" {
478 panic(fmt.Sprintf("no prefix for referenced type: %s", ts.id))
479 }
480
481 if s.prefix == "" {
482 panic(fmt.Sprintf("no prefix for referencing type: %q %q", s.id, s.defName))
483 }
484
485 // TODO: probably not technically correct, but i'm kinda over how lexicon
486 // tries to enforce application logic in a schema language
487 if ts.Type == "string" {
488 return "INVALID", "string"
489 }
490
491 var pkg string
492 if ts.prefix != s.prefix {
493 pkg = importNameForPrefix(ts.prefix) + "."
494 }
495
496 tname := pkg + ts.TypeName()
497 vname := tname
498 if strings.Contains(vname, ".") {
499 // Trim the package name from the variable name
500 vname = strings.Split(vname, ".")[1]
501 }
502
503 return vname, tname
504}
505
506func (s *TypeSchema) TypeName() string {
507 if s.id == "" {
508 panic("type schema hint fields not set")
509 }
510 if s.prefix == "" {
511 panic("why no prefix?")
512 }
513 n := nameFromID(s.id, s.prefix)
514 if s.defName != "main" {
515 n += "_" + strings.Title(s.defName)
516 }
517
518 if s.Type == "array" {
519 n = "[]" + n
520
521 if s.Items.Type == "union" {
522 n = n + "_Elem"
523 }
524 }
525
526 return n
527}
528
529// name: enclosing type name
530// k: field name
531// v: field TypeSchema
532func (s *TypeSchema) typeNameForField(name, k string, v TypeSchema) (string, error) {
533 switch v.Type {
534 case "string":
535 return "string", nil
536 case "float":
537 return "float64", nil
538 case "integer":
539 return "int64", nil
540 case "boolean":
541 return "bool", nil
542 case "object":
543 return "*" + name + "_" + strings.Title(k), nil
544 case "ref":
545 _, tn := s.namesFromRef(v.Ref)
546 if tn[0] == '[' {
547 return tn, nil
548 }
549 return "*" + tn, nil
550 case "datetime":
551 // TODO: maybe do a native type?
552 return "string", nil
553 case "unknown":
554 // NOTE: sometimes a record, for which we want LexiconTypeDecoder, sometimes any object
555 if k == "didDoc" || k == "plcOp" || k == "meta" {
556 return "interface{}", nil
557 } else {
558 return "*util.LexiconTypeDecoder", nil
559 }
560 case "union":
561 if len(v.Refs) > 0 {
562 return "*" + name + "_" + strings.Title(k), nil
563 } else {
564 // an empty union is effectively an 'unknown', but with mandatory type indicator
565 return "*util.LexiconTypeDecoder", nil
566 }
567 case "blob":
568 return "*util.LexBlob", nil
569 case "array":
570 subt, err := s.typeNameForField(name+"_"+strings.Title(k), "Elem", *v.Items)
571 if err != nil {
572 return "", err
573 }
574
575 return "[]" + subt, nil
576 case "cid-link":
577 return "util.LexLink", nil
578 case "bytes":
579 return "util.LexBytes", nil
580 default:
581 return "", fmt.Errorf("field %q in %s has unsupported type name (%s)", k, name, v.Type)
582 }
583}
584
585func (ts *TypeSchema) lookupRef(ref string) (*TypeSchema, error) {
586 fqref := ref
587 if strings.HasPrefix(ref, "#") {
588 fmt.Println("updating fqref: ", ts.id)
589 fqref = ts.id + ref
590 }
591 rr, ok := ts.defMap[fqref]
592 if !ok {
593 fmt.Println(ts.defMap)
594 panic(fmt.Sprintf("no such ref: %q", fqref))
595 }
596
597 return rr.Type, nil
598}
599
600// name is the top level type name from outputType
601// WriteType is only called on a top level TypeSchema
602func (ts *TypeSchema) WriteType(name string, w io.Writer) error {
603 name = strings.Title(name)
604 if err := ts.writeTypeDefinition(name, w); err != nil {
605 return err
606 }
607
608 if err := ts.writeTypeMethods(name, w); err != nil {
609 return err
610 }
611
612 return nil
613}
614
615// name is the top level type name from outputType
616// writeTypeDefinition is not called recursively, but only on a top level TypeSchema
617func (ts *TypeSchema) writeTypeDefinition(name string, w io.Writer) error {
618 pf := printerf(w)
619
620 switch {
621 case strings.HasSuffix(name, "_Output"):
622 pf("// %s is the output of a %s call.\n", name, ts.id)
623 case strings.HasSuffix(name, "Input"):
624 pf("// %s is the input argument to a %s call.\n", name, ts.id)
625 case ts.defName != "":
626 pf("// %s is a %q in the %s schema.\n", name, ts.defName, ts.id)
627 }
628 if ts.Description != "" {
629 pf("//\n// %s\n", ts.Description)
630 }
631
632 switch ts.Type {
633 case "string":
634 // TODO: deal with max length
635 pf("type %s string\n", name)
636 case "float":
637 pf("type %s float64\n", name)
638 case "integer":
639 pf("type %s int64\n", name)
640 case "boolean":
641 pf("type %s bool\n", name)
642 case "object":
643 if ts.needsType {
644 pf("//\n// RECORDTYPE: %s\n", name)
645 }
646
647 pf("type %s struct {\n", name)
648
649 if ts.needsType {
650 var omit string
651 if ts.id == "com.atproto.repo.strongRef" { // TODO: hack
652 omit = ",omitempty"
653 }
654 cval := ts.id
655 if ts.defName != "" && ts.defName != "main" {
656 cval += "#" + ts.defName
657 }
658 pf("\tLexiconTypeID string `json:\"$type,const=%s%s\" cborgen:\"$type,const=%s%s\"`\n", cval, omit, cval, omit)
659 } else {
660 //pf("\tLexiconTypeID string `json:\"$type,omitempty\" cborgen:\"$type,omitempty\"`\n")
661 }
662
663 required := make(map[string]bool)
664 for _, req := range ts.Required {
665 required[req] = true
666 }
667
668 nullable := make(map[string]bool)
669 for _, req := range ts.Nullable {
670 nullable[req] = true
671 }
672
673 if err := orderedMapIter(ts.Properties, func(k string, v *TypeSchema) error {
674 goname := strings.Title(k)
675
676 tname, err := ts.typeNameForField(name, k, *v)
677 if err != nil {
678 return err
679 }
680
681 var ptr string
682 var omit string
683 if !required[k] {
684 omit = ",omitempty"
685 if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") {
686 ptr = "*"
687 }
688 }
689 if nullable[k] {
690 omit = ""
691 if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") {
692 ptr = "*"
693 }
694 }
695
696 jsonOmit, cborOmit := omit, omit
697
698 // Don't generate pointers to lexbytes, as it's already a pointer.
699 if ptr == "*" && tname == "util.LexBytes" {
700 ptr = ""
701 }
702
703 // TODO: hard-coded hacks for now, making this type (with underlying type []byte)
704 // be omitempty.
705 if ptr == "" && tname == "util.LexBytes" {
706 jsonOmit = ",omitempty"
707 cborOmit = ",omitempty"
708 }
709
710 if name == "LabelDefs_SelfLabels" && k == "values" {
711 // TODO: regularize weird hack?
712 cborOmit += ",preservenil"
713 }
714
715 if v.Description != "" {
716 pf("\t// %s: %s\n", k, v.Description)
717 }
718 pf("\t%s %s%s `json:\"%s%s\" cborgen:\"%s%s\"`\n", goname, ptr, tname, k, jsonOmit, k, cborOmit)
719 return nil
720 }); err != nil {
721 return err
722 }
723
724 pf("}\n\n")
725
726 case "array":
727 tname, err := ts.typeNameForField(name, "elem", *ts.Items)
728 if err != nil {
729 return err
730 }
731
732 pf("type %s []%s\n", name, tname)
733
734 case "union":
735 if len(ts.Refs) > 0 {
736 pf("type %s struct {\n", name)
737 for _, r := range ts.Refs {
738 vname, tname := ts.namesFromRef(r)
739 pf("\t%s *%s\n", vname, tname)
740 }
741 pf("}\n\n")
742 }
743 default:
744 return fmt.Errorf("%s has unrecognized type: %s", name, ts.Type)
745 }
746
747 return nil
748}
749
750func (ts *TypeSchema) writeTypeMethods(name string, w io.Writer) error {
751 switch ts.Type {
752 case "string", "float", "array", "boolean", "integer":
753 return nil
754 case "object":
755 if err := ts.writeJsonMarshalerObject(name, w); err != nil {
756 return err
757 }
758
759 if err := ts.writeJsonUnmarshalerObject(name, w); err != nil {
760 return err
761 }
762
763 return nil
764 case "union":
765 if len(ts.Refs) > 0 {
766 reft, err := ts.lookupRef(ts.Refs[0])
767 if err != nil {
768 return err
769 }
770
771 if reft.Type == "string" {
772 return nil
773 }
774
775 if err := ts.writeJsonMarshalerEnum(name, w); err != nil {
776 return err
777 }
778
779 if err := ts.writeJsonUnmarshalerEnum(name, w); err != nil {
780 return err
781 }
782
783 if ts.needsCbor {
784 if err := ts.writeCborMarshalerEnum(name, w); err != nil {
785 return err
786 }
787
788 if err := ts.writeCborUnmarshalerEnum(name, w); err != nil {
789 return err
790 }
791 }
792
793 return nil
794 }
795
796 return fmt.Errorf("%q unsupported for marshaling", name)
797 default:
798 return fmt.Errorf("%q has unrecognized type: %s", name, ts.Type)
799 }
800}
801
802func (ts *TypeSchema) writeJsonMarshalerObject(name string, w io.Writer) error {
803 return nil // no need for a special json marshaler right now
804}
805
806func (ts *TypeSchema) writeJsonMarshalerEnum(name string, w io.Writer) error {
807 pf := printerf(w)
808 pf("func (t *%s) MarshalJSON() ([]byte, error) {\n", name)
809
810 for _, e := range ts.Refs {
811 vname, _ := ts.namesFromRef(e)
812 if strings.HasPrefix(e, "#") {
813 e = ts.id + e
814 }
815
816 pf("\tif t.%s != nil {\n", vname)
817 pf("\tt.%s.LexiconTypeID = %q\n", vname, e)
818 pf("\t\treturn json.Marshal(t.%s)\n\t}\n", vname)
819 }
820
821 pf("\treturn nil, fmt.Errorf(\"cannot marshal empty enum\")\n}\n")
822 return nil
823}
824
825func (s *TypeSchema) writeJsonUnmarshalerObject(name string, w io.Writer) error {
826 // TODO: would be nice to add some validation...
827 return nil
828 //pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name)
829}
830
831func (ts *TypeSchema) writeJsonUnmarshalerEnum(name string, w io.Writer) error {
832 pf := printerf(w)
833 pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name)
834 pf("\ttyp, err := util.TypeExtract(b)\n")
835 pf("\tif err != nil {\n\t\treturn err\n\t}\n\n")
836 pf("\tswitch typ {\n")
837 for _, e := range ts.Refs {
838 if strings.HasPrefix(e, "#") {
839 e = ts.id + e
840 }
841
842 vname, goname := ts.namesFromRef(e)
843
844 pf("\t\tcase \"%s\":\n", e)
845 pf("\t\t\tt.%s = new(%s)\n", vname, goname)
846 pf("\t\t\treturn json.Unmarshal(b, t.%s)\n", vname)
847 }
848
849 if ts.Closed {
850 pf(`
851 default:
852 return fmt.Errorf("closed enums must have a matching value")
853 `)
854 } else {
855 pf(`
856 default:
857 return nil
858 `)
859
860 }
861
862 pf("\t}\n")
863 pf("}\n\n")
864
865 return nil
866}
867
868func (ts *TypeSchema) writeCborMarshalerEnum(name string, w io.Writer) error {
869 pf := printerf(w)
870 pf("func (t *%s) MarshalCBOR(w io.Writer) error {\n", name)
871 pf(`
872 if t == nil {
873 _, err := w.Write(cbg.CborNull)
874 return err
875 }
876`)
877
878 for _, e := range ts.Refs {
879 vname, _ := ts.namesFromRef(e)
880 pf("\tif t.%s != nil {\n", vname)
881 pf("\t\treturn t.%s.MarshalCBOR(w)\n\t}\n", vname)
882 }
883
884 pf("\treturn fmt.Errorf(\"cannot cbor marshal empty enum\")\n}\n")
885 return nil
886}
887
888func (ts *TypeSchema) writeCborUnmarshalerEnum(name string, w io.Writer) error {
889 pf := printerf(w)
890 pf("func (t *%s) UnmarshalCBOR(r io.Reader) error {\n", name)
891 pf("\ttyp, b, err := util.CborTypeExtractReader(r)\n")
892 pf("\tif err != nil {\n\t\treturn err\n\t}\n\n")
893 pf("\tswitch typ {\n")
894 for _, e := range ts.Refs {
895 if strings.HasPrefix(e, "#") {
896 e = ts.id + e
897 }
898
899 vname, goname := ts.namesFromRef(e)
900
901 pf("\t\tcase \"%s\":\n", e)
902 pf("\t\t\tt.%s = new(%s)\n", vname, goname)
903 pf("\t\t\treturn t.%s.UnmarshalCBOR(bytes.NewReader(b))\n", vname)
904 }
905
906 if ts.Closed {
907 pf(`
908 default:
909 return fmt.Errorf("closed enums must have a matching value")
910 `)
911 } else {
912 pf(`
913 default:
914 return nil
915 `)
916
917 }
918
919 pf("\t}\n")
920 pf("}\n\n")
921
922 return nil
923}