fork of indigo with slightly nicer lexgen
at main 24 kB view raw
1package lex 2 3import ( 4 "fmt" 5 "io" 6 "slices" 7 "strings" 8) 9 10type OutputType struct { 11 Encoding string `json:"encoding"` 12 Schema *TypeSchema `json:"schema"` 13} 14 15type InputType struct { 16 Encoding string `json:"encoding"` 17 Schema *TypeSchema `json:"schema"` 18} 19 20// TypeSchema is the content of a lexicon schema file "defs" section. 21// https://atproto.com/specs/lexicon 22type TypeSchema struct { 23 prefix string // prefix of a major package being processed, e.g. com.atproto 24 id string // parent Schema.ID 25 defName string // parent Schema.Defs[defName] points to this *TypeSchema 26 defMap map[string]*ExtDef 27 needsCbor bool 28 needsType bool 29 30 Type string `json:"type"` 31 Key string `json:"key"` 32 Description string `json:"description"` 33 Parameters *TypeSchema `json:"parameters"` 34 Input *InputType `json:"input"` 35 Output *OutputType `json:"output"` 36 Record *TypeSchema `json:"record"` 37 38 Ref string `json:"ref"` 39 Refs []string `json:"refs"` 40 Required []string `json:"required"` 41 Nullable []string `json:"nullable"` 42 Properties map[string]*TypeSchema `json:"properties"` 43 MaxLength int `json:"maxLength"` 44 Items *TypeSchema `json:"items"` 45 Const any `json:"const"` 46 Enum []string `json:"enum"` 47 Closed bool `json:"closed"` 48 49 Default any `json:"default"` 50 Minimum any `json:"minimum"` 51 Maximum any `json:"maximum"` 52} 53 54func (s *TypeSchema) WriteRPC(w io.Writer, typename, inputname string) error { 55 pf := printerf(w) 56 fname := typename 57 58 params := "ctx context.Context, c util.LexClient" 59 inpvar := "nil" 60 inpenc := "" 61 62 if s.Input != nil { 63 inpvar = "input" 64 inpenc = s.Input.Encoding 65 switch s.Input.Encoding { 66 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingMP4: 67 params = fmt.Sprintf("%s, input io.Reader", params) 68 case EncodingJSON: 69 params = fmt.Sprintf("%s, input *%s", params, inputname) 70 71 default: 72 return fmt.Errorf("unsupported input encoding (RPC input): %q", s.Input.Encoding) 73 } 74 } 75 76 if s.Parameters != nil { 77 if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error { 78 tn, err := s.typeNameForField(name, "", *t) 79 if err != nil { 80 return err 81 } 82 83 // TODO: deal with optional params 84 params = params + fmt.Sprintf(", %s %s", name, tn) 85 return nil 86 }); err != nil { 87 return err 88 } 89 } 90 91 out := "error" 92 if s.Output != nil { 93 switch s.Output.Encoding { 94 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4: 95 out = "([]byte, error)" 96 case EncodingJSON: 97 outname := fname + "_Output" 98 if s.Output.Schema.Type == "ref" { 99 _, outname = s.namesFromRef(s.Output.Schema.Ref) 100 } 101 102 out = fmt.Sprintf("(*%s, error)", outname) 103 default: 104 return fmt.Errorf("unrecognized encoding scheme (RPC output): %q", s.Output.Encoding) 105 } 106 } 107 108 pf("// %s calls the XRPC method %q.\n", fname, s.id) 109 if s.Parameters != nil && len(s.Parameters.Properties) > 0 { 110 pf("//\n") 111 if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error { 112 if t.Description != "" { 113 pf("// %s: %s\n", name, t.Description) 114 } 115 return nil 116 }); err != nil { 117 return err 118 } 119 } 120 pf("func %s(%s) %s {\n", fname, params, out) 121 122 outvar := "nil" 123 errRet := "err" 124 outRet := "nil" 125 if s.Output != nil { 126 switch s.Output.Encoding { 127 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4: 128 pf("buf := new(bytes.Buffer)\n") 129 outvar = "buf" 130 errRet = "nil, err" 131 outRet = "buf.Bytes(), nil" 132 case EncodingJSON: 133 outname := fname + "_Output" 134 if s.Output.Schema.Type == "ref" { 135 _, outname = s.namesFromRef(s.Output.Schema.Ref) 136 } 137 pf("\tvar out %s\n", outname) 138 outvar = "&out" 139 errRet = "nil, err" 140 outRet = "&out, nil" 141 default: 142 return fmt.Errorf("unrecognized output encoding (func signature): %q", s.Output.Encoding) 143 } 144 } 145 146 queryparams := "nil" 147 if s.Parameters != nil { 148 queryparams = "params" 149 pf("\n\tparams := map[string]interface{}{}\n") 150 if err := orderedMapIter(s.Parameters.Properties, func(name string, t *TypeSchema) error { 151 if slices.Contains(s.Parameters.Required, name) || slices.Contains(s.Parameters.Nullable, name) { 152 pf("params[\"%s\"] = %s\n", name, name) 153 } else { 154 // if parameter isn't required, only include conditionally 155 switch t.Type { 156 case "integer": 157 pf("if %s != 0 { params[\"%s\"] = %s }\n", name, name, name) 158 case "string": 159 pf("if %s != \"\" { params[\"%s\"] = %s }\n", name, name, name) 160 case "array": 161 pf("if len(%s) != 0 { params[\"%s\"] = %s }\n", name, name, name) 162 case "boolean": 163 pf("if %s { params[\"%s\"] = %s }\n", name, name, name) 164 default: 165 return fmt.Errorf("unhandled query param type: %s", t.Type) 166 } 167 } 168 return nil 169 }); err != nil { 170 return err 171 } 172 } 173 174 var reqtype string 175 switch s.Type { 176 case "procedure": 177 reqtype = "util.Procedure" 178 case "query": 179 reqtype = "util.Query" 180 default: 181 return fmt.Errorf("can only generate RPC for Query or Procedure (got %s)", s.Type) 182 } 183 184 pf("\tif err := c.LexDo(ctx, %s, %q, \"%s\", %s, %s, %s); err != nil {\n", reqtype, inpenc, s.id, queryparams, inpvar, outvar) 185 pf("\t\treturn %s\n", errRet) 186 pf("\t}\n\n") 187 pf("\treturn %s\n", outRet) 188 pf("}\n\n") 189 190 return nil 191} 192 193func (s *TypeSchema) WriteHandlerStub(w io.Writer, fname, shortname, impname string) error { 194 pf := printerf(w) 195 paramtypes := []string{"ctx context.Context"} 196 if s.Type == "query" { 197 198 if s.Parameters != nil { 199 var required map[string]bool 200 if s.Parameters.Required != nil { 201 required = make(map[string]bool) 202 for _, r := range s.Required { 203 required[r] = true 204 } 205 } 206 orderedMapIter[*TypeSchema](s.Parameters.Properties, func(k string, t *TypeSchema) error { 207 switch t.Type { 208 case "string": 209 paramtypes = append(paramtypes, k+" string") 210 case "integer": 211 // TODO(bnewbold) could be handling "nullable" here 212 if required != nil && !required[k] { 213 paramtypes = append(paramtypes, k+" *int") 214 } else { 215 paramtypes = append(paramtypes, k+" int") 216 } 217 case "float": 218 return fmt.Errorf("non-integer numbers currently unsupported") 219 case "array": 220 paramtypes = append(paramtypes, k+"[]"+t.Items.Type) 221 default: 222 return fmt.Errorf("unsupported handler parameter type: %s", t.Type) 223 } 224 return nil 225 }) 226 } 227 } 228 229 returndef := "error" 230 if s.Output != nil { 231 switch s.Output.Encoding { 232 case "application/json": 233 outname := shortname + "_Output" 234 if s.Output.Schema.Type == "ref" { 235 outname, _ = s.namesFromRef(s.Output.Schema.Ref) 236 } 237 returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname) 238 case "application/cbor", "application/vnd.ipld.car", "*/*": 239 returndef = "(io.Reader, error)" 240 default: 241 return fmt.Errorf("unrecognized output encoding (handler stub): %q", s.Output.Encoding) 242 } 243 } 244 245 if s.Input != nil { 246 switch s.Input.Encoding { 247 case "application/json": 248 paramtypes = append(paramtypes, fmt.Sprintf("input *%s.%s_Input", impname, shortname)) 249 case "application/cbor": 250 paramtypes = append(paramtypes, "r io.Reader") 251 } 252 } 253 254 pf("func (s *Server) handle%s(%s) %s {\n", fname, strings.Join(paramtypes, ","), returndef) 255 pf("panic(\"not yet implemented\")\n}\n\n") 256 257 return nil 258} 259 260func (s *TypeSchema) WriteRPCHandler(w io.Writer, fname, shortname, impname string) error { 261 pf := printerf(w) 262 tname := shortname 263 264 pf("func (s *Server) Handle%s(c echo.Context) error {\n", fname) 265 266 pf("ctx, span := otel.Tracer(\"server\").Start(c.Request().Context(), %q)\n", "Handle"+fname) 267 pf("defer span.End()\n") 268 269 paramtypes := []string{"ctx context.Context"} 270 params := []string{"ctx"} 271 if s.Type == "query" { 272 if s.Parameters != nil { 273 // TODO(bnewbold): could be handling 'nullable' here 274 required := make(map[string]bool) 275 for _, r := range s.Parameters.Required { 276 required[r] = true 277 } 278 for k, v := range s.Parameters.Properties { 279 if v.Default != nil { 280 required[k] = true 281 } 282 } 283 if err := orderedMapIter(s.Parameters.Properties, func(k string, t *TypeSchema) error { 284 switch t.Type { 285 case "string": 286 params = append(params, k) 287 paramtypes = append(paramtypes, k+" string") 288 pf("%s := c.QueryParam(\"%s\")\n", k, k) 289 case "integer": 290 params = append(params, k) 291 292 if !required[k] { 293 paramtypes = append(paramtypes, k+" *int") 294 pf(` 295var %s *int 296if p := c.QueryParam("%s"); p != "" { 297 %s_val, err := strconv.Atoi(p) 298 if err != nil { 299 return err 300 } 301 %s = &%s_val 302} 303`, k, k, k, k, k) 304 } else if t.Default != nil { 305 paramtypes = append(paramtypes, k+" int") 306 pf(` 307var %s int 308if p := c.QueryParam("%s"); p != "" { 309var err error 310%s, err = strconv.Atoi(p) 311if err != nil { 312 return err 313} 314} else { 315 %s = %d 316} 317`, k, k, k, k, int(t.Default.(float64))) 318 } else { 319 320 paramtypes = append(paramtypes, k+" int") 321 pf(` 322%s, err := strconv.Atoi(c.QueryParam("%s")) 323if err != nil { 324 return err 325} 326`, k, k) 327 } 328 329 case "float": 330 return fmt.Errorf("non-integer numbers currently unsupported") 331 case "boolean": 332 params = append(params, k) 333 if !required[k] { 334 paramtypes = append(paramtypes, k+" *bool") 335 pf(` 336var %s *bool 337if p := c.QueryParam("%s"); p != "" { 338 %s_val, err := strconv.ParseBool(p) 339 if err != nil { 340 return err 341 } 342 %s = &%s_val 343} 344`, k, k, k, k, k) 345 } else if t.Default != nil { 346 paramtypes = append(paramtypes, k+" bool") 347 pf(` 348var %s bool 349if p := c.QueryParam("%s"); p != "" { 350var err error 351%s, err = strconv.ParseBool(p) 352if err != nil { 353 return err 354} 355} else { 356 %s = %v 357} 358`, k, k, k, k, t.Default.(bool)) 359 } else { 360 361 paramtypes = append(paramtypes, k+" bool") 362 pf(` 363%s, err := strconv.ParseBool(c.QueryParam("%s")) 364if err != nil { 365 return err 366} 367`, k, k) 368 } 369 370 case "array": 371 if t.Items.Type != "string" { 372 return fmt.Errorf("currently only string arrays are supported in query params") 373 } 374 paramtypes = append(paramtypes, k+" []string") 375 params = append(params, k) 376 pf(` 377%s := c.QueryParams()["%s"] 378`, k, k) 379 380 default: 381 return fmt.Errorf("unsupported handler parameter type: %s", t.Type) 382 } 383 return nil 384 }); err != nil { 385 return err 386 } 387 } 388 } else if s.Type == "procedure" { 389 if s.Input != nil { 390 intname := impname + "." + tname + "_Input" 391 switch s.Input.Encoding { 392 case EncodingJSON: 393 pf(` 394var body %s 395if err := c.Bind(&body); err != nil { 396 return err 397} 398`, intname) 399 paramtypes = append(paramtypes, "body *"+intname) 400 params = append(params, "&body") 401 case EncodingCBOR: 402 pf("body := c.Request().Body\n") 403 paramtypes = append(paramtypes, "r io.Reader") 404 params = append(params, "body") 405 case EncodingANY: 406 pf("body := c.Request().Body\n") 407 pf("contentType := c.Request().Header.Get(\"Content-Type\")\n") 408 paramtypes = append(paramtypes, "r io.Reader", "contentType string") 409 params = append(params, "body", "contentType") 410 case EncodingMP4: 411 pf("body := c.Request().Body\n") 412 paramtypes = append(paramtypes, "r io.Reader") 413 params = append(params, "body") 414 default: 415 return fmt.Errorf("unrecognized input encoding: %q", s.Input.Encoding) 416 } 417 } 418 } else { 419 return fmt.Errorf("can only generate handlers for queries or procedures") 420 } 421 422 assign := "handleErr" 423 returndef := "error" 424 if s.Output != nil { 425 switch s.Output.Encoding { 426 case EncodingJSON: 427 assign = "out, handleErr" 428 outname := tname + "_Output" 429 if s.Output.Schema.Type == "ref" { 430 outname, _ = s.namesFromRef(s.Output.Schema.Ref) 431 } 432 pf("var out *%s.%s\n", impname, outname) 433 returndef = fmt.Sprintf("(*%s.%s, error)", impname, outname) 434 case EncodingCBOR, EncodingCAR, EncodingANY, EncodingJSONL, EncodingMP4: 435 assign = "out, handleErr" 436 pf("var out io.Reader\n") 437 returndef = "(io.Reader, error)" 438 default: 439 return fmt.Errorf("unrecognized output encoding (RPC output handler): %q", s.Output.Encoding) 440 } 441 } 442 pf("var handleErr error\n") 443 pf("// func (s *Server) handle%s(%s) %s\n", fname, strings.Join(paramtypes, ","), returndef) 444 pf("%s = s.handle%s(%s)\n", assign, fname, strings.Join(params, ",")) 445 pf("if handleErr != nil {\nreturn handleErr\n}\n") 446 447 if s.Output != nil { 448 switch s.Output.Encoding { 449 case EncodingJSON: 450 pf("return c.JSON(200, out)\n}\n\n") 451 case EncodingANY: 452 pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n") 453 case EncodingCBOR: 454 pf("return c.Stream(200, \"application/octet-stream\", out)\n}\n\n") 455 case EncodingCAR: 456 pf("return c.Stream(200, \"application/vnd.ipld.car\", out)\n}\n\n") 457 case EncodingJSONL: 458 pf("return c.Stream(200, \"application/jsonl\", out)\n}\n\n") 459 case EncodingMP4: 460 pf("return c.Stream(200, \"video/mp4\", out)\n}\n\n") 461 default: 462 return fmt.Errorf("unrecognized output encoding (RPC output handler return): %q", s.Output.Encoding) 463 } 464 } else { 465 pf("return nil\n}\n\n") 466 } 467 468 return nil 469} 470 471func (s *TypeSchema) namesFromRef(r string) (string, string) { 472 ts, err := s.lookupRef(r) 473 if err != nil { 474 panic(err) 475 } 476 477 if ts.prefix == "" { 478 panic(fmt.Sprintf("no prefix for referenced type: %s", ts.id)) 479 } 480 481 if s.prefix == "" { 482 panic(fmt.Sprintf("no prefix for referencing type: %q %q", s.id, s.defName)) 483 } 484 485 // TODO: probably not technically correct, but i'm kinda over how lexicon 486 // tries to enforce application logic in a schema language 487 if ts.Type == "string" { 488 return "INVALID", "string" 489 } 490 491 var pkg string 492 if ts.prefix != s.prefix { 493 pkg = importNameForPrefix(ts.prefix) + "." 494 } 495 496 tname := pkg + ts.TypeName() 497 vname := tname 498 if strings.Contains(vname, ".") { 499 // Trim the package name from the variable name 500 vname = strings.Split(vname, ".")[1] 501 } 502 503 return vname, tname 504} 505 506func (s *TypeSchema) TypeName() string { 507 if s.id == "" { 508 panic("type schema hint fields not set") 509 } 510 if s.prefix == "" { 511 panic("why no prefix?") 512 } 513 n := nameFromID(s.id, s.prefix) 514 if s.defName != "main" { 515 n += "_" + strings.Title(s.defName) 516 } 517 518 if s.Type == "array" { 519 n = "[]" + n 520 521 if s.Items.Type == "union" { 522 n = n + "_Elem" 523 } 524 } 525 526 return n 527} 528 529// name: enclosing type name 530// k: field name 531// v: field TypeSchema 532func (s *TypeSchema) typeNameForField(name, k string, v TypeSchema) (string, error) { 533 switch v.Type { 534 case "string": 535 return "string", nil 536 case "float": 537 return "float64", nil 538 case "integer": 539 return "int64", nil 540 case "boolean": 541 return "bool", nil 542 case "object": 543 return "*" + name + "_" + strings.Title(k), nil 544 case "ref": 545 _, tn := s.namesFromRef(v.Ref) 546 if tn[0] == '[' { 547 return tn, nil 548 } 549 return "*" + tn, nil 550 case "datetime": 551 // TODO: maybe do a native type? 552 return "string", nil 553 case "unknown": 554 // NOTE: sometimes a record, for which we want LexiconTypeDecoder, sometimes any object 555 if k == "didDoc" || k == "plcOp" || k == "meta" { 556 return "interface{}", nil 557 } else { 558 return "*util.LexiconTypeDecoder", nil 559 } 560 case "union": 561 if len(v.Refs) > 0 { 562 return "*" + name + "_" + strings.Title(k), nil 563 } else { 564 // an empty union is effectively an 'unknown', but with mandatory type indicator 565 return "*util.LexiconTypeDecoder", nil 566 } 567 case "blob": 568 return "*util.LexBlob", nil 569 case "array": 570 subt, err := s.typeNameForField(name+"_"+strings.Title(k), "Elem", *v.Items) 571 if err != nil { 572 return "", err 573 } 574 575 return "[]" + subt, nil 576 case "cid-link": 577 return "util.LexLink", nil 578 case "bytes": 579 return "util.LexBytes", nil 580 default: 581 return "", fmt.Errorf("field %q in %s has unsupported type name (%s)", k, name, v.Type) 582 } 583} 584 585func (ts *TypeSchema) lookupRef(ref string) (*TypeSchema, error) { 586 fqref := ref 587 if strings.HasPrefix(ref, "#") { 588 fmt.Println("updating fqref: ", ts.id) 589 fqref = ts.id + ref 590 } 591 rr, ok := ts.defMap[fqref] 592 if !ok { 593 fmt.Println(ts.defMap) 594 panic(fmt.Sprintf("no such ref: %q", fqref)) 595 } 596 597 return rr.Type, nil 598} 599 600// name is the top level type name from outputType 601// WriteType is only called on a top level TypeSchema 602func (ts *TypeSchema) WriteType(name string, w io.Writer) error { 603 name = strings.Title(name) 604 if err := ts.writeTypeDefinition(name, w); err != nil { 605 return err 606 } 607 608 if err := ts.writeTypeMethods(name, w); err != nil { 609 return err 610 } 611 612 return nil 613} 614 615// name is the top level type name from outputType 616// writeTypeDefinition is not called recursively, but only on a top level TypeSchema 617func (ts *TypeSchema) writeTypeDefinition(name string, w io.Writer) error { 618 pf := printerf(w) 619 620 switch { 621 case strings.HasSuffix(name, "_Output"): 622 pf("// %s is the output of a %s call.\n", name, ts.id) 623 case strings.HasSuffix(name, "Input"): 624 pf("// %s is the input argument to a %s call.\n", name, ts.id) 625 case ts.defName != "": 626 pf("// %s is a %q in the %s schema.\n", name, ts.defName, ts.id) 627 } 628 if ts.Description != "" { 629 pf("//\n// %s\n", ts.Description) 630 } 631 632 switch ts.Type { 633 case "string": 634 // TODO: deal with max length 635 pf("type %s string\n", name) 636 case "float": 637 pf("type %s float64\n", name) 638 case "integer": 639 pf("type %s int64\n", name) 640 case "boolean": 641 pf("type %s bool\n", name) 642 case "object": 643 if ts.needsType { 644 pf("//\n// RECORDTYPE: %s\n", name) 645 } 646 647 pf("type %s struct {\n", name) 648 649 if ts.needsType { 650 var omit string 651 if ts.id == "com.atproto.repo.strongRef" { // TODO: hack 652 omit = ",omitempty" 653 } 654 cval := ts.id 655 if ts.defName != "" && ts.defName != "main" { 656 cval += "#" + ts.defName 657 } 658 pf("\tLexiconTypeID string `json:\"$type,const=%s%s\" cborgen:\"$type,const=%s%s\"`\n", cval, omit, cval, omit) 659 } else { 660 //pf("\tLexiconTypeID string `json:\"$type,omitempty\" cborgen:\"$type,omitempty\"`\n") 661 } 662 663 required := make(map[string]bool) 664 for _, req := range ts.Required { 665 required[req] = true 666 } 667 668 nullable := make(map[string]bool) 669 for _, req := range ts.Nullable { 670 nullable[req] = true 671 } 672 673 if err := orderedMapIter(ts.Properties, func(k string, v *TypeSchema) error { 674 goname := strings.Title(k) 675 676 tname, err := ts.typeNameForField(name, k, *v) 677 if err != nil { 678 return err 679 } 680 681 var ptr string 682 var omit string 683 if !required[k] { 684 omit = ",omitempty" 685 if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") { 686 ptr = "*" 687 } 688 } 689 if nullable[k] { 690 omit = "" 691 if !strings.HasPrefix(tname, "*") && !strings.HasPrefix(tname, "[]") { 692 ptr = "*" 693 } 694 } 695 696 jsonOmit, cborOmit := omit, omit 697 698 // Don't generate pointers to lexbytes, as it's already a pointer. 699 if ptr == "*" && tname == "util.LexBytes" { 700 ptr = "" 701 } 702 703 // TODO: hard-coded hacks for now, making this type (with underlying type []byte) 704 // be omitempty. 705 if ptr == "" && tname == "util.LexBytes" { 706 jsonOmit = ",omitempty" 707 cborOmit = ",omitempty" 708 } 709 710 if name == "LabelDefs_SelfLabels" && k == "values" { 711 // TODO: regularize weird hack? 712 cborOmit += ",preservenil" 713 } 714 715 if v.Description != "" { 716 pf("\t// %s: %s\n", k, v.Description) 717 } 718 pf("\t%s %s%s `json:\"%s%s\" cborgen:\"%s%s\"`\n", goname, ptr, tname, k, jsonOmit, k, cborOmit) 719 return nil 720 }); err != nil { 721 return err 722 } 723 724 pf("}\n\n") 725 726 case "array": 727 tname, err := ts.typeNameForField(name, "elem", *ts.Items) 728 if err != nil { 729 return err 730 } 731 732 pf("type %s []%s\n", name, tname) 733 734 case "union": 735 if len(ts.Refs) > 0 { 736 pf("type %s struct {\n", name) 737 for _, r := range ts.Refs { 738 vname, tname := ts.namesFromRef(r) 739 pf("\t%s *%s\n", vname, tname) 740 } 741 pf("}\n\n") 742 } 743 default: 744 return fmt.Errorf("%s has unrecognized type: %s", name, ts.Type) 745 } 746 747 return nil 748} 749 750func (ts *TypeSchema) writeTypeMethods(name string, w io.Writer) error { 751 switch ts.Type { 752 case "string", "float", "array", "boolean", "integer": 753 return nil 754 case "object": 755 if err := ts.writeJsonMarshalerObject(name, w); err != nil { 756 return err 757 } 758 759 if err := ts.writeJsonUnmarshalerObject(name, w); err != nil { 760 return err 761 } 762 763 return nil 764 case "union": 765 if len(ts.Refs) > 0 { 766 reft, err := ts.lookupRef(ts.Refs[0]) 767 if err != nil { 768 return err 769 } 770 771 if reft.Type == "string" { 772 return nil 773 } 774 775 if err := ts.writeJsonMarshalerEnum(name, w); err != nil { 776 return err 777 } 778 779 if err := ts.writeJsonUnmarshalerEnum(name, w); err != nil { 780 return err 781 } 782 783 if ts.needsCbor { 784 if err := ts.writeCborMarshalerEnum(name, w); err != nil { 785 return err 786 } 787 788 if err := ts.writeCborUnmarshalerEnum(name, w); err != nil { 789 return err 790 } 791 } 792 793 return nil 794 } 795 796 return fmt.Errorf("%q unsupported for marshaling", name) 797 default: 798 return fmt.Errorf("%q has unrecognized type: %s", name, ts.Type) 799 } 800} 801 802func (ts *TypeSchema) writeJsonMarshalerObject(name string, w io.Writer) error { 803 return nil // no need for a special json marshaler right now 804} 805 806func (ts *TypeSchema) writeJsonMarshalerEnum(name string, w io.Writer) error { 807 pf := printerf(w) 808 pf("func (t *%s) MarshalJSON() ([]byte, error) {\n", name) 809 810 for _, e := range ts.Refs { 811 vname, _ := ts.namesFromRef(e) 812 if strings.HasPrefix(e, "#") { 813 e = ts.id + e 814 } 815 816 pf("\tif t.%s != nil {\n", vname) 817 pf("\tt.%s.LexiconTypeID = %q\n", vname, e) 818 pf("\t\treturn json.Marshal(t.%s)\n\t}\n", vname) 819 } 820 821 pf("\treturn nil, fmt.Errorf(\"cannot marshal empty enum\")\n}\n") 822 return nil 823} 824 825func (s *TypeSchema) writeJsonUnmarshalerObject(name string, w io.Writer) error { 826 // TODO: would be nice to add some validation... 827 return nil 828 //pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name) 829} 830 831func (ts *TypeSchema) writeJsonUnmarshalerEnum(name string, w io.Writer) error { 832 pf := printerf(w) 833 pf("func (t *%s) UnmarshalJSON(b []byte) (error) {\n", name) 834 pf("\ttyp, err := util.TypeExtract(b)\n") 835 pf("\tif err != nil {\n\t\treturn err\n\t}\n\n") 836 pf("\tswitch typ {\n") 837 for _, e := range ts.Refs { 838 if strings.HasPrefix(e, "#") { 839 e = ts.id + e 840 } 841 842 vname, goname := ts.namesFromRef(e) 843 844 pf("\t\tcase \"%s\":\n", e) 845 pf("\t\t\tt.%s = new(%s)\n", vname, goname) 846 pf("\t\t\treturn json.Unmarshal(b, t.%s)\n", vname) 847 } 848 849 if ts.Closed { 850 pf(` 851 default: 852 return fmt.Errorf("closed enums must have a matching value") 853 `) 854 } else { 855 pf(` 856 default: 857 return nil 858 `) 859 860 } 861 862 pf("\t}\n") 863 pf("}\n\n") 864 865 return nil 866} 867 868func (ts *TypeSchema) writeCborMarshalerEnum(name string, w io.Writer) error { 869 pf := printerf(w) 870 pf("func (t *%s) MarshalCBOR(w io.Writer) error {\n", name) 871 pf(` 872 if t == nil { 873 _, err := w.Write(cbg.CborNull) 874 return err 875 } 876`) 877 878 for _, e := range ts.Refs { 879 vname, _ := ts.namesFromRef(e) 880 pf("\tif t.%s != nil {\n", vname) 881 pf("\t\treturn t.%s.MarshalCBOR(w)\n\t}\n", vname) 882 } 883 884 pf("\treturn fmt.Errorf(\"cannot cbor marshal empty enum\")\n}\n") 885 return nil 886} 887 888func (ts *TypeSchema) writeCborUnmarshalerEnum(name string, w io.Writer) error { 889 pf := printerf(w) 890 pf("func (t *%s) UnmarshalCBOR(r io.Reader) error {\n", name) 891 pf("\ttyp, b, err := util.CborTypeExtractReader(r)\n") 892 pf("\tif err != nil {\n\t\treturn err\n\t}\n\n") 893 pf("\tswitch typ {\n") 894 for _, e := range ts.Refs { 895 if strings.HasPrefix(e, "#") { 896 e = ts.id + e 897 } 898 899 vname, goname := ts.namesFromRef(e) 900 901 pf("\t\tcase \"%s\":\n", e) 902 pf("\t\t\tt.%s = new(%s)\n", vname, goname) 903 pf("\t\t\treturn t.%s.UnmarshalCBOR(bytes.NewReader(b))\n", vname) 904 } 905 906 if ts.Closed { 907 pf(` 908 default: 909 return fmt.Errorf("closed enums must have a matching value") 910 `) 911 } else { 912 pf(` 913 default: 914 return nil 915 `) 916 917 } 918 919 pf("\t}\n") 920 pf("}\n\n") 921 922 return nil 923}