fork of indigo with slightly nicer lexgen
at main 14 kB view raw
1package dbpersist 2 3import ( 4 "bytes" 5 "context" 6 "encoding/json" 7 "fmt" 8 "log/slog" 9 "sync" 10 "time" 11 12 comatproto "github.com/bluesky-social/indigo/api/atproto" 13 "github.com/bluesky-social/indigo/carstore" 14 "github.com/bluesky-social/indigo/events" 15 lexutil "github.com/bluesky-social/indigo/lex/util" 16 "github.com/bluesky-social/indigo/models" 17 "github.com/bluesky-social/indigo/util" 18 arc "github.com/hashicorp/golang-lru/arc/v2" 19 20 cid "github.com/ipfs/go-cid" 21 "gorm.io/gorm" 22) 23 24var log = slog.Default().With("system", "dbpersist") 25 26type PersistenceBatchItem struct { 27 Record *RepoEventRecord 28 Event *events.XRPCStreamEvent 29} 30 31type Options struct { 32 MaxBatchSize int 33 MinBatchSize int 34 MaxTimeBetweenFlush time.Duration 35 CheckBatchInterval time.Duration 36 UIDCacheSize int 37 DIDCacheSize int 38 PlaybackBatchSize int 39 HydrationConcurrency int 40} 41 42func DefaultOptions() *Options { 43 return &Options{ 44 MaxBatchSize: 200, 45 MinBatchSize: 10, 46 MaxTimeBetweenFlush: 500 * time.Millisecond, 47 CheckBatchInterval: 100 * time.Millisecond, 48 UIDCacheSize: 10000, 49 DIDCacheSize: 10000, 50 PlaybackBatchSize: 500, 51 HydrationConcurrency: 10, 52 } 53} 54 55type DbPersistence struct { 56 db *gorm.DB 57 58 cs carstore.CarStore 59 60 lk sync.Mutex 61 62 broadcast func(*events.XRPCStreamEvent) 63 64 batch []*PersistenceBatchItem 65 batchOptions Options 66 lastFlush time.Time 67 68 uidCache *arc.ARCCache[models.Uid, string] 69 didCache *arc.ARCCache[string, models.Uid] 70} 71 72type RepoEventRecord struct { 73 Seq uint `gorm:"primarykey"` 74 Rev string 75 Since *string 76 Commit *models.DbCID 77 Prev *models.DbCID 78 NewHandle *string // NewHandle is only set if this is a handle change event 79 80 Time time.Time 81 Blobs []byte 82 Repo models.Uid 83 Type string 84 Rebase bool 85 86 // Active and Status are only set on RepoAccount events 87 Active bool 88 Status *string 89 90 Ops []byte 91} 92 93func NewDbPersistence(db *gorm.DB, cs carstore.CarStore, options *Options) (*DbPersistence, error) { 94 if err := db.AutoMigrate(&RepoEventRecord{}); err != nil { 95 return nil, err 96 } 97 98 if options == nil { 99 options = DefaultOptions() 100 } 101 102 uidCache, err := arc.NewARC[models.Uid, string](options.UIDCacheSize) 103 if err != nil { 104 return nil, fmt.Errorf("failed to create uid cache: %w", err) 105 } 106 107 didCache, err := arc.NewARC[string, models.Uid](options.DIDCacheSize) 108 if err != nil { 109 return nil, fmt.Errorf("failed to create did cache: %w", err) 110 } 111 112 p := DbPersistence{ 113 db: db, 114 cs: cs, 115 batchOptions: *options, 116 batch: []*PersistenceBatchItem{}, 117 uidCache: uidCache, 118 didCache: didCache, 119 } 120 121 go p.batchFlusher() 122 123 return &p, nil 124} 125 126func (p *DbPersistence) batchFlusher() { 127 for { 128 time.Sleep(p.batchOptions.CheckBatchInterval) 129 130 p.lk.Lock() 131 needsFlush := len(p.batch) > 0 && 132 (len(p.batch) >= p.batchOptions.MinBatchSize || 133 time.Since(p.lastFlush) >= p.batchOptions.MaxTimeBetweenFlush) 134 p.lk.Unlock() 135 136 if needsFlush { 137 if err := p.Flush(context.Background()); err != nil { 138 log.Error("failed to flush batch", "err", err) 139 } 140 } 141 } 142} 143 144func (p *DbPersistence) SetEventBroadcaster(brc func(*events.XRPCStreamEvent)) { 145 p.broadcast = brc 146} 147 148func (p *DbPersistence) Flush(ctx context.Context) error { 149 p.lk.Lock() 150 defer p.lk.Unlock() 151 return p.flushBatchLocked(ctx) 152} 153 154func (p *DbPersistence) flushBatchLocked(ctx context.Context) error { 155 // TODO: we technically don't need to hold the lock through the database 156 // operation, all we need to do is swap the batch out, and ensure nobody 157 // else tries to enter this function to flush another batch while we are 158 // flushing. I'll leave that for a later optimization 159 160 records := make([]*RepoEventRecord, len(p.batch)) 161 for i, item := range p.batch { 162 records[i] = item.Record 163 } 164 165 if err := p.db.CreateInBatches(records, 50).Error; err != nil { 166 return fmt.Errorf("failed to create records: %w", err) 167 } 168 169 for i, item := range records { 170 e := p.batch[i].Event 171 switch { 172 case e.RepoCommit != nil: 173 e.RepoCommit.Seq = int64(item.Seq) 174 case e.RepoSync != nil: 175 e.RepoSync.Seq = int64(item.Seq) 176 case e.RepoIdentity != nil: 177 e.RepoIdentity.Seq = int64(item.Seq) 178 case e.RepoAccount != nil: 179 e.RepoAccount.Seq = int64(item.Seq) 180 default: 181 return fmt.Errorf("unknown event type") 182 } 183 p.broadcast(e) 184 } 185 186 p.batch = []*PersistenceBatchItem{} 187 p.lastFlush = time.Now() 188 189 return nil 190} 191 192func (p *DbPersistence) AddItemToBatch(ctx context.Context, rec *RepoEventRecord, evt *events.XRPCStreamEvent) error { 193 p.lk.Lock() 194 defer p.lk.Unlock() 195 p.batch = append(p.batch, &PersistenceBatchItem{ 196 Record: rec, 197 Event: evt, 198 }) 199 200 if len(p.batch) >= p.batchOptions.MaxBatchSize { 201 if err := p.flushBatchLocked(ctx); err != nil { 202 return fmt.Errorf("failed to flush batch at max size: %w", err) 203 } 204 } 205 206 return nil 207} 208 209func (p *DbPersistence) Persist(ctx context.Context, e *events.XRPCStreamEvent) error { 210 var rer *RepoEventRecord 211 var err error 212 213 switch { 214 case e.RepoCommit != nil: 215 rer, err = p.RecordFromRepoCommit(ctx, e.RepoCommit) 216 if err != nil { 217 return err 218 } 219 case e.RepoSync != nil: 220 rer, err = p.RecordFromRepoSync(ctx, e.RepoSync) 221 if err != nil { 222 return err 223 } 224 case e.RepoIdentity != nil: 225 rer, err = p.RecordFromRepoIdentity(ctx, e.RepoIdentity) 226 if err != nil { 227 return err 228 } 229 case e.RepoAccount != nil: 230 rer, err = p.RecordFromRepoAccount(ctx, e.RepoAccount) 231 if err != nil { 232 return err 233 } 234 default: 235 return nil 236 } 237 238 if err := p.AddItemToBatch(ctx, rer, e); err != nil { 239 return err 240 } 241 242 return nil 243} 244 245func (p *DbPersistence) RecordFromRepoIdentity(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Identity) (*RepoEventRecord, error) { 246 t, err := time.Parse(util.ISO8601, evt.Time) 247 if err != nil { 248 return nil, err 249 } 250 251 uid, err := p.uidForDid(ctx, evt.Did) 252 if err != nil { 253 return nil, err 254 } 255 256 return &RepoEventRecord{ 257 Repo: uid, 258 Type: "repo_identity", 259 Time: t, 260 }, nil 261} 262 263func (p *DbPersistence) RecordFromRepoAccount(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Account) (*RepoEventRecord, error) { 264 t, err := time.Parse(util.ISO8601, evt.Time) 265 if err != nil { 266 return nil, err 267 } 268 269 uid, err := p.uidForDid(ctx, evt.Did) 270 if err != nil { 271 return nil, err 272 } 273 274 return &RepoEventRecord{ 275 Repo: uid, 276 Type: "repo_account", 277 Time: t, 278 Active: evt.Active, 279 Status: evt.Status, 280 }, nil 281} 282 283func (p *DbPersistence) RecordFromRepoCommit(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Commit) (*RepoEventRecord, error) { 284 // TODO: hack hack hack 285 if len(evt.Ops) > 8192 { 286 log.Error("(VERY BAD) truncating ops field in outgoing event", "len", len(evt.Ops)) 287 evt.Ops = evt.Ops[:8192] 288 } 289 290 uid, err := p.uidForDid(ctx, evt.Repo) 291 if err != nil { 292 return nil, err 293 } 294 295 var blobs []byte 296 if len(evt.Blobs) > 0 { 297 b, err := json.Marshal(evt.Blobs) 298 if err != nil { 299 return nil, err 300 } 301 blobs = b 302 } 303 304 t, err := time.Parse(util.ISO8601, evt.Time) 305 if err != nil { 306 return nil, err 307 } 308 309 rer := RepoEventRecord{ 310 Commit: &models.DbCID{CID: cid.Cid(evt.Commit)}, 311 //Prev 312 Repo: uid, 313 Type: "repo_append", // TODO: refactor to "#commit"? can "rebase" come through this path? 314 Blobs: blobs, 315 Time: t, 316 Rebase: evt.Rebase, 317 Rev: evt.Rev, 318 Since: evt.Since, 319 } 320 321 opsb, err := json.Marshal(evt.Ops) 322 if err != nil { 323 return nil, err 324 } 325 rer.Ops = opsb 326 327 return &rer, nil 328} 329 330func (p *DbPersistence) RecordFromRepoSync(ctx context.Context, evt *comatproto.SyncSubscribeRepos_Sync) (*RepoEventRecord, error) { 331 332 uid, err := p.uidForDid(ctx, evt.Did) 333 if err != nil { 334 return nil, err 335 } 336 337 t, err := time.Parse(util.ISO8601, evt.Time) 338 if err != nil { 339 return nil, err 340 } 341 342 rer := RepoEventRecord{ 343 Repo: uid, 344 Type: "repo_sync", 345 Time: t, 346 Rev: evt.Rev, 347 } 348 349 return &rer, nil 350} 351 352func (p *DbPersistence) Playback(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error) error { 353 pageSize := 1000 354 355 for { 356 rows, err := p.db.Model(&RepoEventRecord{}).Where("seq > ?", since).Order("seq asc").Limit(pageSize).Rows() 357 if err != nil { 358 return err 359 } 360 defer rows.Close() 361 362 hasRows := false 363 364 batch := make([]*RepoEventRecord, 0, p.batchOptions.PlaybackBatchSize) 365 for rows.Next() { 366 hasRows = true 367 368 var evt RepoEventRecord 369 if err := p.db.ScanRows(rows, &evt); err != nil { 370 return err 371 } 372 373 // Advance the since cursor 374 since = int64(evt.Seq) 375 376 batch = append(batch, &evt) 377 378 if len(batch) >= p.batchOptions.PlaybackBatchSize { 379 if err := p.hydrateBatch(ctx, batch, cb); err != nil { 380 return err 381 } 382 383 batch = batch[:0] 384 } 385 } 386 387 if len(batch) > 0 { 388 if err := p.hydrateBatch(ctx, batch, cb); err != nil { 389 return err 390 } 391 } 392 393 if !hasRows { 394 break 395 } 396 } 397 398 return nil 399} 400 401func (p *DbPersistence) hydrateBatch(ctx context.Context, batch []*RepoEventRecord, cb func(*events.XRPCStreamEvent) error) error { 402 evts := make([]*events.XRPCStreamEvent, len(batch)) 403 404 type Result struct { 405 Event *events.XRPCStreamEvent 406 Index int 407 Err error 408 } 409 410 resultChan := make(chan Result, len(batch)) 411 412 // Semaphore pattern for limiting concurrent goroutines 413 sem := make(chan struct{}, p.batchOptions.HydrationConcurrency) 414 var wg sync.WaitGroup 415 416 for i, record := range batch { 417 wg.Add(1) 418 go func(i int, record *RepoEventRecord) { 419 defer wg.Done() 420 sem <- struct{}{} 421 // release the semaphore at the end of the goroutine 422 defer func() { <-sem }() 423 424 var streamEvent *events.XRPCStreamEvent 425 var err error 426 427 switch { 428 case record.Commit != nil: 429 streamEvent, err = p.hydrateCommit(ctx, record) 430 case record.Type == "repo_sync": 431 streamEvent, err = p.hydrateSyncEvent(ctx, record) 432 case record.Type == "repo_identity": 433 streamEvent, err = p.hydrateIdentityEvent(ctx, record) 434 case record.Type == "repo_account": 435 streamEvent, err = p.hydrateAccountEvent(ctx, record) 436 default: 437 err = fmt.Errorf("unknown event type: %s", record.Type) 438 } 439 440 resultChan <- Result{Event: streamEvent, Index: i, Err: err} 441 442 }(i, record) 443 } 444 445 go func() { 446 wg.Wait() 447 close(resultChan) 448 }() 449 450 cur := 0 451 for result := range resultChan { 452 if result.Err != nil { 453 return result.Err 454 } 455 456 evts[result.Index] = result.Event 457 458 for ; cur < len(evts) && evts[cur] != nil; cur++ { 459 if err := cb(evts[cur]); err != nil { 460 return err 461 } 462 } 463 } 464 465 return nil 466} 467 468func (p *DbPersistence) uidForDid(ctx context.Context, did string) (models.Uid, error) { 469 if uid, ok := p.didCache.Get(did); ok { 470 return uid, nil 471 } 472 473 var u models.ActorInfo 474 if err := p.db.First(&u, "did = ?", did).Error; err != nil { 475 return 0, err 476 } 477 478 p.didCache.Add(did, u.Uid) 479 480 return u.Uid, nil 481} 482 483func (p *DbPersistence) didForUid(ctx context.Context, uid models.Uid) (string, error) { 484 if did, ok := p.uidCache.Get(uid); ok { 485 return did, nil 486 } 487 488 var u models.ActorInfo 489 if err := p.db.First(&u, "uid = ?", uid).Error; err != nil { 490 return "", err 491 } 492 493 p.uidCache.Add(uid, u.Did) 494 495 return u.Did, nil 496} 497 498func (p *DbPersistence) hydrateIdentityEvent(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) { 499 did, err := p.didForUid(ctx, rer.Repo) 500 if err != nil { 501 return nil, err 502 } 503 504 return &events.XRPCStreamEvent{ 505 RepoIdentity: &comatproto.SyncSubscribeRepos_Identity{ 506 Did: did, 507 Time: rer.Time.Format(util.ISO8601), 508 }, 509 }, nil 510} 511 512func (p *DbPersistence) hydrateAccountEvent(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) { 513 did, err := p.didForUid(ctx, rer.Repo) 514 if err != nil { 515 return nil, err 516 } 517 518 return &events.XRPCStreamEvent{ 519 RepoAccount: &comatproto.SyncSubscribeRepos_Account{ 520 Did: did, 521 Time: rer.Time.Format(util.ISO8601), 522 Active: rer.Active, 523 Status: rer.Status, 524 }, 525 }, nil 526} 527 528func (p *DbPersistence) hydrateCommit(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) { 529 if rer.Commit == nil { 530 return nil, fmt.Errorf("commit is nil") 531 } 532 533 var blobs []string 534 if len(rer.Blobs) > 0 { 535 if err := json.Unmarshal(rer.Blobs, &blobs); err != nil { 536 return nil, err 537 } 538 } 539 var blobCIDs []lexutil.LexLink 540 for _, b := range blobs { 541 c, err := cid.Decode(b) 542 if err != nil { 543 return nil, err 544 } 545 blobCIDs = append(blobCIDs, lexutil.LexLink(c)) 546 } 547 548 did, err := p.didForUid(ctx, rer.Repo) 549 if err != nil { 550 return nil, err 551 } 552 553 var ops []*comatproto.SyncSubscribeRepos_RepoOp 554 if err := json.Unmarshal(rer.Ops, &ops); err != nil { 555 return nil, err 556 } 557 558 out := &comatproto.SyncSubscribeRepos_Commit{ 559 Seq: int64(rer.Seq), 560 Repo: did, 561 Commit: lexutil.LexLink(rer.Commit.CID), 562 Time: rer.Time.Format(util.ISO8601), 563 Blobs: blobCIDs, 564 Rebase: rer.Rebase, 565 Ops: ops, 566 Rev: rer.Rev, 567 Since: rer.Since, 568 } 569 570 cs, err := p.readCarSlice(ctx, rer) 571 if err != nil { 572 return nil, fmt.Errorf("read car slice (%s): %w", rer.Commit.CID, err) 573 } 574 575 if len(cs) > carstore.MaxSliceLength { 576 out.TooBig = true 577 out.Blocks = []byte{} 578 } else { 579 out.Blocks = cs 580 } 581 582 return &events.XRPCStreamEvent{RepoCommit: out}, nil 583} 584 585func (p *DbPersistence) hydrateSyncEvent(ctx context.Context, rer *RepoEventRecord) (*events.XRPCStreamEvent, error) { 586 587 did, err := p.didForUid(ctx, rer.Repo) 588 if err != nil { 589 return nil, err 590 } 591 592 evt := &comatproto.SyncSubscribeRepos_Sync{ 593 Seq: int64(rer.Seq), 594 Did: did, 595 Time: rer.Time.Format(util.ISO8601), 596 Rev: rer.Rev, 597 } 598 599 cs, err := p.readCarSlice(ctx, rer) 600 if err != nil { 601 return nil, fmt.Errorf("read car slice: %w", err) 602 } 603 evt.Blocks = cs 604 605 return &events.XRPCStreamEvent{RepoSync: evt}, nil 606} 607 608func (p *DbPersistence) readCarSlice(ctx context.Context, rer *RepoEventRecord) ([]byte, error) { 609 610 buf := new(bytes.Buffer) 611 if err := p.cs.ReadUserCar(ctx, rer.Repo, rer.Rev, true, buf); err != nil { 612 return nil, err 613 } 614 615 return buf.Bytes(), nil 616} 617 618func (p *DbPersistence) TakeDownRepo(ctx context.Context, usr models.Uid) error { 619 return p.deleteAllEventsForUser(ctx, usr) 620} 621 622func (p *DbPersistence) deleteAllEventsForUser(ctx context.Context, usr models.Uid) error { 623 if err := p.db.Where("repo = ?", usr).Delete(&RepoEventRecord{}).Error; err != nil { 624 return err 625 } 626 627 return nil 628} 629 630func (p *DbPersistence) Shutdown(context.Context) error { 631 return nil 632}