Margin is an open annotation layer for the internet. Powered by the AT Protocol.
margin.at
extension
web
atproto
comments
1package crypto
2
3import (
4 "bytes"
5 "encoding/json"
6 "fmt"
7 "sort"
8 "strings"
9
10 "github.com/fxamacker/cbor/v2"
11 "github.com/ipfs/go-cid"
12 "github.com/multiformats/go-multihash"
13)
14
15const (
16 DagCBORCodec = 0x71
17 SHA256Code = multihash.SHA2_256
18)
19
20type CIDVerificationError struct {
21 ExpectedCID string
22 ComputedCID string
23 RecordURI string
24}
25
26func (e *CIDVerificationError) Error() string {
27 return fmt.Sprintf("CID verification failed for %s: expected %s, computed %s",
28 e.RecordURI, e.ExpectedCID, e.ComputedCID)
29}
30
31func VerifyRecordCID(recordJSON json.RawMessage, expectedCID string, recordURI string) error {
32 if expectedCID == "" {
33 return nil
34 }
35
36 expectedC, err := cid.Decode(expectedCID)
37 if err != nil {
38 return fmt.Errorf("invalid CID format: %w", err)
39 }
40
41 cborBytes, err := jsonToDAGCBOR(recordJSON)
42 if err != nil {
43 return fmt.Errorf("failed to encode as DAG-CBOR: %w", err)
44 }
45
46 mh, err := multihash.Sum(cborBytes, SHA256Code, -1)
47 if err != nil {
48 return fmt.Errorf("failed to compute hash: %w", err)
49 }
50
51 computedC := cid.NewCidV1(DagCBORCodec, mh)
52
53 if !expectedC.Equals(computedC) {
54 return &CIDVerificationError{
55 ExpectedCID: expectedCID,
56 ComputedCID: computedC.String(),
57 RecordURI: recordURI,
58 }
59 }
60
61 return nil
62}
63
64func jsonToDAGCBOR(jsonData json.RawMessage) ([]byte, error) {
65 var data interface{}
66 if err := json.Unmarshal(jsonData, &data); err != nil {
67 return nil, err
68 }
69
70 processed := processValue(data)
71
72 encMode, err := cbor.CanonicalEncOptions().EncMode()
73 if err != nil {
74 return nil, err
75 }
76
77 return encMode.Marshal(processed)
78}
79
80func processValue(v interface{}) interface{} {
81 switch val := v.(type) {
82 case map[string]interface{}:
83 return processMap(val)
84 case []interface{}:
85 result := make([]interface{}, len(val))
86 for i, item := range val {
87 result[i] = processValue(item)
88 }
89 return result
90 case float64:
91 if val == float64(int64(val)) {
92 return int64(val)
93 }
94 return val
95 case string:
96 return val
97 default:
98 return val
99 }
100}
101
102func processMap(m map[string]interface{}) interface{} {
103 if link, ok := m["$link"].(string); ok && len(m) == 1 {
104 c, err := cid.Decode(link)
105 if err == nil {
106 return cbor.Tag{
107 Number: 42,
108 Content: append([]byte{0x00}, c.Bytes()...),
109 }
110 }
111 }
112
113 if bytesStr, ok := m["$bytes"].(string); ok && len(m) == 1 {
114 bytesStr = strings.TrimRight(bytesStr, "=")
115 decoded := decodeBase64(bytesStr)
116 if decoded != nil {
117 return decoded
118 }
119 }
120
121 keys := make([]string, 0, len(m))
122 for k := range m {
123 keys = append(keys, k)
124 }
125 sort.Strings(keys)
126
127 result := make(map[string]interface{}, len(m))
128 for _, k := range keys {
129 result[k] = processValue(m[k])
130 }
131
132 return result
133}
134
135func decodeBase64(s string) []byte {
136 switch len(s) % 4 {
137 case 2:
138 s += "=="
139 case 3:
140 s += "="
141 }
142
143 decoded := make([]byte, len(s))
144 n := 0
145 for i := 0; i < len(s); i += 4 {
146 if i+4 > len(s) {
147 break
148 }
149 chunk := s[i : i+4]
150 val := uint32(0)
151 for _, c := range chunk {
152 var v byte
153 switch {
154 case c >= 'A' && c <= 'Z':
155 v = byte(c - 'A')
156 case c >= 'a' && c <= 'z':
157 v = byte(c - 'a' + 26)
158 case c >= '0' && c <= '9':
159 v = byte(c - '0' + 52)
160 case c == '+' || c == '-':
161 v = 62
162 case c == '/' || c == '_':
163 v = 63
164 case c == '=':
165 v = 0
166 default:
167 return nil
168 }
169 val = val<<6 | uint32(v)
170 }
171 decoded[n] = byte(val >> 16)
172 n++
173 if chunk[2] != '=' {
174 decoded[n] = byte(val >> 8)
175 n++
176 }
177 if chunk[3] != '=' {
178 decoded[n] = byte(val)
179 n++
180 }
181 }
182 return decoded[:n]
183}
184
185func VerifyRecordCIDBatch(records []struct {
186 JSON json.RawMessage
187 CID string
188 URI string
189}) []error {
190 var errors []error
191 for _, r := range records {
192 if err := VerifyRecordCID(r.JSON, r.CID, r.URI); err != nil {
193 errors = append(errors, err)
194 }
195 }
196 return errors
197}
198
199func MustVerifyRecordCID(recordJSON json.RawMessage, expectedCID string, recordURI string) bool {
200 return VerifyRecordCID(recordJSON, expectedCID, recordURI) == nil
201}
202
203func ComputeRecordCID(recordJSON json.RawMessage) (string, error) {
204 cborBytes, err := jsonToDAGCBOR(recordJSON)
205 if err != nil {
206 return "", fmt.Errorf("failed to encode as DAG-CBOR: %w", err)
207 }
208
209 mh, err := multihash.Sum(cborBytes, SHA256Code, -1)
210 if err != nil {
211 return "", fmt.Errorf("failed to compute hash: %w", err)
212 }
213
214 c := cid.NewCidV1(DagCBORCodec, mh)
215 return c.String(), nil
216}
217
218func CompareRecordBytes(a, b json.RawMessage) (bool, error) {
219 cborA, err := jsonToDAGCBOR(a)
220 if err != nil {
221 return false, err
222 }
223 cborB, err := jsonToDAGCBOR(b)
224 if err != nil {
225 return false, err
226 }
227 return bytes.Equal(cborA, cborB), nil
228}