1package util
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "reflect"
8 "strings"
9
10 cbg "github.com/whyrusleeping/cbor-gen"
11)
12
13var lexTypesMap map[string]reflect.Type
14
15func init() {
16 lexTypesMap = make(map[string]reflect.Type)
17 RegisterType("blob", &LexBlob{})
18}
19
20func RegisterType(id string, val cbg.CBORMarshaler) {
21 t := reflect.TypeOf(val)
22
23 if t.Kind() == reflect.Pointer {
24 t = t.Elem()
25 }
26
27 if _, ok := lexTypesMap[id]; ok {
28 panic(fmt.Sprintf("already registered type for %q", id))
29 }
30
31 lexTypesMap[id] = t
32}
33
34func NewFromType(typ string) (interface{}, error) {
35 t, ok := lexTypesMap[typ]
36 if !ok {
37 return nil, fmt.Errorf("%w: %q", ErrUnrecognizedType, typ)
38 }
39 v := reflect.New(t)
40 return v.Interface(), nil
41}
42
43func JsonDecodeValue(b []byte) (any, error) {
44 tstr, err := TypeExtract(b)
45 if err != nil {
46 return nil, err
47 }
48
49 t, ok := lexTypesMap[tstr]
50 if !ok {
51 return nil, fmt.Errorf("%w: %q", ErrUnrecognizedType, tstr)
52 }
53
54 val := reflect.New(t)
55
56 ival := val.Interface()
57 if err := json.Unmarshal(b, ival); err != nil {
58 return nil, err
59 }
60
61 return ival, nil
62}
63
64type CBOR interface {
65 cbg.CBORUnmarshaler
66 cbg.CBORMarshaler
67}
68
69var ErrUnrecognizedType = fmt.Errorf("unrecognized lexicon type")
70
71func CborDecodeValue(b []byte) (CBOR, error) {
72 tstr, err := CborTypeExtract(b)
73 if err != nil {
74 return nil, fmt.Errorf("cbor type extract: %w", err)
75 }
76
77 t, ok := lexTypesMap[tstr]
78 if !ok {
79 return nil, fmt.Errorf("%w: %q", ErrUnrecognizedType, tstr)
80 }
81
82 val := reflect.New(t)
83
84 ival, ok := val.Interface().(CBOR)
85 if !ok {
86 return nil, fmt.Errorf("registered type did not have proper cbor hooks")
87 }
88
89 if err := ival.UnmarshalCBOR(bytes.NewReader(b)); err != nil {
90 return nil, err
91 }
92
93 return ival, nil
94}
95
96type LexiconTypeDecoder struct {
97 Val cbg.CBORMarshaler
98}
99
100func (ltd *LexiconTypeDecoder) UnmarshalJSON(b []byte) error {
101 val, err := JsonDecodeValue(b)
102 if err != nil {
103 return err
104 }
105
106 ltd.Val = val.(cbg.CBORMarshaler)
107
108 return nil
109}
110
111func (ltd *LexiconTypeDecoder) MarshalJSON() ([]byte, error) {
112 if ltd == nil || ltd.Val == nil {
113 return nil, fmt.Errorf("LexiconTypeDecoder MarshalJSON called on a nil")
114 }
115 v := reflect.ValueOf(ltd.Val)
116 t := v.Type()
117 sf, ok := t.Elem().FieldByName("LexiconTypeID")
118 if !ok {
119 return nil, fmt.Errorf("lexicon type decoder can only handle record fields")
120 }
121
122 tag, ok := sf.Tag.Lookup("cborgen")
123 if !ok {
124 return nil, fmt.Errorf("lexicon type decoder can only handle record fields with const $type")
125 }
126
127 parts := strings.Split(tag, ",")
128
129 var cval string
130 for _, p := range parts {
131 if strings.HasPrefix(p, "const=") {
132 cval = strings.TrimPrefix(p, "const=")
133 break
134 }
135 }
136 if cval == "" {
137 return nil, fmt.Errorf("must have const $type field")
138 }
139
140 v.Elem().FieldByName("LexiconTypeID").SetString(cval)
141
142 return json.Marshal(ltd.Val)
143}