-1049
lex/gen.go
-1049
lex/gen.go
···
25
25
EncodingANY = "*/*"
26
26
)
27
27
28
-
type Schema struct {
29
-
path string
30
-
prefix string
31
-
32
-
Lexicon int `json:"lexicon"`
33
-
ID string `json:"id"`
34
-
Defs map[string]*TypeSchema `json:"defs"`
35
-
}
36
-
37
-
// TODO(bnewbold): suspect this param needs updating for lex refactors
38
-
type Param struct {
39
-
Type string `json:"type"`
40
-
Maximum int `json:"maximum"`
41
-
Required bool `json:"required"`
42
-
}
43
-
44
-
type OutputType struct {
45
-
Encoding string `json:"encoding"`
46
-
Schema *TypeSchema `json:"schema"`
47
-
}
48
-
49
-
type InputType struct {
50
-
Encoding string `json:"encoding"`
51
-
Schema *TypeSchema `json:"schema"`
52
-
}
53
-
54
-
type TypeSchema struct {
55
-
prefix string // prefix of a major package being processed, e.g. com.atproto
56
-
id string // parent Schema.ID
57
-
defName string // parent Schema.Defs[defName] points to this *TypeSchema
58
-
defMap map[string]*ExtDef
59
-
needsCbor bool
60
-
needsType bool
61
-
62
-
Type string `json:"type"`
63
-
Key string `json:"key"`
64
-
Description string `json:"description"`
65
-
Parameters *TypeSchema `json:"parameters"`
66
-
Input *InputType `json:"input"`
67
-
Output *OutputType `json:"output"`
68
-
Record *TypeSchema `json:"record"`
69
-
70
-
Ref string `json:"ref"`
71
-
Refs []string `json:"refs"`
72
-
Required []string `json:"required"`
73
-
Nullable []string `json:"nullable"`
74
-
Properties map[string]*TypeSchema `json:"properties"`
75
-
MaxLength int `json:"maxLength"`
76
-
Items *TypeSchema `json:"items"`
77
-
Const any `json:"const"`
78
-
Enum []string `json:"enum"`
79
-
Closed bool `json:"closed"`
80
-
81
-
Default any `json:"default"`
82
-
Minimum any `json:"minimum"`
83
-
Maximum any `json:"maximum"`
84
-
}
85
-
86
-
func (s *Schema) Name() string {
87
-
p := strings.Split(s.ID, ".")
88
-
return p[len(p)-2] + p[len(p)-1]
89
-
}
90
-
91
28
type outputType struct {
92
29
Name string
93
30
Type *TypeSchema
···
95
32
NeedsType bool
96
33
}
97
34
98
-
func (s *Schema) AllTypes(prefix string, defMap map[string]*ExtDef) []outputType {
99
-
var out []outputType
100
-
101
-
var walk func(name string, ts *TypeSchema, needsCbor bool)
102
-
walk = func(name string, ts *TypeSchema, needsCbor bool) {
103
-
if ts == nil {
104
-
panic(fmt.Sprintf("nil type schema in %q (%s)", name, s.ID))
105
-
}
106
-
107
-
if needsCbor {
108
-
fmt.Println("Setting to record: ", name)
109
-
if name == "EmbedImages_View" {
110
-
panic("not ok")
111
-
}
112
-
ts.needsCbor = true
113
-
}
114
-
115
-
if name == "LabelDefs_SelfLabels" {
116
-
ts.needsType = true
117
-
}
118
-
119
-
ts.prefix = prefix
120
-
ts.id = s.ID
121
-
ts.defMap = defMap
122
-
if ts.Type == "object" ||
123
-
(ts.Type == "union" && len(ts.Refs) > 0) {
124
-
out = append(out, outputType{
125
-
Name: name,
126
-
Type: ts,
127
-
NeedsCbor: ts.needsCbor,
128
-
})
129
-
130
-
for _, r := range ts.Refs {
131
-
refname := r
132
-
if strings.HasPrefix(refname, "#") {
133
-
refname = s.ID + r
134
-
}
135
-
136
-
ed, ok := defMap[refname]
137
-
if !ok {
138
-
panic(fmt.Sprintf("cannot find: %q", refname))
139
-
}
140
-
141
-
fmt.Println("UNION REF", refname, name, needsCbor)
142
-
143
-
if needsCbor {
144
-
ed.Type.needsCbor = true
145
-
}
146
-
147
-
ed.Type.needsType = true
148
-
}
149
-
}
150
-
151
-
if ts.Type == "ref" {
152
-
refname := ts.Ref
153
-
if strings.HasPrefix(refname, "#") {
154
-
refname = s.ID + ts.Ref
155
-
}
156
-
157
-
sub, ok := defMap[refname]
158
-
if !ok {
159
-
panic(fmt.Sprintf("missing ref: %q", refname))
160
-
}
161
-
162
-
if needsCbor {
163
-
sub.Type.needsCbor = true
164
-
}
165
-
}
166
-
167
-
for childname, val := range ts.Properties {
168
-
walk(name+"_"+strings.Title(childname), val, ts.needsCbor)
169
-
}
170
-
171
-
if ts.Items != nil {
172
-
walk(name+"_Elem", ts.Items, ts.needsCbor)
173
-
}
174
-
175
-
if ts.Input != nil {
176
-
if ts.Input.Schema == nil {
177
-
if ts.Input.Encoding != EncodingCBOR &&
178
-
ts.Input.Encoding != EncodingANY &&
179
-
ts.Input.Encoding != EncodingCAR &&
180
-
ts.Input.Encoding != EncodingMP4 {
181
-
panic(fmt.Sprintf("strange input type def in %s", s.ID))
182
-
}
183
-
} else {
184
-
walk(name+"_Input", ts.Input.Schema, ts.needsCbor)
185
-
}
186
-
}
187
-
188
-
if ts.Output != nil {
189
-
if ts.Output.Schema == nil {
190
-
if ts.Output.Encoding != EncodingCBOR &&
191
-
ts.Output.Encoding != EncodingCAR &&
192
-
ts.Output.Encoding != EncodingANY &&
193
-
ts.Output.Encoding != EncodingJSONL &&
194
-
ts.Output.Encoding != EncodingMP4 {
195
-
panic(fmt.Sprintf("strange output type def in %s", s.ID))
196
-
}
197
-
} else {
198
-
walk(name+"_Output", ts.Output.Schema, ts.needsCbor)
199
-
}
200
-
}
201
-
202
-
if ts.Type == "record" {
203
-
ts.Record.needsType = true
204
-
walk(name, ts.Record, true)
205
-
}
206
-
207
-
}
208
-
209
-
tname := nameFromID(s.ID, prefix)
210
-
211
-
for name, def := range s.Defs {
212
-
n := tname + "_" + strings.Title(name)
213
-
if name == "main" {
214
-
n = tname
215
-
}
216
-
walk(n, def, def.needsCbor)
217
-
}
218
-
219
-
return out
220
-
}
221
-
222
-
func ReadSchema(f string) (*Schema, error) {
223
-
fi, err := os.Open(f)
224
-
if err != nil {
225
-
return nil, err
226
-
}
227
-
defer fi.Close()
228
-
229
-
var s Schema
230
-
if err := json.NewDecoder(fi).Decode(&s); err != nil {
231
-
return nil, err
232
-
}
233
-
s.path = f
234
-
235
-
return &s, nil
236
-
}
237
-
238
35
// Build total map of all types defined inside schemas.
239
36
// Return map from fully qualified type name to its *TypeSchema
240
37
func BuildExtDefMap(ss []*Schema, packages []Package) map[string]*ExtDef {
···
475
272
return nil
476
273
}
477
274
478
-
func (s *TypeSchema) WriteRPC(w io.Writer, typename string) error {
479
-
pf := printerf(w)
480
-
fname := typename
481
-
482
-
params := "ctx context.Context, c *xrpc.Client"
483
-
inpvar := "nil"
484
-
inpenc := ""
485
-
486
-
if s.Input != nil {
487
-
inpvar = "input"
488
-
inpenc = s.Input.Encoding
489
-
switch s.Input.Encoding {
490
-
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingMP4:
491
-
params = fmt.Sprintf("%s, input io.Reader", params)
492
-
case EncodingJSON:
493
-
params = fmt.Sprintf("%s, input *%s_Input", params, fname)
494
-
495
-
default:
496
-
return fmt.Errorf("unsupported input encoding (RPC input): %q", s.Input.Encoding)
497
-
}
498
-
}
499
-
500
-
if s.Parameters != nil {
501
-
if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
502
-
tn, err := s.typeNameForField(name, "", *t)
503
-
if err != nil {
504
-
return err
505
-
}
506
-
507
-
// TODO: deal with optional params
508
-
params = params + fmt.Sprintf(", %s %s", name, tn)
509
-
return nil
510
-
}); err != nil {
511
-
return err
512
-
}
513
-
}
514
-
515
-
out := "error"
516
-
if s.Output != nil {
517
-
switch s.Output.Encoding {
518
-
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
519
-
out = "([]byte, error)"
520
-
case EncodingJSON:
521
-
outname := fname + "_Output"
522
-
if s.Output.Schema.Type == "ref" {
523
-
_, outname = s.namesFromRef(s.Output.Schema.Ref)
524
-
}
525
-
526
-
out = fmt.Sprintf("(*%s, error)", outname)
527
-
default:
528
-
return fmt.Errorf("unrecognized encoding scheme (RPC output): %q", s.Output.Encoding)
529
-
}
530
-
}
531
-
532
-
pf("// %s calls the XRPC method %q.\n", fname, s.id)
533
-
if s.Parameters != nil && len(s.Parameters.Properties) > 0 {
534
-
pf("//\n")
535
-
if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
536
-
if t.Description != "" {
537
-
pf("// %s: %s\n", name, t.Description)
538
-
}
539
-
return nil
540
-
}); err != nil {
541
-
return err
542
-
}
543
-
}
544
-
pf("func %s(%s) %s {\n", fname, params, out)
545
-
546
-
outvar := "nil"
547
-
errRet := "err"
548
-
outRet := "nil"
549
-
if s.Output != nil {
550
-
switch s.Output.Encoding {
551
-
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
552
-
pf("buf := new(bytes.Buffer)\n")
553
-
outvar = "buf"
554
-
errRet = "nil, err"
555
-
outRet = "buf.Bytes(), nil"
556
-
case EncodingJSON:
557
-
outname := fname + "_Output"
558
-
if s.Output.Schema.Type == "ref" {
559
-
_, outname = s.namesFromRef(s.Output.Schema.Ref)
560
-
}
561
-
pf("\tvar out %s\n", outname)
562
-
outvar = "&out"
563
-
errRet = "nil, err"
564
-
outRet = "&out, nil"
565
-
default:
566
-
return fmt.Errorf("unrecognized output encoding (func signature): %q", s.Output.Encoding)
567
-
}
568
-
}
569
-
570
-
queryparams := "nil"
571
-
if s.Parameters != nil {
572
-
queryparams = "params"
573
-
pf(`
574
-
params := map[string]interface{}{
575
-
`)
576
-
if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
577
-
pf(`"%s": %s,
578
-
`, name, name)
579
-
return nil
580
-
}); err != nil {
581
-
return err
582
-
}
583
-
pf("}\n")
584
-
}
585
-
586
-
var reqtype string
587
-
switch s.Type {
588
-
case "procedure":
589
-
reqtype = "xrpc.Procedure"
590
-
case "query":
591
-
reqtype = "xrpc.Query"
592
-
default:
593
-
return fmt.Errorf("can only generate RPC for Query or Procedure (got %s)", s.Type)
594
-
}
595
-
596
-
pf("\tif err := c.Do(ctx, %s, %q, \"%s\", %s, %s, %s); err != nil {\n", reqtype, inpenc, s.id, queryparams, inpvar, outvar)
597
-
pf("\t\treturn %s\n", errRet)
598
-
pf("\t}\n\n")
599
-
pf("\treturn %s\n", outRet)
600
-
pf("}\n\n")
601
-
602
-
return nil
603
-
}
604
-
605
275
func CreateHandlerStub(pkg string, impmap map[string]string, dir string, schemas []*Schema, handlers bool) error {
606
276
buf := new(bytes.Buffer)
607
277
···
772
442
fname += strings.Title(p)
773
443
}
774
444
return fname
775
-
}
776
-
777
-
func (s *TypeSchema) WriteHandlerStub(w io.Writer, fname, shortname, impname string) error {
778
-
pf := printerf(w)
779
-
paramtypes := []string{"ctx context.Context"}
780
-
if s.Type == "query" {
781
-
782
-
if s.Parameters != nil {
783
-
var required map[string]bool
784
-
if s.Parameters.Required != nil {
785
-
required = make(map[string]bool)
786
-
for _, r := range s.Required {
787
-
required[r] = true
788
-
}
789
-
}
790
-
orderedMapIter[*TypeSchema](s.Parameters.Properties, func(k string, t *TypeSchema) error {
791
-
switch t.Type {
792
-
case "string":
793
-
paramtypes = append(paramtypes, k+" string")
794
-
case "integer":
795
-
// TODO(bnewbold) could be handling "nullable" here
796
-
if required != nil && !required[k] {
797
-
paramtypes = append(paramtypes, k+" *int")
798
-
} else {
799
-
paramtypes = append(paramtypes, k+" int")
800
-
}
801
-
case "float":
802
-
return fmt.Errorf("non-integer numbers currently unsupported")
803
-
case "array":
804
-
paramtypes = append(paramtypes, k+"[]"+t.Items.Type)
805
-
default:
806
-
return fmt.Errorf("unsupported handler parameter type: %s", t.Type)
807
-
}
808
-
return nil
809
-
})
810
-
}
811
-
}
812
-
813
-
returndef := "error"
814
-
if s.Output != nil {
815
-
switch s.Output.Encoding {
816
-
case "application/json":
817
-
outname := shortname + "_Output"
818
-
if s.Output.Schema.Type == "ref" {
819
-
outname, _ = s.namesFromRef(s.Output.Schema.Ref)
820
-
}
821
-
returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname)
822
-
case "application/cbor", "application/vnd.ipld.car", "*/*":
823
-
returndef = fmt.Sprintf("(io.Reader, error)")
824
-
default:
825
-
return fmt.Errorf("unrecognized output encoding (handler stub): %q", s.Output.Encoding)
826
-
}
827
-
}
828
-
829
-
if s.Input != nil {
830
-
switch s.Input.Encoding {
831
-
case "application/json":
832
-
paramtypes = append(paramtypes, fmt.Sprintf("input *%s.%s_Input", impname, shortname))
833
-
case "application/cbor":
834
-
paramtypes = append(paramtypes, "r io.Reader")
835
-
}
836
-
}
837
-
838
-
pf("func (s *Server) handle%s(%s) %s {\n", fname, strings.Join(paramtypes, ","), returndef)
839
-
pf("panic(\"not yet implemented\")\n}\n\n")
840
-
841
-
return nil
842
-
}
843
-
844
-
func (s *TypeSchema) WriteRPCHandler(w io.Writer, fname, shortname, impname string) error {
845
-
pf := printerf(w)
846
-
tname := shortname
847
-
848
-
pf("func (s *Server) Handle%s(c echo.Context) error {\n", fname)
849
-
850
-
pf("ctx, span := otel.Tracer(\"server\").Start(c.Request().Context(), %q)\n", "Handle"+fname)
851
-
pf("defer span.End()\n")
852
-
853
-
paramtypes := []string{"ctx context.Context"}
854
-
params := []string{"ctx"}
855
-
if s.Type == "query" {
856
-
if s.Parameters != nil {
857
-
// TODO(bnewbold): could be handling 'nullable' here
858
-
required := make(map[string]bool)
859
-
for _, r := range s.Parameters.Required {
860
-
required[r] = true
861
-
}
862
-
for k, v := range s.Parameters.Properties {
863
-
if v.Default != nil {
864
-
required[k] = true
865
-
}
866
-
}
867
-
if err := orderedMapIter(s.Parameters.Properties, func(k string, t *TypeSchema) error {
868
-
switch t.Type {
869
-
case "string":
870
-
params = append(params, k)
871
-
paramtypes = append(paramtypes, k+" string")
872
-
pf("%s := c.QueryParam(\"%s\")\n", k, k)
873
-
case "integer":
874
-
params = append(params, k)
875
-
876
-
if !required[k] {
877
-
paramtypes = append(paramtypes, k+" *int")
878
-
pf(`
879
-
var %s *int
880
-
if p := c.QueryParam("%s"); p != "" {
881
-
%s_val, err := strconv.Atoi(p)
882
-
if err != nil {
883
-
return err
884
-
}
885
-
%s = &%s_val
886
-
}
887
-
`, k, k, k, k, k)
888
-
} else if t.Default != nil {
889
-
paramtypes = append(paramtypes, k+" int")
890
-
pf(`
891
-
var %s int
892
-
if p := c.QueryParam("%s"); p != "" {
893
-
var err error
894
-
%s, err = strconv.Atoi(p)
895
-
if err != nil {
896
-
return err
897
-
}
898
-
} else {
899
-
%s = %d
900
-
}
901
-
`, k, k, k, k, int(t.Default.(float64)))
902
-
} else {
903
-
904
-
paramtypes = append(paramtypes, k+" int")
905
-
pf(`
906
-
%s, err := strconv.Atoi(c.QueryParam("%s"))
907
-
if err != nil {
908
-
return err
909
-
}
910
-
`, k, k)
911
-
}
912
-
913
-
case "float":
914
-
return fmt.Errorf("non-integer numbers currently unsupported")
915
-
case "boolean":
916
-
params = append(params, k)
917
-
if !required[k] {
918
-
paramtypes = append(paramtypes, k+" *bool")
919
-
pf(`
920
-
var %s *bool
921
-
if p := c.QueryParam("%s"); p != "" {
922
-
%s_val, err := strconv.ParseBool(p)
923
-
if err != nil {
924
-
return err
925
-
}
926
-
%s = &%s_val
927
-
}
928
-
`, k, k, k, k, k)
929
-
} else if t.Default != nil {
930
-
paramtypes = append(paramtypes, k+" bool")
931
-
pf(`
932
-
var %s bool
933
-
if p := c.QueryParam("%s"); p != "" {
934
-
var err error
935
-
%s, err = strconv.ParseBool(p)
936
-
if err != nil {
937
-
return err
938
-
}
939
-
} else {
940
-
%s = %v
941
-
}
942
-
`, k, k, k, k, t.Default.(bool))
943
-
} else {
944
-
945
-
paramtypes = append(paramtypes, k+" bool")
946
-
pf(`
947
-
%s, err := strconv.ParseBool(c.QueryParam("%s"))
948
-
if err != nil {
949
-
return err
950
-
}
951
-
`, k, k)
952
-
}
953
-
954
-
case "array":
955
-
if t.Items.Type != "string" {
956
-
return fmt.Errorf("currently only string arrays are supported in query params")
957
-
}
958
-
paramtypes = append(paramtypes, k+" []string")
959
-
params = append(params, k)
960
-
pf(`
961
-
%s := c.QueryParams()["%s"]
962
-
`, k, k)
963
-
964
-
default:
965
-
return fmt.Errorf("unsupported handler parameter type: %s", t.Type)
966
-
}
967
-
return nil
968
-
}); err != nil {
969
-
return err
970
-
}
971
-
}
972
-
} else if s.Type == "procedure" {
973
-
if s.Input != nil {
974
-
intname := impname + "." + tname + "_Input"
975
-
switch s.Input.Encoding {
976
-
case EncodingJSON:
977
-
pf(`
978
-
var body %s
979
-
if err := c.Bind(&body); err != nil {
980
-
return err
981
-
}
982
-
`, intname)
983
-
paramtypes = append(paramtypes, "body *"+intname)
984
-
params = append(params, "&body")
985
-
case EncodingCBOR:
986
-
pf("body := c.Request().Body\n")
987
-
paramtypes = append(paramtypes, "r io.Reader")
988
-
params = append(params, "body")
989
-
case EncodingANY:
990
-
pf("body := c.Request().Body\n")
991
-
pf("contentType := c.Request().Header.Get(\"Content-Type\")\n")
992
-
paramtypes = append(paramtypes, "r io.Reader", "contentType string")
993
-
params = append(params, "body", "contentType")
994
-
case EncodingMP4:
995
-
pf("body := c.Request().Body\n")
996
-
paramtypes = append(paramtypes, "r io.Reader")
997
-
params = append(params, "body")
998
-
default:
999
-
return fmt.Errorf("unrecognized input encoding: %q", s.Input.Encoding)
1000
-
}
1001
-
}
1002
-
} else {
1003
-
return fmt.Errorf("can only generate handlers for queries or procedures")
1004
-
}
1005
-
1006
-
assign := "handleErr"
1007
-
returndef := "error"
1008
-
if s.Output != nil {
1009
-
switch s.Output.Encoding {
1010
-
case EncodingJSON:
1011
-
assign = "out, handleErr"
1012
-
outname := tname + "_Output"
1013
-
if s.Output.Schema.Type == "ref" {
1014
-
outname, _ = s.namesFromRef(s.Output.Schema.Ref)
1015
-
}
1016
-
pf("var out *%s.%s\n", impname, outname)
1017
-
returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname)
1018
-
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
1019
-
assign = "out, handleErr"
1020
-
pf("var out io.Reader\n")
1021
-
returndef = "(io.Reader, error)"
1022
-
default:
1023
-
return fmt.Errorf("unrecognized output encoding (RPC output handler): %q", s.Output.Encoding)
1024
-
}
1025
-
}
1026
-
pf("var handleErr error\n")
1027
-
pf("// func (s *Server) handle%s(%s) %s\n", fname, strings.Join(paramtypes, ","), returndef)
1028
-
pf("%s = s.handle%s(%s)\n", assign, fname, strings.Join(params, ","))
1029
-
pf("if handleErr != nil {\nreturn handleErr\n}\n")
1030
-
1031
-
if s.Output != nil {
1032
-
switch s.Output.Encoding {
1033
-
case EncodingJSON:
1034
-
pf("return c.JSON(200, out)\n}\n\n")
1035
-
case EncodingANY:
1036
-
pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n")
1037
-
case EncodingCBOR:
1038
-
pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n")
1039
-
case EncodingCAR:
1040
-
pf("return c.Stream(200, \"application/vnd.ipld.car\", out)\n}\n\n")
1041
-
case EncodingJSONL:
1042
-
pf("return c.Stream(200, \"application/jsonl\", out)\n}\n\n")
1043
-
case EncodingMP4:
1044
-
pf("return c.Stream(200, \"video/mp4\", out)\n}\n\n")
1045
-
default:
1046
-
return fmt.Errorf("unrecognized output encoding (RPC output handler return): %q", s.Output.Encoding)
1047
-
}
1048
-
} else {
1049
-
pf("return nil\n}\n\n")
1050
-
}
1051
-
1052
-
return nil
1053
-
}
1054
-
1055
-
func (s *TypeSchema) namesFromRef(r string) (string, string) {
1056
-
ts, err := s.lookupRef(r)
1057
-
if err != nil {
1058
-
panic(err)
1059
-
}
1060
-
1061
-
if ts.prefix == "" {
1062
-
panic(fmt.Sprintf("no prefix for referenced type: %s", ts.id))
1063
-
}
1064
-
1065
-
if s.prefix == "" {
1066
-
panic(fmt.Sprintf("no prefix for referencing type: %q %q", s.id, s.defName))
1067
-
}
1068
-
1069
-
// TODO: probably not technically correct, but i'm kinda over how lexicon
1070
-
// tries to enforce application logic in a schema language
1071
-
if ts.Type == "string" {
1072
-
return "INVALID", "string"
1073
-
}
1074
-
1075
-
var pkg string
1076
-
if ts.prefix != s.prefix {
1077
-
pkg = importNameForPrefix(ts.prefix) + "."
1078
-
}
1079
-
1080
-
tname := pkg + ts.TypeName()
1081
-
vname := tname
1082
-
if strings.Contains(vname, ".") {
1083
-
// Trim the package name from the variable name
1084
-
vname = strings.Split(vname, ".")[1]
1085
-
}
1086
-
1087
-
return vname, tname
1088
-
}
1089
-
1090
-
func (s *TypeSchema) TypeName() string {
1091
-
if s.id == "" {
1092
-
panic("type schema hint fields not set")
1093
-
}
1094
-
if s.prefix == "" {
1095
-
panic("why no prefix?")
1096
-
}
1097
-
n := nameFromID(s.id, s.prefix)
1098
-
if s.defName != "main" {
1099
-
n += "_" + strings.Title(s.defName)
1100
-
}
1101
-
1102
-
if s.Type == "array" {
1103
-
n = "[]" + n
1104
-
1105
-
if s.Items.Type == "union" {
1106
-
n = n + "_Elem"
1107
-
}
1108
-
}
1109
-
1110
-
return n
1111
-
}
1112
-
1113
-
func (s *TypeSchema) typeNameForField(name, k string, v TypeSchema) (string, error) {
1114
-
switch v.Type {
1115
-
case "string":
1116
-
return "string", nil
1117
-
case "float":
1118
-
return "float64", nil
1119
-
case "integer":
1120
-
return "int64", nil
1121
-
case "boolean":
1122
-
return "bool", nil
1123
-
case "object":
1124
-
return "*" + name + "_" + strings.Title(k), nil
1125
-
case "ref":
1126
-
_, tn := s.namesFromRef(v.Ref)
1127
-
if tn[0] == '[' {
1128
-
return tn, nil
1129
-
}
1130
-
return "*" + tn, nil
1131
-
case "datetime":
1132
-
// TODO: maybe do a native type?
1133
-
return "string", nil
1134
-
case "unknown":
1135
-
// NOTE: sometimes a record, for which we want LexiconTypeDecoder, sometimes any object
1136
-
if k == "didDoc" || k == "plcOp" {
1137
-
return "interface{}", nil
1138
-
} else {
1139
-
return "*util.LexiconTypeDecoder", nil
1140
-
}
1141
-
case "union":
1142
-
return "*" + name + "_" + strings.Title(k), nil
1143
-
case "blob":
1144
-
return "*util.LexBlob", nil
1145
-
case "array":
1146
-
subt, err := s.typeNameForField(name+"_"+strings.Title(k), "Elem", *v.Items)
1147
-
if err != nil {
1148
-
return "", err
1149
-
}
1150
-
1151
-
return "[]" + subt, nil
1152
-
case "cid-link":
1153
-
return "util.LexLink", nil
1154
-
case "bytes":
1155
-
return "util.LexBytes", nil
1156
-
default:
1157
-
return "", fmt.Errorf("field %q in %s has unsupported type name (%s)", k, name, v.Type)
1158
-
}
1159
-
}
1160
-
1161
-
func (ts *TypeSchema) lookupRef(ref string) (*TypeSchema, error) {
1162
-
fqref := ref
1163
-
if strings.HasPrefix(ref, "#") {
1164
-
fmt.Println("updating fqref: ", ts.id)
1165
-
fqref = ts.id + ref
1166
-
}
1167
-
rr, ok := ts.defMap[fqref]
1168
-
if !ok {
1169
-
fmt.Println(ts.defMap)
1170
-
panic(fmt.Sprintf("no such ref: %q", fqref))
1171
-
}
1172
-
1173
-
return rr.Type, nil
1174
-
}
1175
-
1176
-
func (ts *TypeSchema) WriteType(name string, w io.Writer) error {
1177
-
name = strings.Title(name)
1178
-
if err := ts.writeTypeDefinition(name, w); err != nil {
1179
-
return err
1180
-
}
1181
-
1182
-
if err := ts.writeTypeMethods(name, w); err != nil {
1183
-
return err
1184
-
}
1185
-
1186
-
return nil
1187
-
}
1188
-
1189
-
func (ts *TypeSchema) writeTypeDefinition(name string, w io.Writer) error {
1190
-
pf := printerf(w)
1191
-
1192
-
switch {
1193
-
case strings.HasSuffix(name, "_Output"):
1194
-
pf("// %s is the output of a %s call.\n", name, ts.id)
1195
-
case strings.HasSuffix(name, "Input"):
1196
-
pf("// %s is the input argument to a %s call.\n", name, ts.id)
1197
-
case ts.defName != "":
1198
-
pf("// %s is a %q in the %s schema.\n", name, ts.defName, ts.id)
1199
-
}
1200
-
if ts.Description != "" {
1201
-
pf("//\n// %s\n", ts.Description)
1202
-
}
1203
-
1204
-
switch ts.Type {
1205
-
case "string":
1206
-
// TODO: deal with max length
1207
-
pf("type %s string\n", name)
1208
-
case "float":
1209
-
pf("type %s float64\n", name)
1210
-
case "integer":
1211
-
pf("type %s int64\n", name)
1212
-
case "boolean":
1213
-
pf("type %s bool\n", name)
1214
-
case "object":
1215
-
if ts.needsType {
1216
-
pf("//\n// RECORDTYPE: %s\n", name)
1217
-
}
1218
-
1219
-
pf("type %s struct {\n", name)
1220
-
1221
-
if ts.needsType {
1222
-
var omit string
1223
-
if ts.id == "com.atproto.repo.strongRef" { // TODO: hack
1224
-
omit = ",omitempty"
1225
-
}
1226
-
cval := ts.id
1227
-
if ts.defName != "" && ts.defName != "main" {
1228
-
cval += "#" + ts.defName
1229
-
}
1230
-
pf("\tLexiconTypeID string `json:\"$type,const=%s%s\" cborgen:\"$type,const=%s%s\"`\n", cval, omit, cval, omit)
1231
-
} else {
1232
-
//pf("\tLexiconTypeID string `json:\"$type,omitempty\" cborgen:\"$type,omitempty\"`\n")
1233
-
}
1234
-
1235
-
required := make(map[string]bool)
1236
-
for _, req := range ts.Required {
1237
-
required[req] = true
1238
-
}
1239
-
1240
-
nullable := make(map[string]bool)
1241
-
for _, req := range ts.Nullable {
1242
-
nullable[req] = true
1243
-
}
1244
-
1245
-
if err := orderedMapIter(ts.Properties, func(k string, v *TypeSchema) error {
1246
-
goname := strings.Title(k)
1247
-
1248
-
tname, err := ts.typeNameForField(name, k, *v)
1249
-
if err != nil {
1250
-
return err
1251
-
}
1252
-
1253
-
var ptr string
1254
-
var omit string
1255
-
if !required[k] {
1256
-
omit = ",omitempty"
1257
-
if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") {
1258
-
ptr = "*"
1259
-
}
1260
-
}
1261
-
if nullable[k] {
1262
-
omit = ""
1263
-
if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") {
1264
-
ptr = "*"
1265
-
}
1266
-
}
1267
-
1268
-
jsonOmit, cborOmit := omit, omit
1269
-
1270
-
// Don't generate pointers to lexbytes, as it's already a pointer.
1271
-
if ptr == "*" && tname == "util.LexBytes" {
1272
-
ptr = ""
1273
-
}
1274
-
1275
-
// TODO: hard-coded hacks for now, making this type (with underlying type []byte)
1276
-
// be omitempty.
1277
-
if ptr == "" && tname == "util.LexBytes" {
1278
-
jsonOmit = ",omitempty"
1279
-
cborOmit = ",omitempty"
1280
-
}
1281
-
1282
-
if name == "LabelDefs_SelfLabels" && k == "values" {
1283
-
cborOmit += ",preservenil"
1284
-
}
1285
-
1286
-
if v.Description != "" {
1287
-
pf("\t// %s: %s\n", k, v.Description)
1288
-
}
1289
-
pf("\t%s %s%s `json:\"%s%s\" cborgen:\"%s%s\"`\n", goname, ptr, tname, k, jsonOmit, k, cborOmit)
1290
-
return nil
1291
-
}); err != nil {
1292
-
return err
1293
-
}
1294
-
1295
-
pf("}\n\n")
1296
-
1297
-
case "array":
1298
-
tname, err := ts.typeNameForField(name, "elem", *ts.Items)
1299
-
if err != nil {
1300
-
return err
1301
-
}
1302
-
1303
-
pf("type %s []%s\n", name, tname)
1304
-
1305
-
case "union":
1306
-
if len(ts.Refs) > 0 {
1307
-
pf("type %s struct {\n", name)
1308
-
for _, r := range ts.Refs {
1309
-
vname, tname := ts.namesFromRef(r)
1310
-
pf("\t%s *%s\n", vname, tname)
1311
-
}
1312
-
pf("}\n\n")
1313
-
}
1314
-
default:
1315
-
return fmt.Errorf("%s has unrecognized type: %s", name, ts.Type)
1316
-
}
1317
-
1318
-
return nil
1319
-
}
1320
-
1321
-
func (ts *TypeSchema) writeTypeMethods(name string, w io.Writer) error {
1322
-
switch ts.Type {
1323
-
case "string", "float", "array", "boolean", "integer":
1324
-
return nil
1325
-
case "object":
1326
-
if err := ts.writeJsonMarshalerObject(name, w); err != nil {
1327
-
return err
1328
-
}
1329
-
1330
-
if err := ts.writeJsonUnmarshalerObject(name, w); err != nil {
1331
-
return err
1332
-
}
1333
-
1334
-
return nil
1335
-
case "union":
1336
-
if len(ts.Refs) > 0 {
1337
-
reft, err := ts.lookupRef(ts.Refs[0])
1338
-
if err != nil {
1339
-
return err
1340
-
}
1341
-
1342
-
if reft.Type == "string" {
1343
-
return nil
1344
-
}
1345
-
1346
-
if err := ts.writeJsonMarshalerEnum(name, w); err != nil {
1347
-
return err
1348
-
}
1349
-
1350
-
if err := ts.writeJsonUnmarshalerEnum(name, w); err != nil {
1351
-
return err
1352
-
}
1353
-
1354
-
if ts.needsCbor {
1355
-
if err := ts.writeCborMarshalerEnum(name, w); err != nil {
1356
-
return err
1357
-
}
1358
-
1359
-
if err := ts.writeCborUnmarshalerEnum(name, w); err != nil {
1360
-
return err
1361
-
}
1362
-
}
1363
-
1364
-
return nil
1365
-
}
1366
-
1367
-
return fmt.Errorf("%q unsupported for marshaling", name)
1368
-
default:
1369
-
return fmt.Errorf("%q has unrecognized type: %s", name, ts.Type)
1370
-
}
1371
-
}
1372
-
1373
-
func (ts *TypeSchema) writeJsonMarshalerObject(name string, w io.Writer) error {
1374
-
return nil // no need for a special json marshaler right now
1375
-
}
1376
-
1377
-
func (ts *TypeSchema) writeJsonMarshalerEnum(name string, w io.Writer) error {
1378
-
pf := printerf(w)
1379
-
pf("func (t *%s) MarshalJSON() ([]byte, error) {\n", name)
1380
-
1381
-
for _, e := range ts.Refs {
1382
-
vname, _ := ts.namesFromRef(e)
1383
-
if strings.HasPrefix(e, "#") {
1384
-
e = ts.id + e
1385
-
}
1386
-
1387
-
pf("\tif t.%s != nil {\n", vname)
1388
-
pf("\tt.%s.LexiconTypeID = %q\n", vname, e)
1389
-
pf("\t\treturn json.Marshal(t.%s)\n\t}\n", vname)
1390
-
}
1391
-
1392
-
pf("\treturn nil, fmt.Errorf(\"cannot marshal empty enum\")\n}\n")
1393
-
return nil
1394
-
}
1395
-
1396
-
func (s *TypeSchema) writeJsonUnmarshalerObject(name string, w io.Writer) error {
1397
-
// TODO: would be nice to add some validation...
1398
-
return nil
1399
-
//pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name)
1400
-
}
1401
-
1402
-
func (ts *TypeSchema) writeJsonUnmarshalerEnum(name string, w io.Writer) error {
1403
-
pf := printerf(w)
1404
-
pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name)
1405
-
pf("\ttyp, err := util.TypeExtract(b)\n")
1406
-
pf("\tif err != nil {\n\t\treturn err\n\t}\n\n")
1407
-
pf("\tswitch typ {\n")
1408
-
for _, e := range ts.Refs {
1409
-
if strings.HasPrefix(e, "#") {
1410
-
e = ts.id + e
1411
-
}
1412
-
1413
-
vname, goname := ts.namesFromRef(e)
1414
-
1415
-
pf("\t\tcase \"%s\":\n", e)
1416
-
pf("\t\t\tt.%s = new(%s)\n", vname, goname)
1417
-
pf("\t\t\treturn json.Unmarshal(b, t.%s)\n", vname)
1418
-
}
1419
-
1420
-
if ts.Closed {
1421
-
pf(`
1422
-
default:
1423
-
return fmt.Errorf("closed enums must have a matching value")
1424
-
`)
1425
-
} else {
1426
-
pf(`
1427
-
default:
1428
-
return nil
1429
-
`)
1430
-
1431
-
}
1432
-
1433
-
pf("\t}\n")
1434
-
pf("}\n\n")
1435
-
1436
-
return nil
1437
-
}
1438
-
1439
-
func (ts *TypeSchema) writeCborMarshalerEnum(name string, w io.Writer) error {
1440
-
pf := printerf(w)
1441
-
pf("func (t *%s) MarshalCBOR(w io.Writer) error {\n", name)
1442
-
pf(`
1443
-
if t == nil {
1444
-
_, err := w.Write(cbg.CborNull)
1445
-
return err
1446
-
}
1447
-
`)
1448
-
1449
-
for _, e := range ts.Refs {
1450
-
vname, _ := ts.namesFromRef(e)
1451
-
pf("\tif t.%s != nil {\n", vname)
1452
-
pf("\t\treturn t.%s.MarshalCBOR(w)\n\t}\n", vname)
1453
-
}
1454
-
1455
-
pf("\treturn fmt.Errorf(\"cannot cbor marshal empty enum\")\n}\n")
1456
-
return nil
1457
-
}
1458
-
1459
-
func (ts *TypeSchema) writeCborUnmarshalerEnum(name string, w io.Writer) error {
1460
-
pf := printerf(w)
1461
-
pf("func (t *%s) UnmarshalCBOR(r io.Reader) error {\n", name)
1462
-
pf("\ttyp, b, err := util.CborTypeExtractReader(r)\n")
1463
-
pf("\tif err != nil {\n\t\treturn err\n\t}\n\n")
1464
-
pf("\tswitch typ {\n")
1465
-
for _, e := range ts.Refs {
1466
-
if strings.HasPrefix(e, "#") {
1467
-
e = ts.id + e
1468
-
}
1469
-
1470
-
vname, goname := ts.namesFromRef(e)
1471
-
1472
-
pf("\t\tcase \"%s\":\n", e)
1473
-
pf("\t\t\tt.%s = new(%s)\n", vname, goname)
1474
-
pf("\t\t\treturn t.%s.UnmarshalCBOR(bytes.NewReader(b))\n", vname)
1475
-
}
1476
-
1477
-
if ts.Closed {
1478
-
pf(`
1479
-
default:
1480
-
return fmt.Errorf("closed enums must have a matching value")
1481
-
`)
1482
-
} else {
1483
-
pf(`
1484
-
default:
1485
-
return nil
1486
-
`)
1487
-
1488
-
}
1489
-
1490
-
pf("\t}\n")
1491
-
pf("}\n\n")
1492
-
1493
-
return nil
1494
445
}
1495
446
1496
447
type Package struct {
+169
lex/schema.go
+169
lex/schema.go
···
1
+
package lex
2
+
3
+
import (
4
+
"encoding/json"
5
+
"fmt"
6
+
"os"
7
+
"strings"
8
+
)
9
+
10
+
// Schema is a lexicon json file
11
+
// e.g. atproto/lexicons/app/bsky/feed/post.json
12
+
// https://atproto.com/specs/lexicon
13
+
type Schema struct {
14
+
// path of json file read
15
+
path string
16
+
17
+
// prefix of lexicon group, e.g. "app.bsky" or "com.atproto"
18
+
prefix string
19
+
20
+
// Lexicon version, e.g. 1
21
+
Lexicon int `json:"lexicon"`
22
+
ID string `json:"id"`
23
+
Defs map[string]*TypeSchema `json:"defs"`
24
+
}
25
+
26
+
func ReadSchema(f string) (*Schema, error) {
27
+
fi, err := os.Open(f)
28
+
if err != nil {
29
+
return nil, err
30
+
}
31
+
defer fi.Close()
32
+
33
+
var s Schema
34
+
if err := json.NewDecoder(fi).Decode(&s); err != nil {
35
+
return nil, err
36
+
}
37
+
s.path = f
38
+
39
+
return &s, nil
40
+
}
41
+
42
+
func (s *Schema) Name() string {
43
+
p := strings.Split(s.ID, ".")
44
+
return p[len(p)-2] + p[len(p)-1]
45
+
}
46
+
47
+
func (s *Schema) AllTypes(prefix string, defMap map[string]*ExtDef) []outputType {
48
+
var out []outputType
49
+
50
+
var walk func(name string, ts *TypeSchema, needsCbor bool)
51
+
walk = func(name string, ts *TypeSchema, needsCbor bool) {
52
+
if ts == nil {
53
+
panic(fmt.Sprintf("nil type schema in %q (%s)", name, s.ID))
54
+
}
55
+
56
+
if needsCbor {
57
+
fmt.Println("Setting to record: ", name)
58
+
if name == "EmbedImages_View" {
59
+
panic("not ok")
60
+
}
61
+
ts.needsCbor = true
62
+
}
63
+
64
+
if name == "LabelDefs_SelfLabels" {
65
+
ts.needsType = true
66
+
}
67
+
68
+
ts.prefix = prefix
69
+
ts.id = s.ID
70
+
ts.defMap = defMap
71
+
if ts.Type == "object" ||
72
+
(ts.Type == "union" && len(ts.Refs) > 0) {
73
+
out = append(out, outputType{
74
+
Name: name,
75
+
Type: ts,
76
+
NeedsCbor: ts.needsCbor,
77
+
})
78
+
79
+
for _, r := range ts.Refs {
80
+
refname := r
81
+
if strings.HasPrefix(refname, "#") {
82
+
refname = s.ID + r
83
+
}
84
+
85
+
ed, ok := defMap[refname]
86
+
if !ok {
87
+
panic(fmt.Sprintf("cannot find: %q", refname))
88
+
}
89
+
90
+
fmt.Println("UNION REF", refname, name, needsCbor)
91
+
92
+
if needsCbor {
93
+
ed.Type.needsCbor = true
94
+
}
95
+
96
+
ed.Type.needsType = true
97
+
}
98
+
}
99
+
100
+
if ts.Type == "ref" {
101
+
refname := ts.Ref
102
+
if strings.HasPrefix(refname, "#") {
103
+
refname = s.ID + ts.Ref
104
+
}
105
+
106
+
sub, ok := defMap[refname]
107
+
if !ok {
108
+
panic(fmt.Sprintf("missing ref: %q", refname))
109
+
}
110
+
111
+
if needsCbor {
112
+
sub.Type.needsCbor = true
113
+
}
114
+
}
115
+
116
+
for childname, val := range ts.Properties {
117
+
walk(name+"_"+strings.Title(childname), val, ts.needsCbor)
118
+
}
119
+
120
+
if ts.Items != nil {
121
+
walk(name+"_Elem", ts.Items, ts.needsCbor)
122
+
}
123
+
124
+
if ts.Input != nil {
125
+
if ts.Input.Schema == nil {
126
+
if ts.Input.Encoding != EncodingCBOR &&
127
+
ts.Input.Encoding != EncodingANY &&
128
+
ts.Input.Encoding != EncodingCAR &&
129
+
ts.Input.Encoding != EncodingMP4 {
130
+
panic(fmt.Sprintf("strange input type def in %s", s.ID))
131
+
}
132
+
} else {
133
+
walk(name+"_Input", ts.Input.Schema, ts.needsCbor)
134
+
}
135
+
}
136
+
137
+
if ts.Output != nil {
138
+
if ts.Output.Schema == nil {
139
+
if ts.Output.Encoding != EncodingCBOR &&
140
+
ts.Output.Encoding != EncodingCAR &&
141
+
ts.Output.Encoding != EncodingANY &&
142
+
ts.Output.Encoding != EncodingJSONL &&
143
+
ts.Output.Encoding != EncodingMP4 {
144
+
panic(fmt.Sprintf("strange output type def in %s", s.ID))
145
+
}
146
+
} else {
147
+
walk(name+"_Output", ts.Output.Schema, ts.needsCbor)
148
+
}
149
+
}
150
+
151
+
if ts.Type == "record" {
152
+
ts.Record.needsType = true
153
+
walk(name, ts.Record, true)
154
+
}
155
+
156
+
}
157
+
158
+
tname := nameFromID(s.ID, prefix)
159
+
160
+
for name, def := range s.Defs {
161
+
n := tname + "_" + strings.Title(name)
162
+
if name == "main" {
163
+
n = tname
164
+
}
165
+
walk(n, def, def.needsCbor)
166
+
}
167
+
168
+
return out
169
+
}
+905
lex/type_schema.go
+905
lex/type_schema.go
···
1
+
package lex
2
+
3
+
import (
4
+
"fmt"
5
+
"io"
6
+
"strings"
7
+
)
8
+
9
+
type OutputType struct {
10
+
Encoding string `json:"encoding"`
11
+
Schema *TypeSchema `json:"schema"`
12
+
}
13
+
14
+
type InputType struct {
15
+
Encoding string `json:"encoding"`
16
+
Schema *TypeSchema `json:"schema"`
17
+
}
18
+
19
+
// TypeSchema is the content of a lexicon schema file "defs" section.
20
+
// https://atproto.com/specs/lexicon
21
+
type TypeSchema struct {
22
+
prefix string // prefix of a major package being processed, e.g. com.atproto
23
+
id string // parent Schema.ID
24
+
defName string // parent Schema.Defs[defName] points to this *TypeSchema
25
+
defMap map[string]*ExtDef
26
+
needsCbor bool
27
+
needsType bool
28
+
29
+
Type string `json:"type"`
30
+
Key string `json:"key"`
31
+
Description string `json:"description"`
32
+
Parameters *TypeSchema `json:"parameters"`
33
+
Input *InputType `json:"input"`
34
+
Output *OutputType `json:"output"`
35
+
Record *TypeSchema `json:"record"`
36
+
37
+
Ref string `json:"ref"`
38
+
Refs []string `json:"refs"`
39
+
Required []string `json:"required"`
40
+
Nullable []string `json:"nullable"`
41
+
Properties map[string]*TypeSchema `json:"properties"`
42
+
MaxLength int `json:"maxLength"`
43
+
Items *TypeSchema `json:"items"`
44
+
Const any `json:"const"`
45
+
Enum []string `json:"enum"`
46
+
Closed bool `json:"closed"`
47
+
48
+
Default any `json:"default"`
49
+
Minimum any `json:"minimum"`
50
+
Maximum any `json:"maximum"`
51
+
}
52
+
53
+
func (s *TypeSchema) WriteRPC(w io.Writer, typename string) error {
54
+
pf := printerf(w)
55
+
fname := typename
56
+
57
+
params := "ctx context.Context, c *xrpc.Client"
58
+
inpvar := "nil"
59
+
inpenc := ""
60
+
61
+
if s.Input != nil {
62
+
inpvar = "input"
63
+
inpenc = s.Input.Encoding
64
+
switch s.Input.Encoding {
65
+
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingMP4:
66
+
params = fmt.Sprintf("%s, input io.Reader", params)
67
+
case EncodingJSON:
68
+
params = fmt.Sprintf("%s, input *%s_Input", params, fname)
69
+
70
+
default:
71
+
return fmt.Errorf("unsupported input encoding (RPC input): %q", s.Input.Encoding)
72
+
}
73
+
}
74
+
75
+
if s.Parameters != nil {
76
+
if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
77
+
tn, err := s.typeNameForField(name, "", *t)
78
+
if err != nil {
79
+
return err
80
+
}
81
+
82
+
// TODO: deal with optional params
83
+
params = params + fmt.Sprintf(", %s %s", name, tn)
84
+
return nil
85
+
}); err != nil {
86
+
return err
87
+
}
88
+
}
89
+
90
+
out := "error"
91
+
if s.Output != nil {
92
+
switch s.Output.Encoding {
93
+
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
94
+
out = "([]byte, error)"
95
+
case EncodingJSON:
96
+
outname := fname + "_Output"
97
+
if s.Output.Schema.Type == "ref" {
98
+
_, outname = s.namesFromRef(s.Output.Schema.Ref)
99
+
}
100
+
101
+
out = fmt.Sprintf("(*%s, error)", outname)
102
+
default:
103
+
return fmt.Errorf("unrecognized encoding scheme (RPC output): %q", s.Output.Encoding)
104
+
}
105
+
}
106
+
107
+
pf("// %s calls the XRPC method %q.\n", fname, s.id)
108
+
if s.Parameters != nil && len(s.Parameters.Properties) > 0 {
109
+
pf("//\n")
110
+
if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
111
+
if t.Description != "" {
112
+
pf("// %s: %s\n", name, t.Description)
113
+
}
114
+
return nil
115
+
}); err != nil {
116
+
return err
117
+
}
118
+
}
119
+
pf("func %s(%s) %s {\n", fname, params, out)
120
+
121
+
outvar := "nil"
122
+
errRet := "err"
123
+
outRet := "nil"
124
+
if s.Output != nil {
125
+
switch s.Output.Encoding {
126
+
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
127
+
pf("buf := new(bytes.Buffer)\n")
128
+
outvar = "buf"
129
+
errRet = "nil, err"
130
+
outRet = "buf.Bytes(), nil"
131
+
case EncodingJSON:
132
+
outname := fname + "_Output"
133
+
if s.Output.Schema.Type == "ref" {
134
+
_, outname = s.namesFromRef(s.Output.Schema.Ref)
135
+
}
136
+
pf("\tvar out %s\n", outname)
137
+
outvar = "&out"
138
+
errRet = "nil, err"
139
+
outRet = "&out, nil"
140
+
default:
141
+
return fmt.Errorf("unrecognized output encoding (func signature): %q", s.Output.Encoding)
142
+
}
143
+
}
144
+
145
+
queryparams := "nil"
146
+
if s.Parameters != nil {
147
+
queryparams = "params"
148
+
pf(`
149
+
params := map[string]interface{}{
150
+
`)
151
+
if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error {
152
+
pf(`"%s": %s,
153
+
`, name, name)
154
+
return nil
155
+
}); err != nil {
156
+
return err
157
+
}
158
+
pf("}\n")
159
+
}
160
+
161
+
var reqtype string
162
+
switch s.Type {
163
+
case "procedure":
164
+
reqtype = "xrpc.Procedure"
165
+
case "query":
166
+
reqtype = "xrpc.Query"
167
+
default:
168
+
return fmt.Errorf("can only generate RPC for Query or Procedure (got %s)", s.Type)
169
+
}
170
+
171
+
pf("\tif err := c.Do(ctx, %s, %q, \"%s\", %s, %s, %s); err != nil {\n", reqtype, inpenc, s.id, queryparams, inpvar, outvar)
172
+
pf("\t\treturn %s\n", errRet)
173
+
pf("\t}\n\n")
174
+
pf("\treturn %s\n", outRet)
175
+
pf("}\n\n")
176
+
177
+
return nil
178
+
}
179
+
180
+
func (s *TypeSchema) WriteHandlerStub(w io.Writer, fname, shortname, impname string) error {
181
+
pf := printerf(w)
182
+
paramtypes := []string{"ctx context.Context"}
183
+
if s.Type == "query" {
184
+
185
+
if s.Parameters != nil {
186
+
var required map[string]bool
187
+
if s.Parameters.Required != nil {
188
+
required = make(map[string]bool)
189
+
for _, r := range s.Required {
190
+
required[r] = true
191
+
}
192
+
}
193
+
orderedMapIter[*TypeSchema](s.Parameters.Properties, func(k string, t *TypeSchema) error {
194
+
switch t.Type {
195
+
case "string":
196
+
paramtypes = append(paramtypes, k+" string")
197
+
case "integer":
198
+
// TODO(bnewbold) could be handling "nullable" here
199
+
if required != nil && !required[k] {
200
+
paramtypes = append(paramtypes, k+" *int")
201
+
} else {
202
+
paramtypes = append(paramtypes, k+" int")
203
+
}
204
+
case "float":
205
+
return fmt.Errorf("non-integer numbers currently unsupported")
206
+
case "array":
207
+
paramtypes = append(paramtypes, k+"[]"+t.Items.Type)
208
+
default:
209
+
return fmt.Errorf("unsupported handler parameter type: %s", t.Type)
210
+
}
211
+
return nil
212
+
})
213
+
}
214
+
}
215
+
216
+
returndef := "error"
217
+
if s.Output != nil {
218
+
switch s.Output.Encoding {
219
+
case "application/json":
220
+
outname := shortname + "_Output"
221
+
if s.Output.Schema.Type == "ref" {
222
+
outname, _ = s.namesFromRef(s.Output.Schema.Ref)
223
+
}
224
+
returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname)
225
+
case "application/cbor", "application/vnd.ipld.car", "*/*":
226
+
returndef = fmt.Sprintf("(io.Reader, error)")
227
+
default:
228
+
return fmt.Errorf("unrecognized output encoding (handler stub): %q", s.Output.Encoding)
229
+
}
230
+
}
231
+
232
+
if s.Input != nil {
233
+
switch s.Input.Encoding {
234
+
case "application/json":
235
+
paramtypes = append(paramtypes, fmt.Sprintf("input *%s.%s_Input", impname, shortname))
236
+
case "application/cbor":
237
+
paramtypes = append(paramtypes, "r io.Reader")
238
+
}
239
+
}
240
+
241
+
pf("func (s *Server) handle%s(%s) %s {\n", fname, strings.Join(paramtypes, ","), returndef)
242
+
pf("panic(\"not yet implemented\")\n}\n\n")
243
+
244
+
return nil
245
+
}
246
+
247
+
func (s *TypeSchema) WriteRPCHandler(w io.Writer, fname, shortname, impname string) error {
248
+
pf := printerf(w)
249
+
tname := shortname
250
+
251
+
pf("func (s *Server) Handle%s(c echo.Context) error {\n", fname)
252
+
253
+
pf("ctx, span := otel.Tracer(\"server\").Start(c.Request().Context(), %q)\n", "Handle"+fname)
254
+
pf("defer span.End()\n")
255
+
256
+
paramtypes := []string{"ctx context.Context"}
257
+
params := []string{"ctx"}
258
+
if s.Type == "query" {
259
+
if s.Parameters != nil {
260
+
// TODO(bnewbold): could be handling 'nullable' here
261
+
required := make(map[string]bool)
262
+
for _, r := range s.Parameters.Required {
263
+
required[r] = true
264
+
}
265
+
for k, v := range s.Parameters.Properties {
266
+
if v.Default != nil {
267
+
required[k] = true
268
+
}
269
+
}
270
+
if err := orderedMapIter(s.Parameters.Properties, func(k string, t *TypeSchema) error {
271
+
switch t.Type {
272
+
case "string":
273
+
params = append(params, k)
274
+
paramtypes = append(paramtypes, k+" string")
275
+
pf("%s := c.QueryParam(\"%s\")\n", k, k)
276
+
case "integer":
277
+
params = append(params, k)
278
+
279
+
if !required[k] {
280
+
paramtypes = append(paramtypes, k+" *int")
281
+
pf(`
282
+
var %s *int
283
+
if p := c.QueryParam("%s"); p != "" {
284
+
%s_val, err := strconv.Atoi(p)
285
+
if err != nil {
286
+
return err
287
+
}
288
+
%s = &%s_val
289
+
}
290
+
`, k, k, k, k, k)
291
+
} else if t.Default != nil {
292
+
paramtypes = append(paramtypes, k+" int")
293
+
pf(`
294
+
var %s int
295
+
if p := c.QueryParam("%s"); p != "" {
296
+
var err error
297
+
%s, err = strconv.Atoi(p)
298
+
if err != nil {
299
+
return err
300
+
}
301
+
} else {
302
+
%s = %d
303
+
}
304
+
`, k, k, k, k, int(t.Default.(float64)))
305
+
} else {
306
+
307
+
paramtypes = append(paramtypes, k+" int")
308
+
pf(`
309
+
%s, err := strconv.Atoi(c.QueryParam("%s"))
310
+
if err != nil {
311
+
return err
312
+
}
313
+
`, k, k)
314
+
}
315
+
316
+
case "float":
317
+
return fmt.Errorf("non-integer numbers currently unsupported")
318
+
case "boolean":
319
+
params = append(params, k)
320
+
if !required[k] {
321
+
paramtypes = append(paramtypes, k+" *bool")
322
+
pf(`
323
+
var %s *bool
324
+
if p := c.QueryParam("%s"); p != "" {
325
+
%s_val, err := strconv.ParseBool(p)
326
+
if err != nil {
327
+
return err
328
+
}
329
+
%s = &%s_val
330
+
}
331
+
`, k, k, k, k, k)
332
+
} else if t.Default != nil {
333
+
paramtypes = append(paramtypes, k+" bool")
334
+
pf(`
335
+
var %s bool
336
+
if p := c.QueryParam("%s"); p != "" {
337
+
var err error
338
+
%s, err = strconv.ParseBool(p)
339
+
if err != nil {
340
+
return err
341
+
}
342
+
} else {
343
+
%s = %v
344
+
}
345
+
`, k, k, k, k, t.Default.(bool))
346
+
} else {
347
+
348
+
paramtypes = append(paramtypes, k+" bool")
349
+
pf(`
350
+
%s, err := strconv.ParseBool(c.QueryParam("%s"))
351
+
if err != nil {
352
+
return err
353
+
}
354
+
`, k, k)
355
+
}
356
+
357
+
case "array":
358
+
if t.Items.Type != "string" {
359
+
return fmt.Errorf("currently only string arrays are supported in query params")
360
+
}
361
+
paramtypes = append(paramtypes, k+" []string")
362
+
params = append(params, k)
363
+
pf(`
364
+
%s := c.QueryParams()["%s"]
365
+
`, k, k)
366
+
367
+
default:
368
+
return fmt.Errorf("unsupported handler parameter type: %s", t.Type)
369
+
}
370
+
return nil
371
+
}); err != nil {
372
+
return err
373
+
}
374
+
}
375
+
} else if s.Type == "procedure" {
376
+
if s.Input != nil {
377
+
intname := impname + "." + tname + "_Input"
378
+
switch s.Input.Encoding {
379
+
case EncodingJSON:
380
+
pf(`
381
+
var body %s
382
+
if err := c.Bind(&body); err != nil {
383
+
return err
384
+
}
385
+
`, intname)
386
+
paramtypes = append(paramtypes, "body *"+intname)
387
+
params = append(params, "&body")
388
+
case EncodingCBOR:
389
+
pf("body := c.Request().Body\n")
390
+
paramtypes = append(paramtypes, "r io.Reader")
391
+
params = append(params, "body")
392
+
case EncodingANY:
393
+
pf("body := c.Request().Body\n")
394
+
pf("contentType := c.Request().Header.Get(\"Content-Type\")\n")
395
+
paramtypes = append(paramtypes, "r io.Reader", "contentType string")
396
+
params = append(params, "body", "contentType")
397
+
case EncodingMP4:
398
+
pf("body := c.Request().Body\n")
399
+
paramtypes = append(paramtypes, "r io.Reader")
400
+
params = append(params, "body")
401
+
default:
402
+
return fmt.Errorf("unrecognized input encoding: %q", s.Input.Encoding)
403
+
}
404
+
}
405
+
} else {
406
+
return fmt.Errorf("can only generate handlers for queries or procedures")
407
+
}
408
+
409
+
assign := "handleErr"
410
+
returndef := "error"
411
+
if s.Output != nil {
412
+
switch s.Output.Encoding {
413
+
case EncodingJSON:
414
+
assign = "out, handleErr"
415
+
outname := tname + "_Output"
416
+
if s.Output.Schema.Type == "ref" {
417
+
outname, _ = s.namesFromRef(s.Output.Schema.Ref)
418
+
}
419
+
pf("var out *%s.%s\n", impname, outname)
420
+
returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname)
421
+
case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4:
422
+
assign = "out, handleErr"
423
+
pf("var out io.Reader\n")
424
+
returndef = "(io.Reader, error)"
425
+
default:
426
+
return fmt.Errorf("unrecognized output encoding (RPC output handler): %q", s.Output.Encoding)
427
+
}
428
+
}
429
+
pf("var handleErr error\n")
430
+
pf("// func (s *Server) handle%s(%s) %s\n", fname, strings.Join(paramtypes, ","), returndef)
431
+
pf("%s = s.handle%s(%s)\n", assign, fname, strings.Join(params, ","))
432
+
pf("if handleErr != nil {\nreturn handleErr\n}\n")
433
+
434
+
if s.Output != nil {
435
+
switch s.Output.Encoding {
436
+
case EncodingJSON:
437
+
pf("return c.JSON(200, out)\n}\n\n")
438
+
case EncodingANY:
439
+
pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n")
440
+
case EncodingCBOR:
441
+
pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n")
442
+
case EncodingCAR:
443
+
pf("return c.Stream(200, \"application/vnd.ipld.car\", out)\n}\n\n")
444
+
case EncodingJSONL:
445
+
pf("return c.Stream(200, \"application/jsonl\", out)\n}\n\n")
446
+
case EncodingMP4:
447
+
pf("return c.Stream(200, \"video/mp4\", out)\n}\n\n")
448
+
default:
449
+
return fmt.Errorf("unrecognized output encoding (RPC output handler return): %q", s.Output.Encoding)
450
+
}
451
+
} else {
452
+
pf("return nil\n}\n\n")
453
+
}
454
+
455
+
return nil
456
+
}
457
+
458
+
func (s *TypeSchema) namesFromRef(r string) (string, string) {
459
+
ts, err := s.lookupRef(r)
460
+
if err != nil {
461
+
panic(err)
462
+
}
463
+
464
+
if ts.prefix == "" {
465
+
panic(fmt.Sprintf("no prefix for referenced type: %s", ts.id))
466
+
}
467
+
468
+
if s.prefix == "" {
469
+
panic(fmt.Sprintf("no prefix for referencing type: %q %q", s.id, s.defName))
470
+
}
471
+
472
+
// TODO: probably not technically correct, but i'm kinda over how lexicon
473
+
// tries to enforce application logic in a schema language
474
+
if ts.Type == "string" {
475
+
return "INVALID", "string"
476
+
}
477
+
478
+
var pkg string
479
+
if ts.prefix != s.prefix {
480
+
pkg = importNameForPrefix(ts.prefix) + "."
481
+
}
482
+
483
+
tname := pkg + ts.TypeName()
484
+
vname := tname
485
+
if strings.Contains(vname, ".") {
486
+
// Trim the package name from the variable name
487
+
vname = strings.Split(vname, ".")[1]
488
+
}
489
+
490
+
return vname, tname
491
+
}
492
+
493
+
func (s *TypeSchema) TypeName() string {
494
+
if s.id == "" {
495
+
panic("type schema hint fields not set")
496
+
}
497
+
if s.prefix == "" {
498
+
panic("why no prefix?")
499
+
}
500
+
n := nameFromID(s.id, s.prefix)
501
+
if s.defName != "main" {
502
+
n += "_" + strings.Title(s.defName)
503
+
}
504
+
505
+
if s.Type == "array" {
506
+
n = "[]" + n
507
+
508
+
if s.Items.Type == "union" {
509
+
n = n + "_Elem"
510
+
}
511
+
}
512
+
513
+
return n
514
+
}
515
+
516
+
// name: enclosing type name
517
+
// k: field name
518
+
// v: field TypeSchema
519
+
func (s *TypeSchema) typeNameForField(name, k string, v TypeSchema) (string, error) {
520
+
switch v.Type {
521
+
case "string":
522
+
return "string", nil
523
+
case "float":
524
+
return "float64", nil
525
+
case "integer":
526
+
return "int64", nil
527
+
case "boolean":
528
+
return "bool", nil
529
+
case "object":
530
+
return "*" + name + "_" + strings.Title(k), nil
531
+
case "ref":
532
+
_, tn := s.namesFromRef(v.Ref)
533
+
if tn[0] == '[' {
534
+
return tn, nil
535
+
}
536
+
return "*" + tn, nil
537
+
case "datetime":
538
+
// TODO: maybe do a native type?
539
+
return "string", nil
540
+
case "unknown":
541
+
// NOTE: sometimes a record, for which we want LexiconTypeDecoder, sometimes any object
542
+
if k == "didDoc" || k == "plcOp" {
543
+
return "interface{}", nil
544
+
} else {
545
+
return "*util.LexiconTypeDecoder", nil
546
+
}
547
+
case "union":
548
+
return "*" + name + "_" + strings.Title(k), nil
549
+
case "blob":
550
+
return "*util.LexBlob", nil
551
+
case "array":
552
+
subt, err := s.typeNameForField(name+"_"+strings.Title(k), "Elem", *v.Items)
553
+
if err != nil {
554
+
return "", err
555
+
}
556
+
557
+
return "[]" + subt, nil
558
+
case "cid-link":
559
+
return "util.LexLink", nil
560
+
case "bytes":
561
+
return "util.LexBytes", nil
562
+
default:
563
+
return "", fmt.Errorf("field %q in %s has unsupported type name (%s)", k, name, v.Type)
564
+
}
565
+
}
566
+
567
+
func (ts *TypeSchema) lookupRef(ref string) (*TypeSchema, error) {
568
+
fqref := ref
569
+
if strings.HasPrefix(ref, "#") {
570
+
fmt.Println("updating fqref: ", ts.id)
571
+
fqref = ts.id + ref
572
+
}
573
+
rr, ok := ts.defMap[fqref]
574
+
if !ok {
575
+
fmt.Println(ts.defMap)
576
+
panic(fmt.Sprintf("no such ref: %q", fqref))
577
+
}
578
+
579
+
return rr.Type, nil
580
+
}
581
+
582
+
// name is the top level type name from outputType
583
+
// WriteType is only called on a top level TypeSchema
584
+
func (ts *TypeSchema) WriteType(name string, w io.Writer) error {
585
+
name = strings.Title(name)
586
+
if err := ts.writeTypeDefinition(name, w); err != nil {
587
+
return err
588
+
}
589
+
590
+
if err := ts.writeTypeMethods(name, w); err != nil {
591
+
return err
592
+
}
593
+
594
+
return nil
595
+
}
596
+
597
+
// name is the top level type name from outputType
598
+
// writeTypeDefinition is not called recursively, but only on a top level TypeSchema
599
+
func (ts *TypeSchema) writeTypeDefinition(name string, w io.Writer) error {
600
+
pf := printerf(w)
601
+
602
+
switch {
603
+
case strings.HasSuffix(name, "_Output"):
604
+
pf("// %s is the output of a %s call.\n", name, ts.id)
605
+
case strings.HasSuffix(name, "Input"):
606
+
pf("// %s is the input argument to a %s call.\n", name, ts.id)
607
+
case ts.defName != "":
608
+
pf("// %s is a %q in the %s schema.\n", name, ts.defName, ts.id)
609
+
}
610
+
if ts.Description != "" {
611
+
pf("//\n// %s\n", ts.Description)
612
+
}
613
+
614
+
switch ts.Type {
615
+
case "string":
616
+
// TODO: deal with max length
617
+
pf("type %s string\n", name)
618
+
case "float":
619
+
pf("type %s float64\n", name)
620
+
case "integer":
621
+
pf("type %s int64\n", name)
622
+
case "boolean":
623
+
pf("type %s bool\n", name)
624
+
case "object":
625
+
if ts.needsType {
626
+
pf("//\n// RECORDTYPE: %s\n", name)
627
+
}
628
+
629
+
pf("type %s struct {\n", name)
630
+
631
+
if ts.needsType {
632
+
var omit string
633
+
if ts.id == "com.atproto.repo.strongRef" { // TODO: hack
634
+
omit = ",omitempty"
635
+
}
636
+
cval := ts.id
637
+
if ts.defName != "" && ts.defName != "main" {
638
+
cval += "#" + ts.defName
639
+
}
640
+
pf("\tLexiconTypeID string `json:\"$type,const=%s%s\" cborgen:\"$type,const=%s%s\"`\n", cval, omit, cval, omit)
641
+
} else {
642
+
//pf("\tLexiconTypeID string `json:\"$type,omitempty\" cborgen:\"$type,omitempty\"`\n")
643
+
}
644
+
645
+
required := make(map[string]bool)
646
+
for _, req := range ts.Required {
647
+
required[req] = true
648
+
}
649
+
650
+
nullable := make(map[string]bool)
651
+
for _, req := range ts.Nullable {
652
+
nullable[req] = true
653
+
}
654
+
655
+
if err := orderedMapIter(ts.Properties, func(k string, v *TypeSchema) error {
656
+
goname := strings.Title(k)
657
+
658
+
tname, err := ts.typeNameForField(name, k, *v)
659
+
if err != nil {
660
+
return err
661
+
}
662
+
663
+
var ptr string
664
+
var omit string
665
+
if !required[k] {
666
+
omit = ",omitempty"
667
+
if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") {
668
+
ptr = "*"
669
+
}
670
+
}
671
+
if nullable[k] {
672
+
omit = ""
673
+
if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") {
674
+
ptr = "*"
675
+
}
676
+
}
677
+
678
+
jsonOmit, cborOmit := omit, omit
679
+
680
+
// Don't generate pointers to lexbytes, as it's already a pointer.
681
+
if ptr == "*" && tname == "util.LexBytes" {
682
+
ptr = ""
683
+
}
684
+
685
+
// TODO: hard-coded hacks for now, making this type (with underlying type []byte)
686
+
// be omitempty.
687
+
if ptr == "" && tname == "util.LexBytes" {
688
+
jsonOmit = ",omitempty"
689
+
cborOmit = ",omitempty"
690
+
}
691
+
692
+
if name == "LabelDefs_SelfLabels" && k == "values" {
693
+
// TODO: regularize weird hack?
694
+
cborOmit += ",preservenil"
695
+
}
696
+
697
+
if v.Description != "" {
698
+
pf("\t// %s: %s\n", k, v.Description)
699
+
}
700
+
pf("\t%s %s%s `json:\"%s%s\" cborgen:\"%s%s\"`\n", goname, ptr, tname, k, jsonOmit, k, cborOmit)
701
+
return nil
702
+
}); err != nil {
703
+
return err
704
+
}
705
+
706
+
pf("}\n\n")
707
+
708
+
case "array":
709
+
tname, err := ts.typeNameForField(name, "elem", *ts.Items)
710
+
if err != nil {
711
+
return err
712
+
}
713
+
714
+
pf("type %s []%s\n", name, tname)
715
+
716
+
case "union":
717
+
if len(ts.Refs) > 0 {
718
+
pf("type %s struct {\n", name)
719
+
for _, r := range ts.Refs {
720
+
vname, tname := ts.namesFromRef(r)
721
+
pf("\t%s *%s\n", vname, tname)
722
+
}
723
+
pf("}\n\n")
724
+
}
725
+
default:
726
+
return fmt.Errorf("%s has unrecognized type: %s", name, ts.Type)
727
+
}
728
+
729
+
return nil
730
+
}
731
+
732
+
func (ts *TypeSchema) writeTypeMethods(name string, w io.Writer) error {
733
+
switch ts.Type {
734
+
case "string", "float", "array", "boolean", "integer":
735
+
return nil
736
+
case "object":
737
+
if err := ts.writeJsonMarshalerObject(name, w); err != nil {
738
+
return err
739
+
}
740
+
741
+
if err := ts.writeJsonUnmarshalerObject(name, w); err != nil {
742
+
return err
743
+
}
744
+
745
+
return nil
746
+
case "union":
747
+
if len(ts.Refs) > 0 {
748
+
reft, err := ts.lookupRef(ts.Refs[0])
749
+
if err != nil {
750
+
return err
751
+
}
752
+
753
+
if reft.Type == "string" {
754
+
return nil
755
+
}
756
+
757
+
if err := ts.writeJsonMarshalerEnum(name, w); err != nil {
758
+
return err
759
+
}
760
+
761
+
if err := ts.writeJsonUnmarshalerEnum(name, w); err != nil {
762
+
return err
763
+
}
764
+
765
+
if ts.needsCbor {
766
+
if err := ts.writeCborMarshalerEnum(name, w); err != nil {
767
+
return err
768
+
}
769
+
770
+
if err := ts.writeCborUnmarshalerEnum(name, w); err != nil {
771
+
return err
772
+
}
773
+
}
774
+
775
+
return nil
776
+
}
777
+
778
+
return fmt.Errorf("%q unsupported for marshaling", name)
779
+
default:
780
+
return fmt.Errorf("%q has unrecognized type: %s", name, ts.Type)
781
+
}
782
+
}
783
+
784
+
func (ts *TypeSchema) writeJsonMarshalerObject(name string, w io.Writer) error {
785
+
return nil // no need for a special json marshaler right now
786
+
}
787
+
788
+
func (ts *TypeSchema) writeJsonMarshalerEnum(name string, w io.Writer) error {
789
+
pf := printerf(w)
790
+
pf("func (t *%s) MarshalJSON() ([]byte, error) {\n", name)
791
+
792
+
for _, e := range ts.Refs {
793
+
vname, _ := ts.namesFromRef(e)
794
+
if strings.HasPrefix(e, "#") {
795
+
e = ts.id + e
796
+
}
797
+
798
+
pf("\tif t.%s != nil {\n", vname)
799
+
pf("\tt.%s.LexiconTypeID = %q\n", vname, e)
800
+
pf("\t\treturn json.Marshal(t.%s)\n\t}\n", vname)
801
+
}
802
+
803
+
pf("\treturn nil, fmt.Errorf(\"cannot marshal empty enum\")\n}\n")
804
+
return nil
805
+
}
806
+
807
+
func (s *TypeSchema) writeJsonUnmarshalerObject(name string, w io.Writer) error {
808
+
// TODO: would be nice to add some validation...
809
+
return nil
810
+
//pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name)
811
+
}
812
+
813
+
func (ts *TypeSchema) writeJsonUnmarshalerEnum(name string, w io.Writer) error {
814
+
pf := printerf(w)
815
+
pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name)
816
+
pf("\ttyp, err := util.TypeExtract(b)\n")
817
+
pf("\tif err != nil {\n\t\treturn err\n\t}\n\n")
818
+
pf("\tswitch typ {\n")
819
+
for _, e := range ts.Refs {
820
+
if strings.HasPrefix(e, "#") {
821
+
e = ts.id + e
822
+
}
823
+
824
+
vname, goname := ts.namesFromRef(e)
825
+
826
+
pf("\t\tcase \"%s\":\n", e)
827
+
pf("\t\t\tt.%s = new(%s)\n", vname, goname)
828
+
pf("\t\t\treturn json.Unmarshal(b, t.%s)\n", vname)
829
+
}
830
+
831
+
if ts.Closed {
832
+
pf(`
833
+
default:
834
+
return fmt.Errorf("closed enums must have a matching value")
835
+
`)
836
+
} else {
837
+
pf(`
838
+
default:
839
+
return nil
840
+
`)
841
+
842
+
}
843
+
844
+
pf("\t}\n")
845
+
pf("}\n\n")
846
+
847
+
return nil
848
+
}
849
+
850
+
func (ts *TypeSchema) writeCborMarshalerEnum(name string, w io.Writer) error {
851
+
pf := printerf(w)
852
+
pf("func (t *%s) MarshalCBOR(w io.Writer) error {\n", name)
853
+
pf(`
854
+
if t == nil {
855
+
_, err := w.Write(cbg.CborNull)
856
+
return err
857
+
}
858
+
`)
859
+
860
+
for _, e := range ts.Refs {
861
+
vname, _ := ts.namesFromRef(e)
862
+
pf("\tif t.%s != nil {\n", vname)
863
+
pf("\t\treturn t.%s.MarshalCBOR(w)\n\t}\n", vname)
864
+
}
865
+
866
+
pf("\treturn fmt.Errorf(\"cannot cbor marshal empty enum\")\n}\n")
867
+
return nil
868
+
}
869
+
870
+
func (ts *TypeSchema) writeCborUnmarshalerEnum(name string, w io.Writer) error {
871
+
pf := printerf(w)
872
+
pf("func (t *%s) UnmarshalCBOR(r io.Reader) error {\n", name)
873
+
pf("\ttyp, b, err := util.CborTypeExtractReader(r)\n")
874
+
pf("\tif err != nil {\n\t\treturn err\n\t}\n\n")
875
+
pf("\tswitch typ {\n")
876
+
for _, e := range ts.Refs {
877
+
if strings.HasPrefix(e, "#") {
878
+
e = ts.id + e
879
+
}
880
+
881
+
vname, goname := ts.namesFromRef(e)
882
+
883
+
pf("\t\tcase \"%s\":\n", e)
884
+
pf("\t\t\tt.%s = new(%s)\n", vname, goname)
885
+
pf("\t\t\treturn t.%s.UnmarshalCBOR(bytes.NewReader(b))\n", vname)
886
+
}
887
+
888
+
if ts.Closed {
889
+
pf(`
890
+
default:
891
+
return fmt.Errorf("closed enums must have a matching value")
892
+
`)
893
+
} else {
894
+
pf(`
895
+
default:
896
+
return nil
897
+
`)
898
+
899
+
}
900
+
901
+
pf("\t}\n")
902
+
pf("}\n\n")
903
+
904
+
return nil
905
+
}