forked from hailey.at/cocoon
An atproto PDS written in Go
1package server 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "io" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/bluesky-social/indigo/api/atproto" 13 "github.com/bluesky-social/indigo/atproto/atdata" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "github.com/bluesky-social/indigo/carstore" 16 "github.com/bluesky-social/indigo/events" 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 "github.com/bluesky-social/indigo/repo" 19 "github.com/haileyok/cocoon/internal/db" 20 "github.com/haileyok/cocoon/metrics" 21 "github.com/haileyok/cocoon/models" 22 "github.com/haileyok/cocoon/recording_blockstore" 23 blocks "github.com/ipfs/go-block-format" 24 "github.com/ipfs/go-cid" 25 cbor "github.com/ipfs/go-ipld-cbor" 26 "github.com/ipld/go-car" 27 "gorm.io/gorm/clause" 28) 29 30type RepoMan struct { 31 db *db.DB 32 s *Server 33 clock *syntax.TIDClock 34} 35 36func NewRepoMan(s *Server) *RepoMan { 37 clock := syntax.NewTIDClock(0) 38 39 return &RepoMan{ 40 s: s, 41 db: s.db, 42 clock: &clock, 43 } 44} 45 46type OpType string 47 48var ( 49 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create") 50 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update") 51 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete") 52) 53 54func (ot OpType) String() string { 55 return string(ot) 56} 57 58type Op struct { 59 Type OpType `json:"$type"` 60 Collection string `json:"collection"` 61 Rkey *string `json:"rkey,omitempty"` 62 Validate *bool `json:"validate,omitempty"` 63 SwapRecord *string `json:"swapRecord,omitempty"` 64 Record *MarshalableMap `json:"record,omitempty"` 65} 66 67type MarshalableMap map[string]any 68 69type FirehoseOp struct { 70 Cid cid.Cid 71 Path string 72 Action string 73} 74 75func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error { 76 data, err := atdata.MarshalCBOR(*mm) 77 if err != nil { 78 return err 79 } 80 81 w.Write(data) 82 83 return nil 84} 85 86type ApplyWriteResult struct { 87 Type *string `json:"$type,omitempty"` 88 Uri *string `json:"uri,omitempty"` 89 Cid *string `json:"cid,omitempty"` 90 Commit *RepoCommit `json:"commit,omitempty"` 91 ValidationStatus *string `json:"validationStatus,omitempty"` 92} 93 94type RepoCommit struct { 95 Cid string `json:"cid"` 96 Rev string `json:"rev"` 97} 98 99// TODO make use of swap commit 100func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 101 rootcid, err := cid.Cast(urepo.Root) 102 if err != nil { 103 return nil, err 104 } 105 106 dbs := rm.s.getBlockstore(urepo.Did) 107 bs := recording_blockstore.New(dbs) 108 r, err := repo.OpenRepo(ctx, bs, rootcid) 109 110 var results []ApplyWriteResult 111 112 entries := make([]models.Record, 0, len(writes)) 113 for i, op := range writes { 114 // updates or deletes must supply an rkey 115 if op.Type != OpTypeCreate && op.Rkey == nil { 116 return nil, fmt.Errorf("invalid rkey") 117 } else if op.Type == OpTypeCreate && op.Rkey != nil { 118 // we should conver this op to an update if the rkey already exists 119 _, _, err := r.GetRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey)) 120 if err == nil { 121 op.Type = OpTypeUpdate 122 } 123 } else if op.Rkey == nil { 124 // creates that don't supply an rkey will have one generated for them 125 op.Rkey = to.StringPtr(rm.clock.Next().String()) 126 writes[i].Rkey = op.Rkey 127 } 128 129 // validate the record key is actually valid 130 _, err := syntax.ParseRecordKey(*op.Rkey) 131 if err != nil { 132 return nil, err 133 } 134 135 switch op.Type { 136 case OpTypeCreate: 137 // HACK: this fixes some type conversions, mainly around integers 138 // first we convert to json bytes 139 b, err := json.Marshal(*op.Record) 140 if err != nil { 141 return nil, err 142 } 143 // then we use atdata.UnmarshalJSON to convert it back to a map 144 out, err := atdata.UnmarshalJSON(b) 145 if err != nil { 146 return nil, err 147 } 148 // finally we can cast to a MarshalableMap 149 mm := MarshalableMap(out) 150 151 // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection 152 // i forget why this is actually necessary? 153 if mm["$type"] == "" { 154 mm["$type"] = op.Collection 155 } 156 157 nc, err := r.PutRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey), &mm) 158 if err != nil { 159 return nil, err 160 } 161 162 d, err := atdata.MarshalCBOR(mm) 163 if err != nil { 164 return nil, err 165 } 166 167 entries = append(entries, models.Record{ 168 Did: urepo.Did, 169 CreatedAt: rm.clock.Next().String(), 170 Nsid: op.Collection, 171 Rkey: *op.Rkey, 172 Cid: nc.String(), 173 Value: d, 174 }) 175 176 results = append(results, ApplyWriteResult{ 177 Type: to.StringPtr(OpTypeCreate.String()), 178 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 179 Cid: to.StringPtr(nc.String()), 180 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 181 }) 182 case OpTypeDelete: 183 // try to find the old record in the database 184 var old models.Record 185 if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 186 return nil, err 187 } 188 189 // TODO: this is really confusing, and looking at it i have no idea why i did this. below when we are doing deletes, we 190 // check if `cid` here is nil to indicate if we should delete. that really doesn't make much sense and its super illogical 191 // when reading this code. i dont feel like fixing right now though so 192 entries = append(entries, models.Record{ 193 Did: urepo.Did, 194 Nsid: op.Collection, 195 Rkey: *op.Rkey, 196 Value: old.Value, 197 }) 198 199 // delete the record from the repo 200 err := r.DeleteRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey)) 201 if err != nil { 202 return nil, err 203 } 204 205 // add a result for the delete 206 results = append(results, ApplyWriteResult{ 207 Type: to.StringPtr(OpTypeDelete.String()), 208 }) 209 case OpTypeUpdate: 210 // HACK: same hack as above for type fixes 211 b, err := json.Marshal(*op.Record) 212 if err != nil { 213 return nil, err 214 } 215 out, err := atdata.UnmarshalJSON(b) 216 if err != nil { 217 return nil, err 218 } 219 mm := MarshalableMap(out) 220 221 nc, err := r.UpdateRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey), &mm) 222 if err != nil { 223 return nil, err 224 } 225 226 d, err := atdata.MarshalCBOR(mm) 227 if err != nil { 228 return nil, err 229 } 230 231 entries = append(entries, models.Record{ 232 Did: urepo.Did, 233 CreatedAt: rm.clock.Next().String(), 234 Nsid: op.Collection, 235 Rkey: *op.Rkey, 236 Cid: nc.String(), 237 Value: d, 238 }) 239 240 results = append(results, ApplyWriteResult{ 241 Type: to.StringPtr(OpTypeUpdate.String()), 242 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), 243 Cid: to.StringPtr(nc.String()), 244 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 245 }) 246 } 247 } 248 249 // commit and get the new root 250 newroot, rev, err := r.Commit(ctx, urepo.SignFor) 251 if err != nil { 252 return nil, err 253 } 254 255 for _, result := range results { 256 if result.Type != nil { 257 metrics.RepoOperations.WithLabelValues(*result.Type).Inc() 258 } 259 } 260 261 // create a buffer for dumping our new cbor into 262 buf := new(bytes.Buffer) 263 264 // first write the car header to the buffer 265 hb, err := cbor.DumpObject(&car.CarHeader{ 266 Roots: []cid.Cid{newroot}, 267 Version: 1, 268 }) 269 if _, err := carstore.LdWrite(buf, hb); err != nil { 270 return nil, err 271 } 272 273 // get a diff of the changes to the repo 274 diffops, err := r.DiffSince(ctx, rootcid) 275 if err != nil { 276 return nil, err 277 } 278 279 // create the repo ops for the given diff 280 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops)) 281 for _, op := range diffops { 282 var c cid.Cid 283 switch op.Op { 284 case "add", "mut": 285 kind := "create" 286 if op.Op == "mut" { 287 kind = "update" 288 } 289 290 c = op.NewCid 291 ll := lexutil.LexLink(op.NewCid) 292 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 293 Action: kind, 294 Path: op.Rpath, 295 Cid: &ll, 296 }) 297 298 case "del": 299 c = op.OldCid 300 ll := lexutil.LexLink(op.OldCid) 301 ops = append(ops, &atproto.SyncSubscribeRepos_RepoOp{ 302 Action: "delete", 303 Path: op.Rpath, 304 Cid: nil, 305 Prev: &ll, 306 }) 307 } 308 309 blk, err := dbs.Get(ctx, c) 310 if err != nil { 311 return nil, err 312 } 313 314 // write the block to the buffer 315 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 316 return nil, err 317 } 318 } 319 320 // write the writelog to the buffer 321 for _, op := range bs.GetWriteLog() { 322 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 323 return nil, err 324 } 325 } 326 327 // blob blob blob blob blob :3 328 var blobs []lexutil.LexLink 329 for _, entry := range entries { 330 var cids []cid.Cid 331 // whenever there is cid present, we know it's a create (dumb) 332 if entry.Cid != "" { 333 if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{ 334 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 335 UpdateAll: true, 336 }}).Error; err != nil { 337 return nil, err 338 } 339 340 // increment the given blob refs, yay 341 cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value) 342 if err != nil { 343 return nil, err 344 } 345 } else { 346 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey 347 // is did + collection + rkey. i still really want to separate that out, or use a different type to make 348 // this less confusing/easy to read. alas, its 2 am and yea no 349 if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil { 350 return nil, err 351 } 352 353 // TODO: 354 cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value) 355 if err != nil { 356 return nil, err 357 } 358 } 359 360 // add all the relevant blobs to the blobs list of blobs. blob ^.^ 361 for _, c := range cids { 362 blobs = append(blobs, lexutil.LexLink(c)) 363 } 364 } 365 366 // NOTE: using the request ctx seems a bit suss here, so using a background context. i'm not sure if this 367 // runs sync or not 368 rm.s.evtman.AddEvent(context.Background(), &events.XRPCStreamEvent{ 369 RepoCommit: &atproto.SyncSubscribeRepos_Commit{ 370 Repo: urepo.Did, 371 Blocks: buf.Bytes(), 372 Blobs: blobs, 373 Rev: rev, 374 Since: &urepo.Rev, 375 Commit: lexutil.LexLink(newroot), 376 Time: time.Now().Format(time.RFC3339Nano), 377 Ops: ops, 378 TooBig: false, 379 }, 380 }) 381 382 if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil { 383 return nil, err 384 } 385 386 for i := range results { 387 results[i].Type = to.StringPtr(*results[i].Type + "Result") 388 results[i].Commit = &RepoCommit{ 389 Cid: newroot.String(), 390 Rev: rev, 391 } 392 } 393 394 return results, nil 395} 396 397// this is a fun little guy. to get a proof, we need to read the record out of the blockstore and record how we actually 398// got to the guy. we'll wrap a new blockstore in a recording blockstore, then return the log for proof 399func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 400 c, err := cid.Cast(urepo.Root) 401 if err != nil { 402 return cid.Undef, nil, err 403 } 404 405 dbs := rm.s.getBlockstore(urepo.Did) 406 bs := recording_blockstore.New(dbs) 407 408 r, err := repo.OpenRepo(ctx, bs, c) 409 if err != nil { 410 return cid.Undef, nil, err 411 } 412 413 _, _, err = r.GetRecordBytes(ctx, fmt.Sprintf("%s/%s", collection, rkey)) 414 if err != nil { 415 return cid.Undef, nil, err 416 } 417 418 return c, bs.GetReadLog(), nil 419} 420 421func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 422 cids, err := getBlobCidsFromCbor(cbor) 423 if err != nil { 424 return nil, err 425 } 426 427 for _, c := range cids { 428 if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 429 return nil, err 430 } 431 } 432 433 return cids, nil 434} 435 436func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 437 cids, err := getBlobCidsFromCbor(cbor) 438 if err != nil { 439 return nil, err 440 } 441 442 for _, c := range cids { 443 var res struct { 444 ID uint 445 Count int 446 } 447 if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 448 return nil, err 449 } 450 451 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what 452 // storage it is in, and clean up s3!!!! 453 if res.Count == 0 { 454 if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 455 return nil, err 456 } 457 if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 458 return nil, err 459 } 460 } 461 } 462 463 return cids, nil 464} 465 466// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional 467// unmarshal here. this will work for now though 468func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) { 469 var cids []cid.Cid 470 471 decoded, err := atdata.UnmarshalCBOR(cbor) 472 if err != nil { 473 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 474 } 475 476 var deepiter func(any) error 477 deepiter = func(item any) error { 478 switch val := item.(type) { 479 case map[string]any: 480 if val["$type"] == "blob" { 481 if ref, ok := val["ref"].(string); ok { 482 c, err := cid.Parse(ref) 483 if err != nil { 484 return err 485 } 486 cids = append(cids, c) 487 } 488 for _, v := range val { 489 return deepiter(v) 490 } 491 } 492 case []any: 493 for _, v := range val { 494 deepiter(v) 495 } 496 } 497 498 return nil 499 } 500 501 if err := deepiter(decoded); err != nil { 502 return nil, err 503 } 504 505 return cids, nil 506}