1package dbpersist
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "sync"
10 "time"
11
12 comatproto "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/bluesky-social/indigo/carstore"
14 "github.com/bluesky-social/indigo/events"
15 lexutil "github.com/bluesky-social/indigo/lex/util"
16 "github.com/bluesky-social/indigo/models"
17 "github.com/bluesky-social/indigo/util"
18 arc "github.com/hashicorp/golang-lru/arc/v2"
19
20 cid "github.com/ipfs/go-cid"
21 "gorm.io/gorm"
22)
23
24var log = slog.Default().With("system", "dbpersist")
25
26type PersistenceBatchItem struct {
27 Record *RepoEventRecord
28 Event *events.XRPCStreamEvent
29}
30
31type Options struct {
32 MaxBatchSize int
33 MinBatchSize int
34 MaxTimeBetweenFlush time.Duration
35 CheckBatchInterval time.Duration
36 UIDCacheSize int
37 DIDCacheSize int
38 PlaybackBatchSize int
39 HydrationConcurrency int
40}
41
42func DefaultOptions() *Options {
43 return &Options{
44 MaxBatchSize: 200,
45 MinBatchSize: 10,
46 MaxTimeBetweenFlush: 500 * time.Millisecond,
47 CheckBatchInterval: 100 * time.Millisecond,
48 UIDCacheSize: 10000,
49 DIDCacheSize: 10000,
50 PlaybackBatchSize: 500,
51 HydrationConcurrency: 10,
52 }
53}
54
55type DbPersistence struct {
56 db *gorm.DB
57
58 cs carstore.CarStore
59
60 lk sync.Mutex
61
62 broadcast func(*events.XRPCStreamEvent)
63
64 batch []*PersistenceBatchItem
65 batchOptions Options
66 lastFlush time.Time
67
68 uidCache *arc.ARCCache[models.Uid, string]
69 didCache *arc.ARCCache[string, models.Uid]
70}
71
72type RepoEventRecord struct {
73 Seq uint `gorm:"primarykey"`
74 Rev string
75 Since *string
76 Commit *models.DbCID
77 Prev *models.DbCID
78 NewHandle *string // NewHandle is only set if this is a handle change event
79
80 Time time.Time
81 Blobs []byte
82 Repo models.Uid
83 Type string
84 Rebase bool
85
86 // Active and Status are only set on RepoAccount events
87 Active bool
88 Status *string
89
90 Ops []byte
91}
92
93func NewDbPersistence(db *gorm.DB, cs carstore.CarStore, options *Options) (*DbPersistence, error) {
94 if err := db.AutoMigrate(&RepoEventRecord{}); err != nil {
95 return nil, err
96 }
97
98 if options == nil {
99 options = DefaultOptions()
100 }
101
102 uidCache, err := arc.NewARC[models.Uid, string](options.UIDCacheSize)
103 if err != nil {
104 return nil, fmt.Errorf("failed to create uid cache: %w", err)
105 }
106
107 didCache, err := arc.NewARC[string, models.Uid](options.DIDCacheSize)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create did cache: %w", err)
110 }
111
112 p := DbPersistence{
113 db: db,
114 cs: cs,
115 batchOptions: *options,
116 batch: []*PersistenceBatchItem{},
117 uidCache: uidCache,
118 didCache: didCache,
119 }
120
121 go p.batchFlusher()
122
123 return &p, nil
124}
125
126func (p *DbPersistence) batchFlusher() {
127 for {
128 time.Sleep(p.batchOptions.CheckBatchInterval)
129
130 p.lk.Lock()
131 needsFlush := len(p.batch) > 0 &&
132 (len(p.batch) >= p.batchOptions.MinBatchSize ||
133 time.Since(p.lastFlush) >= p.batchOptions.MaxTimeBetweenFlush)
134 p.lk.Unlock()
135
136 if needsFlush {
137 if err := p.Flush(context.Background()); err != nil {
138 log.Error("failed to flush batch", "err", err)
139 }
140 }
141 }
142}
143
144func (p *DbPersistence) SetEventBroadcaster(brc func(*events.XRPCStreamEvent)) {
145 p.broadcast = brc
146}
147
148func (p *DbPersistence) Flush(ctx context.Context) error {
149 p.lk.Lock()
150 defer p.lk.Unlock()
151 return p.flushBatchLocked(ctx)
152}
153
154func (p *DbPersistence) flushBatchLocked(ctx context.Context) error {
155 // TODO: we technically don't need to hold the lock through the database
156 // operation, all we need to do is swap the batch out, and ensure nobody
157 // else tries to enter this function to flush another batch while we are
158 // flushing. I'll leave that for a later optimization
159
160 records := make([]*RepoEventRecord, len(p.batch))
161 for i, item := range p.batch {
162 records[i] = item.Record
163 }
164
165 if err := p.db.CreateInBatches(records, 50).Error; err != nil {
166 return fmt.Errorf("failed to create records: %w", err)
167 }
168
169 for i, item := range records {
170 e := p.batch[i].Event
171 switch {
172 case e.RepoCommit != nil:
173 e.RepoCommit.Seq = int64(item.Seq)
174 case e.RepoSync != nil:
175 e.RepoSync.Seq = int64(item.Seq)
176 case e.RepoIdentity != nil:
177 e.RepoIdentity.Seq = int64(item.Seq)
178 case e.RepoAccount != nil:
179 e.RepoAccount.Seq = int64(item.Seq)
180 default:
181 return fmt.Errorf("unknown event type")
182 }
183 p.broadcast(e)
184 }
185
186 p.batch = []*PersistenceBatchItem{}
187 p.lastFlush = time.Now()
188
189 return nil
190}
191
192func (p *DbPersistence) AddItemToBatch(ctx context.Context, rec *RepoEventRecord, evt *events.XRPCStreamEvent) error {
193 p.lk.Lock()
194 defer p.lk.Unlock()
195 p.batch = append(p.batch, &PersistenceBatchItem{
196 Record: rec,
197 Event: evt,
198 })
199
200 if len(p.batch) >= p.batchOptions.MaxBatchSize {
201 if err := p.flushBatchLocked(ctx); err != nil {
202 return fmt.Errorf("failed to flush batch at max size: %w", err)
203 }
204 }
205
206 return nil
207}
208
209func (p *DbPersistence) Persist(ctx context.Context, e *events.XRPCStreamEvent) error {
210 var rer *RepoEventRecord
211 var err error
212
213 switch {
214 case e.RepoCommit != nil:
215 rer, err = p.RecordFromRepoCommit(ctx, e.RepoCommit)
216 if err != nil {
217 return err
218 }
219 case e.RepoSync != nil:
220 rer, err = p.RecordFromRepoSync(ctx, e.RepoSync)
221 if err != nil {
222 return err
223 }
224 case e.RepoIdentity != nil:
225 rer, err = p.RecordFromRepoIdentity(ctx, e.RepoIdentity)
226 if err != nil {
227 return err
228 }
229 case e.RepoAccount != nil:
230 rer, err = p.RecordFromRepoAccount(ctx, e.RepoAccount)
231 if err != nil {
232 return err
233 }
234 default:
235 return nil
236 }
237
238 if err := p.AddItemToBatch(ctx, rer, e); err != nil {
239 return err
240 }
241
242 return nil
243}
244
245func (p *DbPersistence) RecordFromRepoIdentity(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Identity) (*RepoEventRecord, error) {
246 t, err := time.Parse(util.ISO8601, evt.Time)
247 if err != nil {
248 return nil, err
249 }
250
251 uid, err := p.uidForDid(ctx, evt.Did)
252 if err != nil {
253 return nil, err
254 }
255
256 return &RepoEventRecord{
257 Repo: uid,
258 Type: "repo_identity",
259 Time: t,
260 }, nil
261}
262
263func (p *DbPersistence) RecordFromRepoAccount(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Account) (*RepoEventRecord, error) {
264 t, err := time.Parse(util.ISO8601, evt.Time)
265 if err != nil {
266 return nil, err
267 }
268
269 uid, err := p.uidForDid(ctx, evt.Did)
270 if err != nil {
271 return nil, err
272 }
273
274 return &RepoEventRecord{
275 Repo: uid,
276 Type: "repo_account",
277 Time: t,
278 Active: evt.Active,
279 Status: evt.Status,
280 }, nil
281}
282
283func (p *DbPersistence) RecordFromRepoCommit(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Commit) (*RepoEventRecord, error) {
284 // TODO: hack hack hack
285 if len(evt.Ops) > 8192 {
286 log.Error("(VERY BAD) truncating ops field in outgoing event", "len", len(evt.Ops))
287 evt.Ops = evt.Ops[:8192]
288 }
289
290 uid, err := p.uidForDid(ctx, evt.Repo)
291 if err != nil {
292 return nil, err
293 }
294
295 var blobs []byte
296 if len(evt.Blobs) > 0 {
297 b, err := json.Marshal(evt.Blobs)
298 if err != nil {
299 return nil, err
300 }
301 blobs = b
302 }
303
304 t, err := time.Parse(util.ISO8601, evt.Time)
305 if err != nil {
306 return nil, err
307 }
308
309 rer := RepoEventRecord{
310 Commit: &models.DbCID{CID: cid.Cid(evt.Commit)},
311 //Prev
312 Repo: uid,
313 Type: "repo_append", // TODO: refactor to "#commit"? can "rebase" come through this path?
314 Blobs: blobs,
315 Time: t,
316 Rebase: evt.Rebase,
317 Rev: evt.Rev,
318 Since: evt.Since,
319 }
320
321 opsb, err := json.Marshal(evt.Ops)
322 if err != nil {
323 return nil, err
324 }
325 rer.Ops = opsb
326
327 return &rer, nil
328}
329
330func (p *DbPersistence) RecordFromRepoSync(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Sync) (*RepoEventRecord, error) {
331
332 uid, err := p.uidForDid(ctx, evt.Did)
333 if err != nil {
334 return nil, err
335 }
336
337 t, err := time.Parse(util.ISO8601, evt.Time)
338 if err != nil {
339 return nil, err
340 }
341
342 rer := RepoEventRecord{
343 Repo: uid,
344 Type: "repo_sync",
345 Time: t,
346 Rev: evt.Rev,
347 }
348
349 return &rer, nil
350}
351
352func (p *DbPersistence) Playback(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error) error {
353 pageSize := 1000
354
355 for {
356 rows, err := p.db.Model(&RepoEventRecord{}).Where("seq > ?", since).Order("seq asc").Limit(pageSize).Rows()
357 if err != nil {
358 return err
359 }
360 defer rows.Close()
361
362 hasRows := false
363
364 batch := make([]*RepoEventRecord, 0, p.batchOptions.PlaybackBatchSize)
365 for rows.Next() {
366 hasRows = true
367
368 var evt RepoEventRecord
369 if err := p.db.ScanRows(rows, &evt); err != nil {
370 return err
371 }
372
373 // Advance the since cursor
374 since = int64(evt.Seq)
375
376 batch = append(batch, &evt)
377
378 if len(batch) >= p.batchOptions.PlaybackBatchSize {
379 if err := p.hydrateBatch(ctx, batch, cb); err != nil {
380 return err
381 }
382
383 batch = batch[:0]
384 }
385 }
386
387 if len(batch) > 0 {
388 if err := p.hydrateBatch(ctx, batch, cb); err != nil {
389 return err
390 }
391 }
392
393 if !hasRows {
394 break
395 }
396 }
397
398 return nil
399}
400
401func (p *DbPersistence) hydrateBatch(ctx context.Context, batch []*RepoEventRecord, cb func(*events.XRPCStreamEvent) error) error {
402 evts := make([]*events.XRPCStreamEvent, len(batch))
403
404 type Result struct {
405 Event *events.XRPCStreamEvent
406 Index int
407 Err error
408 }
409
410 resultChan := make(chan Result, len(batch))
411
412 // Semaphore pattern for limiting concurrent goroutines
413 sem := make(chan struct{}, p.batchOptions.HydrationConcurrency)
414 var wg sync.WaitGroup
415
416 for i, record := range batch {
417 wg.Add(1)
418 go func(i int, record *RepoEventRecord) {
419 defer wg.Done()
420 sem <- struct{}{}
421 // release the semaphore at the end of the goroutine
422 defer func() { <-sem }()
423
424 var streamEvent *events.XRPCStreamEvent
425 var err error
426
427 switch {
428 case record.Commit != nil:
429 streamEvent, err = p.hydrateCommit(ctx, record)
430 case record.Type == "repo_sync":
431 streamEvent, err = p.hydrateSyncEvent(ctx, record)
432 case record.Type == "repo_identity":
433 streamEvent, err = p.hydrateIdentityEvent(ctx, record)
434 case record.Type == "repo_account":
435 streamEvent, err = p.hydrateAccountEvent(ctx, record)
436 default:
437 err = fmt.Errorf("unknown event type: %s", record.Type)
438 }
439
440 resultChan <- Result{Event: streamEvent, Index: i, Err: err}
441
442 }(i, record)
443 }
444
445 go func() {
446 wg.Wait()
447 close(resultChan)
448 }()
449
450 cur := 0
451 for result := range resultChan {
452 if result.Err != nil {
453 return result.Err
454 }
455
456 evts[result.Index] = result.Event
457
458 for ; cur < len(evts) && evts[cur] != nil; cur++ {
459 if err := cb(evts[cur]); err != nil {
460 return err
461 }
462 }
463 }
464
465 return nil
466}
467
468func (p *DbPersistence) uidForDid(ctx context.Context, did string) (models.Uid, error) {
469 if uid, ok := p.didCache.Get(did); ok {
470 return uid, nil
471 }
472
473 var u models.ActorInfo
474 if err := p.db.First(&u, "did = ?", did).Error; err != nil {
475 return 0, err
476 }
477
478 p.didCache.Add(did, u.Uid)
479
480 return u.Uid, nil
481}
482
483func (p *DbPersistence) didForUid(ctx context.Context, uid models.Uid) (string, error) {
484 if did, ok := p.uidCache.Get(uid); ok {
485 return did, nil
486 }
487
488 var u models.ActorInfo
489 if err := p.db.First(&u, "uid = ?", uid).Error; err != nil {
490 return "", err
491 }
492
493 p.uidCache.Add(uid, u.Did)
494
495 return u.Did, nil
496}
497
498func (p *DbPersistence) hydrateIdentityEvent(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) {
499 did, err := p.didForUid(ctx, rer.Repo)
500 if err != nil {
501 return nil, err
502 }
503
504 return &events.XRPCStreamEvent{
505 RepoIdentity: &comatproto.SyncSubscribeRepos_Identity{
506 Did: did,
507 Time: rer.Time.Format(util.ISO8601),
508 },
509 }, nil
510}
511
512func (p *DbPersistence) hydrateAccountEvent(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) {
513 did, err := p.didForUid(ctx, rer.Repo)
514 if err != nil {
515 return nil, err
516 }
517
518 return &events.XRPCStreamEvent{
519 RepoAccount: &comatproto.SyncSubscribeRepos_Account{
520 Did: did,
521 Time: rer.Time.Format(util.ISO8601),
522 Active: rer.Active,
523 Status: rer.Status,
524 },
525 }, nil
526}
527
528func (p *DbPersistence) hydrateCommit(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) {
529 if rer.Commit == nil {
530 return nil, fmt.Errorf("commit is nil")
531 }
532
533 var blobs []string
534 if len(rer.Blobs) > 0 {
535 if err := json.Unmarshal(rer.Blobs, &blobs); err != nil {
536 return nil, err
537 }
538 }
539 var blobCIDs []lexutil.LexLink
540 for _, b := range blobs {
541 c, err := cid.Decode(b)
542 if err != nil {
543 return nil, err
544 }
545 blobCIDs = append(blobCIDs, lexutil.LexLink(c))
546 }
547
548 did, err := p.didForUid(ctx, rer.Repo)
549 if err != nil {
550 return nil, err
551 }
552
553 var ops []*comatproto.SyncSubscribeRepos_RepoOp
554 if err := json.Unmarshal(rer.Ops, &ops); err != nil {
555 return nil, err
556 }
557
558 out := &comatproto.SyncSubscribeRepos_Commit{
559 Seq: int64(rer.Seq),
560 Repo: did,
561 Commit: lexutil.LexLink(rer.Commit.CID),
562 Time: rer.Time.Format(util.ISO8601),
563 Blobs: blobCIDs,
564 Rebase: rer.Rebase,
565 Ops: ops,
566 Rev: rer.Rev,
567 Since: rer.Since,
568 }
569
570 cs, err := p.readCarSlice(ctx, rer)
571 if err != nil {
572 return nil, fmt.Errorf("read car slice (%s): %w", rer.Commit.CID, err)
573 }
574
575 if len(cs) > carstore.MaxSliceLength {
576 out.TooBig = true
577 out.Blocks = []byte{}
578 } else {
579 out.Blocks = cs
580 }
581
582 return &events.XRPCStreamEvent{RepoCommit: out}, nil
583}
584
585func (p *DbPersistence) hydrateSyncEvent(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) {
586
587 did, err := p.didForUid(ctx, rer.Repo)
588 if err != nil {
589 return nil, err
590 }
591
592 evt := &comatproto.SyncSubscribeRepos_Sync{
593 Seq: int64(rer.Seq),
594 Did: did,
595 Time: rer.Time.Format(util.ISO8601),
596 Rev: rer.Rev,
597 }
598
599 cs, err := p.readCarSlice(ctx, rer)
600 if err != nil {
601 return nil, fmt.Errorf("read car slice: %w", err)
602 }
603 evt.Blocks = cs
604
605 return &events.XRPCStreamEvent{RepoSync: evt}, nil
606}
607
608func (p *DbPersistence) readCarSlice(ctx context.Context, rer *RepoEventRecord) ([]byte, error) {
609
610 buf := new(bytes.Buffer)
611 if err := p.cs.ReadUserCar(ctx, rer.Repo, rer.Rev, true, buf); err != nil {
612 return nil, err
613 }
614
615 return buf.Bytes(), nil
616}
617
618func (p *DbPersistence) TakeDownRepo(ctx context.Context, usr models.Uid) error {
619 return p.deleteAllEventsForUser(ctx, usr)
620}
621
622func (p *DbPersistence) deleteAllEventsForUser(ctx context.Context, usr models.Uid) error {
623 if err := p.db.Where("repo = ?", usr).Delete(&RepoEventRecord{}).Error; err != nil {
624 return err
625 }
626
627 return nil
628}
629
630func (p *DbPersistence) Shutdown(context.Context) error {
631 return nil
632}