Margin is an open annotation layer for the internet. Powered by the AT Protocol. margin.at
extension web atproto comments
at main 228 lines 4.6 kB view raw
1package crypto 2 3import ( 4 "bytes" 5 "encoding/json" 6 "fmt" 7 "sort" 8 "strings" 9 10 "github.com/fxamacker/cbor/v2" 11 "github.com/ipfs/go-cid" 12 "github.com/multiformats/go-multihash" 13) 14 15const ( 16 DagCBORCodec = 0x71 17 SHA256Code = multihash.SHA2_256 18) 19 20type CIDVerificationError struct { 21 ExpectedCID string 22 ComputedCID string 23 RecordURI string 24} 25 26func (e *CIDVerificationError) Error() string { 27 return fmt.Sprintf("CID verification failed for %s: expected %s, computed %s", 28 e.RecordURI, e.ExpectedCID, e.ComputedCID) 29} 30 31func VerifyRecordCID(recordJSON json.RawMessage, expectedCID string, recordURI string) error { 32 if expectedCID == "" { 33 return nil 34 } 35 36 expectedC, err := cid.Decode(expectedCID) 37 if err != nil { 38 return fmt.Errorf("invalid CID format: %w", err) 39 } 40 41 cborBytes, err := jsonToDAGCBOR(recordJSON) 42 if err != nil { 43 return fmt.Errorf("failed to encode as DAG-CBOR: %w", err) 44 } 45 46 mh, err := multihash.Sum(cborBytes, SHA256Code, -1) 47 if err != nil { 48 return fmt.Errorf("failed to compute hash: %w", err) 49 } 50 51 computedC := cid.NewCidV1(DagCBORCodec, mh) 52 53 if !expectedC.Equals(computedC) { 54 return &CIDVerificationError{ 55 ExpectedCID: expectedCID, 56 ComputedCID: computedC.String(), 57 RecordURI: recordURI, 58 } 59 } 60 61 return nil 62} 63 64func jsonToDAGCBOR(jsonData json.RawMessage) ([]byte, error) { 65 var data interface{} 66 if err := json.Unmarshal(jsonData, &data); err != nil { 67 return nil, err 68 } 69 70 processed := processValue(data) 71 72 encMode, err := cbor.CanonicalEncOptions().EncMode() 73 if err != nil { 74 return nil, err 75 } 76 77 return encMode.Marshal(processed) 78} 79 80func processValue(v interface{}) interface{} { 81 switch val := v.(type) { 82 case map[string]interface{}: 83 return processMap(val) 84 case []interface{}: 85 result := make([]interface{}, len(val)) 86 for i, item := range val { 87 result[i] = processValue(item) 88 } 89 return result 90 case float64: 91 if val == float64(int64(val)) { 92 return int64(val) 93 } 94 return val 95 case string: 96 return val 97 default: 98 return val 99 } 100} 101 102func processMap(m map[string]interface{}) interface{} { 103 if link, ok := m["$link"].(string); ok && len(m) == 1 { 104 c, err := cid.Decode(link) 105 if err == nil { 106 return cbor.Tag{ 107 Number: 42, 108 Content: append([]byte{0x00}, c.Bytes()...), 109 } 110 } 111 } 112 113 if bytesStr, ok := m["$bytes"].(string); ok && len(m) == 1 { 114 bytesStr = strings.TrimRight(bytesStr, "=") 115 decoded := decodeBase64(bytesStr) 116 if decoded != nil { 117 return decoded 118 } 119 } 120 121 keys := make([]string, 0, len(m)) 122 for k := range m { 123 keys = append(keys, k) 124 } 125 sort.Strings(keys) 126 127 result := make(map[string]interface{}, len(m)) 128 for _, k := range keys { 129 result[k] = processValue(m[k]) 130 } 131 132 return result 133} 134 135func decodeBase64(s string) []byte { 136 switch len(s) % 4 { 137 case 2: 138 s += "==" 139 case 3: 140 s += "=" 141 } 142 143 decoded := make([]byte, len(s)) 144 n := 0 145 for i := 0; i < len(s); i += 4 { 146 if i+4 > len(s) { 147 break 148 } 149 chunk := s[i : i+4] 150 val := uint32(0) 151 for _, c := range chunk { 152 var v byte 153 switch { 154 case c >= 'A' && c <= 'Z': 155 v = byte(c - 'A') 156 case c >= 'a' && c <= 'z': 157 v = byte(c - 'a' + 26) 158 case c >= '0' && c <= '9': 159 v = byte(c - '0' + 52) 160 case c == '+' || c == '-': 161 v = 62 162 case c == '/' || c == '_': 163 v = 63 164 case c == '=': 165 v = 0 166 default: 167 return nil 168 } 169 val = val<<6 | uint32(v) 170 } 171 decoded[n] = byte(val >> 16) 172 n++ 173 if chunk[2] != '=' { 174 decoded[n] = byte(val >> 8) 175 n++ 176 } 177 if chunk[3] != '=' { 178 decoded[n] = byte(val) 179 n++ 180 } 181 } 182 return decoded[:n] 183} 184 185func VerifyRecordCIDBatch(records []struct { 186 JSON json.RawMessage 187 CID string 188 URI string 189}) []error { 190 var errors []error 191 for _, r := range records { 192 if err := VerifyRecordCID(r.JSON, r.CID, r.URI); err != nil { 193 errors = append(errors, err) 194 } 195 } 196 return errors 197} 198 199func MustVerifyRecordCID(recordJSON json.RawMessage, expectedCID string, recordURI string) bool { 200 return VerifyRecordCID(recordJSON, expectedCID, recordURI) == nil 201} 202 203func ComputeRecordCID(recordJSON json.RawMessage) (string, error) { 204 cborBytes, err := jsonToDAGCBOR(recordJSON) 205 if err != nil { 206 return "", fmt.Errorf("failed to encode as DAG-CBOR: %w", err) 207 } 208 209 mh, err := multihash.Sum(cborBytes, SHA256Code, -1) 210 if err != nil { 211 return "", fmt.Errorf("failed to compute hash: %w", err) 212 } 213 214 c := cid.NewCidV1(DagCBORCodec, mh) 215 return c.String(), nil 216} 217 218func CompareRecordBytes(a, b json.RawMessage) (bool, error) { 219 cborA, err := jsonToDAGCBOR(a) 220 if err != nil { 221 return false, err 222 } 223 cborB, err := jsonToDAGCBOR(b) 224 if err != nil { 225 return false, err 226 } 227 return bytes.Equal(cborA, cborB), nil 228}