forked from
hailey.at/cocoon
An atproto PDS written in Go
1package server
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "io"
9 "sync"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/atcrypto"
15 "github.com/bluesky-social/indigo/atproto/atdata"
16 atp "github.com/bluesky-social/indigo/atproto/repo"
17 "github.com/bluesky-social/indigo/atproto/repo/mst"
18 "github.com/bluesky-social/indigo/atproto/syntax"
19 "github.com/bluesky-social/indigo/carstore"
20 "github.com/bluesky-social/indigo/events"
21 lexutil "github.com/bluesky-social/indigo/lex/util"
22 "github.com/haileyok/cocoon/internal/db"
23 "github.com/haileyok/cocoon/metrics"
24 "github.com/haileyok/cocoon/models"
25 "github.com/haileyok/cocoon/recording_blockstore"
26 blocks "github.com/ipfs/go-block-format"
27 "github.com/ipfs/go-cid"
28 blockstore "github.com/ipfs/go-ipfs-blockstore"
29 cbor "github.com/ipfs/go-ipld-cbor"
30 "github.com/ipld/go-car"
31 "github.com/multiformats/go-multihash"
32 "gorm.io/gorm/clause"
33)
34
35type cachedRepo struct {
36 mu sync.Mutex
37 repo *atp.Repo
38 root cid.Cid
39}
40
41type RepoMan struct {
42 db *db.DB
43 s *Server
44 clock *syntax.TIDClock
45
46 cacheMu sync.Mutex
47 cache map[string]*cachedRepo
48}
49
50func NewRepoMan(s *Server) *RepoMan {
51 clock := syntax.NewTIDClock(0)
52
53 return &RepoMan{
54 s: s,
55 db: s.db,
56 clock: clock,
57 cache: make(map[string]*cachedRepo),
58 }
59}
60
61func (rm *RepoMan) withRepo(ctx context.Context, did string, rootCid cid.Cid, fn func(r *atp.Repo) (newRoot cid.Cid, err error)) error {
62 rm.cacheMu.Lock()
63 cr, ok := rm.cache[did]
64 if !ok {
65 cr = &cachedRepo{}
66 rm.cache[did] = cr
67 }
68 rm.cacheMu.Unlock()
69
70 cr.mu.Lock()
71 defer cr.mu.Unlock()
72
73 if cr.repo == nil || cr.root != rootCid {
74 bs := rm.s.getBlockstore(did)
75 r, err := openRepo(ctx, bs, rootCid, did)
76 if err != nil {
77 return err
78 }
79 cr.repo = r
80 cr.root = rootCid
81 }
82
83 newRoot, err := fn(cr.repo)
84 if err != nil {
85 // invalidate on error since the tree may be partially mutated
86 cr.repo = nil
87 cr.root = cid.Undef
88 return err
89 }
90
91 cr.root = newRoot
92 return nil
93}
94
95type OpType string
96
97var (
98 OpTypeCreate = OpType("com.atproto.repo.applyWrites#create")
99 OpTypeUpdate = OpType("com.atproto.repo.applyWrites#update")
100 OpTypeDelete = OpType("com.atproto.repo.applyWrites#delete")
101)
102
103func (ot OpType) String() string {
104 return string(ot)
105}
106
107type Op struct {
108 Type OpType `json:"$type"`
109 Collection string `json:"collection"`
110 Rkey *string `json:"rkey,omitempty"`
111 Validate *bool `json:"validate,omitempty"`
112 SwapRecord *string `json:"swapRecord,omitempty"`
113 Record *MarshalableMap `json:"record,omitempty"`
114}
115
116type MarshalableMap map[string]any
117
118type FirehoseOp struct {
119 Cid cid.Cid
120 Path string
121 Action string
122}
123
124func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error {
125 data, err := atdata.MarshalCBOR(*mm)
126 if err != nil {
127 return err
128 }
129
130 w.Write(data)
131
132 return nil
133}
134
135type ApplyWriteResult struct {
136 Type *string `json:"$type,omitempty"`
137 Uri *string `json:"uri,omitempty"`
138 Cid *string `json:"cid,omitempty"`
139 Commit *RepoCommit `json:"commit,omitempty"`
140 ValidationStatus *string `json:"validationStatus,omitempty"`
141}
142
143type RepoCommit struct {
144 Cid string `json:"cid"`
145 Rev string `json:"rev"`
146}
147
148func openRepo(ctx context.Context, bs blockstore.Blockstore, rootCid cid.Cid, did string) (*atp.Repo, error) {
149 commitBlock, err := bs.Get(ctx, rootCid)
150 if err != nil {
151 return nil, fmt.Errorf("reading commit block: %w", err)
152 }
153
154 var commit atp.Commit
155 if err := commit.UnmarshalCBOR(bytes.NewReader(commitBlock.RawData())); err != nil {
156 return nil, fmt.Errorf("parsing commit block: %w", err)
157 }
158
159 tree, err := mst.LoadTreeFromStore(ctx, bs, commit.Data)
160 if err != nil {
161 return nil, fmt.Errorf("loading MST: %w", err)
162 }
163
164 clk := syntax.ClockFromTID(syntax.TID(commit.Rev))
165 return &atp.Repo{
166 DID: syntax.DID(did),
167 Clock: &clk,
168 MST: *tree,
169 RecordStore: bs,
170 }, nil
171}
172
173func commitRepo(ctx context.Context, bs blockstore.Blockstore, r *atp.Repo, signingKey []byte) (cid.Cid, string, error) {
174 if _, err := r.MST.WriteDiffBlocks(ctx, bs); err != nil {
175 return cid.Undef, "", fmt.Errorf("writing MST blocks: %w", err)
176 }
177
178 commit, err := r.Commit()
179 if err != nil {
180 return cid.Undef, "", fmt.Errorf("creating commit: %w", err)
181 }
182
183 privkey, err := atcrypto.ParsePrivateBytesK256(signingKey)
184 if err != nil {
185 return cid.Undef, "", fmt.Errorf("parsing signing key: %w", err)
186 }
187 if err := commit.Sign(privkey); err != nil {
188 return cid.Undef, "", fmt.Errorf("signing commit: %w", err)
189 }
190
191 buf := new(bytes.Buffer)
192 if err := commit.MarshalCBOR(buf); err != nil {
193 return cid.Undef, "", fmt.Errorf("marshaling commit: %w", err)
194 }
195
196 pref := cid.NewPrefixV1(cid.DagCBOR, multihash.SHA2_256)
197 commitCid, err := pref.Sum(buf.Bytes())
198 if err != nil {
199 return cid.Undef, "", fmt.Errorf("computing commit CID: %w", err)
200 }
201
202 blk, err := blocks.NewBlockWithCid(buf.Bytes(), commitCid)
203 if err != nil {
204 return cid.Undef, "", fmt.Errorf("creating commit block: %w", err)
205 }
206 if err := bs.Put(ctx, blk); err != nil {
207 return cid.Undef, "", fmt.Errorf("writing commit block: %w", err)
208 }
209
210 return commitCid, commit.Rev, nil
211}
212
213func putRecordBlock(ctx context.Context, bs blockstore.Blockstore, rec *MarshalableMap) (cid.Cid, error) {
214 buf := new(bytes.Buffer)
215 if err := rec.MarshalCBOR(buf); err != nil {
216 return cid.Undef, err
217 }
218
219 pref := cid.NewPrefixV1(cid.DagCBOR, multihash.SHA2_256)
220 c, err := pref.Sum(buf.Bytes())
221 if err != nil {
222 return cid.Undef, err
223 }
224
225 blk, err := blocks.NewBlockWithCid(buf.Bytes(), c)
226 if err != nil {
227 return cid.Undef, err
228 }
229 if err := bs.Put(ctx, blk); err != nil {
230 return cid.Undef, err
231 }
232
233 return c, nil
234}
235
236// TODO make use of swap commit
237func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) {
238 rootcid, err := cid.Cast(urepo.Root)
239 if err != nil {
240 return nil, err
241 }
242
243 dbs := rm.s.getBlockstore(urepo.Did)
244 bs := recording_blockstore.New(dbs)
245
246 var results []ApplyWriteResult
247 var ops []*atp.Operation
248 var entries []models.Record
249 var newroot cid.Cid
250 var rev string
251
252 if err := rm.withRepo(ctx, urepo.Did, rootcid, func(r *atp.Repo) (cid.Cid, error) {
253 entries = make([]models.Record, 0, len(writes))
254 for i, op := range writes {
255 // updates or deletes must supply an rkey
256 if op.Type != OpTypeCreate && op.Rkey == nil {
257 return cid.Undef, fmt.Errorf("invalid rkey")
258 } else if op.Type == OpTypeCreate && op.Rkey != nil {
259 // we should convert this op to an update if the rkey already exists
260 path := fmt.Sprintf("%s/%s", op.Collection, *op.Rkey)
261 existing, _ := r.MST.Get([]byte(path))
262 if existing != nil {
263 op.Type = OpTypeUpdate
264 }
265 } else if op.Rkey == nil {
266 // creates that don't supply an rkey will have one generated for them
267 op.Rkey = to.StringPtr(rm.clock.Next().String())
268 writes[i].Rkey = op.Rkey
269 }
270
271 path := fmt.Sprintf("%s/%s", op.Collection, *op.Rkey)
272
273 // validate the record key is actually valid
274 _, err := syntax.ParseRecordKey(*op.Rkey)
275 if err != nil {
276 return cid.Undef, err
277 }
278
279 switch op.Type {
280 case OpTypeCreate:
281 // HACK: this fixes some type conversions, mainly around integers
282 b, err := json.Marshal(*op.Record)
283 if err != nil {
284 return cid.Undef, err
285 }
286 out, err := atdata.UnmarshalJSON(b)
287 if err != nil {
288 return cid.Undef, err
289 }
290 mm := MarshalableMap(out)
291
292 // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection
293 if mm["$type"] == "" {
294 mm["$type"] = op.Collection
295 }
296
297 nc, err := putRecordBlock(ctx, bs, &mm)
298 if err != nil {
299 return cid.Undef, err
300 }
301
302 atpOp, err := atp.ApplyOp(&r.MST, path, &nc)
303 if err != nil {
304 return cid.Undef, err
305 }
306 ops = append(ops, atpOp)
307
308 d, err := atdata.MarshalCBOR(mm)
309 if err != nil {
310 return cid.Undef, err
311 }
312
313 entries = append(entries, models.Record{
314 Did: urepo.Did,
315 CreatedAt: rm.clock.Next().String(),
316 Nsid: op.Collection,
317 Rkey: *op.Rkey,
318 Cid: nc.String(),
319 Value: d,
320 })
321
322 results = append(results, ApplyWriteResult{
323 Type: to.StringPtr(OpTypeCreate.String()),
324 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
325 Cid: to.StringPtr(nc.String()),
326 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
327 })
328 case OpTypeDelete:
329 // try to find the old record in the database
330 var old models.Record
331 if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
332 return cid.Undef, err
333 }
334
335 // TODO: this is really confusing, and looking at it i have no idea why i did this. below when we are doing deletes, we
336 // check if `cid` here is nil to indicate if we should delete. that really doesn't make much sense and its super illogical
337 // when reading this code. i dont feel like fixing right now though so
338 entries = append(entries, models.Record{
339 Did: urepo.Did,
340 Nsid: op.Collection,
341 Rkey: *op.Rkey,
342 Value: old.Value,
343 })
344
345 atpOp, err := atp.ApplyOp(&r.MST, path, nil)
346 if err != nil {
347 return cid.Undef, err
348 }
349 ops = append(ops, atpOp)
350
351 results = append(results, ApplyWriteResult{
352 Type: to.StringPtr(OpTypeDelete.String()),
353 })
354 case OpTypeUpdate:
355 // HACK: same hack as above for type fixes
356 b, err := json.Marshal(*op.Record)
357 if err != nil {
358 return cid.Undef, err
359 }
360 out, err := atdata.UnmarshalJSON(b)
361 if err != nil {
362 return cid.Undef, err
363 }
364 mm := MarshalableMap(out)
365
366 nc, err := putRecordBlock(ctx, bs, &mm)
367 if err != nil {
368 return cid.Undef, err
369 }
370
371 atpOp, err := atp.ApplyOp(&r.MST, path, &nc)
372 if err != nil {
373 return cid.Undef, err
374 }
375 ops = append(ops, atpOp)
376
377 d, err := atdata.MarshalCBOR(mm)
378 if err != nil {
379 return cid.Undef, err
380 }
381
382 entries = append(entries, models.Record{
383 Did: urepo.Did,
384 CreatedAt: rm.clock.Next().String(),
385 Nsid: op.Collection,
386 Rkey: *op.Rkey,
387 Cid: nc.String(),
388 Value: d,
389 })
390
391 results = append(results, ApplyWriteResult{
392 Type: to.StringPtr(OpTypeUpdate.String()),
393 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey),
394 Cid: to.StringPtr(nc.String()),
395 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol
396 })
397 }
398 }
399
400 // commit and get the new root
401 var commitErr error
402 newroot, rev, commitErr = commitRepo(ctx, bs, r, urepo.SigningKey)
403 if commitErr != nil {
404 return cid.Undef, commitErr
405 }
406
407 return newroot, nil
408 }); err != nil {
409 return nil, err
410 }
411
412 for _, result := range results {
413 if result.Type != nil {
414 metrics.RepoOperations.WithLabelValues(*result.Type).Inc()
415 }
416 }
417
418 // create a buffer for dumping our new cbor into
419 buf := new(bytes.Buffer)
420
421 // first write the car header to the buffer
422 hb, err := cbor.DumpObject(&car.CarHeader{
423 Roots: []cid.Cid{newroot},
424 Version: 1,
425 })
426 if _, err := carstore.LdWrite(buf, hb); err != nil {
427 return nil, err
428 }
429
430 // create the repo ops for the firehose from the tracked operations
431 repoOps := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(ops))
432 for _, op := range ops {
433 if op.IsCreate() || op.IsUpdate() {
434 kind := "create"
435 if op.IsUpdate() {
436 kind = "update"
437 }
438
439 ll := lexutil.LexLink(*op.Value)
440 repoOps = append(repoOps, &atproto.SyncSubscribeRepos_RepoOp{
441 Action: kind,
442 Path: op.Path,
443 Cid: &ll,
444 })
445
446 blk, err := dbs.Get(ctx, *op.Value)
447 if err != nil {
448 return nil, err
449 }
450 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
451 return nil, err
452 }
453 } else if op.IsDelete() {
454 ll := lexutil.LexLink(*op.Prev)
455 repoOps = append(repoOps, &atproto.SyncSubscribeRepos_RepoOp{
456 Action: "delete",
457 Path: op.Path,
458 Cid: nil,
459 Prev: &ll,
460 })
461
462 blk, err := dbs.Get(ctx, *op.Prev)
463 if err != nil {
464 return nil, err
465 }
466 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
467 return nil, err
468 }
469 }
470 }
471
472 // write the writelog to the buffer
473 for _, blk := range bs.GetWriteLog() {
474 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil {
475 return nil, err
476 }
477 }
478
479 // blob blob blob blob blob :3
480 var blobs []lexutil.LexLink
481 for _, entry := range entries {
482 var cids []cid.Cid
483 // whenever there is cid present, we know it's a create (dumb)
484 if entry.Cid != "" {
485 if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{
486 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
487 UpdateAll: true,
488 }}).Error; err != nil {
489 return nil, err
490 }
491
492 // increment the given blob refs, yay
493 cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value)
494 if err != nil {
495 return nil, err
496 }
497 } else {
498 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey
499 // is did + collection + rkey. i still really want to separate that out, or use a different type to make
500 // this less confusing/easy to read. alas, its 2 am and yea no
501 if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil {
502 return nil, err
503 }
504
505 // TODO:
506 cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value)
507 if err != nil {
508 return nil, err
509 }
510 }
511
512 // add all the relevant blobs to the blobs list of blobs. blob ^.^
513 for _, c := range cids {
514 blobs = append(blobs, lexutil.LexLink(c))
515 }
516 }
517
518 // NOTE: using the request ctx seems a bit suss here, so using a background context. i'm not sure if this
519 // runs sync or not
520 rm.s.evtman.AddEvent(context.Background(), &events.XRPCStreamEvent{
521 RepoCommit: &atproto.SyncSubscribeRepos_Commit{
522 Repo: urepo.Did,
523 Blocks: buf.Bytes(),
524 Blobs: blobs,
525 Rev: rev,
526 Since: &urepo.Rev,
527 Commit: lexutil.LexLink(newroot),
528 Time: time.Now().Format(time.RFC3339Nano),
529 Ops: repoOps,
530 TooBig: false,
531 },
532 })
533
534 if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil {
535 return nil, err
536 }
537
538 for i := range results {
539 results[i].Type = to.StringPtr(*results[i].Type + "Result")
540 results[i].Commit = &RepoCommit{
541 Cid: newroot.String(),
542 Rev: rev,
543 }
544 }
545
546 return results, nil
547}
548
549func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) {
550 commitCid, err := cid.Cast(urepo.Root)
551 if err != nil {
552 return cid.Undef, nil, err
553 }
554
555 dbs := rm.s.getBlockstore(urepo.Did)
556
557 var proofBlocks []blocks.Block
558 var recordCid *cid.Cid
559
560 if err := rm.withRepo(ctx, urepo.Did, commitCid, func(r *atp.Repo) (cid.Cid, error) {
561 path := collection + "/" + rkey
562
563 // walk the cached in-memory tree to find the record and collect MST node CIDs on the path
564 nodeCIDs := collectPathNodeCIDs(r.MST.Root, []byte(path))
565
566 rc, getErr := r.MST.Get([]byte(path))
567 if getErr != nil {
568 return cid.Undef, getErr
569 }
570 if rc == nil {
571 return cid.Undef, fmt.Errorf("record not found: %s", path)
572 }
573 recordCid = rc
574
575 // read the commit block
576 commitBlk, err := dbs.Get(ctx, commitCid)
577 if err != nil {
578 return cid.Undef, fmt.Errorf("reading commit block for proof: %w", err)
579 }
580 proofBlocks = append(proofBlocks, commitBlk)
581
582 // read the MST nodes on the path
583 for _, nc := range nodeCIDs {
584 blk, err := dbs.Get(ctx, nc)
585 if err != nil {
586 return cid.Undef, fmt.Errorf("reading MST node for proof: %w", err)
587 }
588 proofBlocks = append(proofBlocks, blk)
589 }
590
591 // read the record block
592 recordBlk, err := dbs.Get(ctx, *recordCid)
593 if err != nil {
594 return cid.Undef, fmt.Errorf("reading record block for proof: %w", err)
595 }
596 proofBlocks = append(proofBlocks, recordBlk)
597
598 // read-only, return same root
599 return commitCid, nil
600 }); err != nil {
601 return cid.Undef, nil, err
602 }
603
604 return commitCid, proofBlocks, nil
605}
606
607func collectPathNodeCIDs(n *mst.Node, key []byte) []cid.Cid {
608 if n == nil {
609 return nil
610 }
611
612 var cids []cid.Cid
613 if n.CID != nil {
614 cids = append(cids, *n.CID)
615 }
616
617 height := mst.HeightForKey(key)
618 if height >= n.Height {
619 // key is at or above this level, no need to descend
620 return cids
621 }
622
623 // find the child node that covers this key
624 childIdx := -1
625 for i, e := range n.Entries {
626 if e.IsChild() {
627 childIdx = i
628 continue
629 }
630 if e.IsValue() {
631 if bytes.Compare(key, e.Key) <= 0 {
632 break
633 }
634 childIdx = -1
635 }
636 }
637
638 if childIdx >= 0 && n.Entries[childIdx].Child != nil {
639 cids = append(cids, collectPathNodeCIDs(n.Entries[childIdx].Child, key)...)
640 }
641
642 return cids
643}
644
645func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
646 cids, err := getBlobCidsFromCbor(cbor)
647 if err != nil {
648 return nil, err
649 }
650
651 for _, c := range cids {
652 if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
653 return nil, err
654 }
655 }
656
657 return cids, nil
658}
659
660func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
661 cids, err := getBlobCidsFromCbor(cbor)
662 if err != nil {
663 return nil, err
664 }
665
666 for _, c := range cids {
667 var res struct {
668 ID uint
669 Count int
670 }
671 if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
672 return nil, err
673 }
674
675 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what
676 // storage it is in, and clean up s3!!!!
677 if res.Count == 0 {
678 if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
679 return nil, err
680 }
681 if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
682 return nil, err
683 }
684 }
685 }
686
687 return cids, nil
688}
689
690// to be honest, we could just store both the cbor and non-cbor in []entries above to avoid an additional
691// unmarshal here. this will work for now though
692func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) {
693 var cids []cid.Cid
694
695 decoded, err := atdata.UnmarshalCBOR(cbor)
696 if err != nil {
697 return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
698 }
699
700 var deepiter func(any) error
701 deepiter = func(item any) error {
702 switch val := item.(type) {
703 case map[string]any:
704 if val["$type"] == "blob" {
705 if ref, ok := val["ref"].(string); ok {
706 c, err := cid.Parse(ref)
707 if err != nil {
708 return err
709 }
710 cids = append(cids, c)
711 }
712 for _, v := range val {
713 return deepiter(v)
714 }
715 }
716 case []any:
717 for _, v := range val {
718 deepiter(v)
719 }
720 }
721
722 return nil
723 }
724
725 if err := deepiter(decoded); err != nil {
726 return nil, err
727 }
728
729 return cids, nil
730}