1package bgs
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "math/rand"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/RussellLuo/slidingwindow"
14 comatproto "github.com/bluesky-social/indigo/api/atproto"
15 "github.com/bluesky-social/indigo/events"
16 "github.com/bluesky-social/indigo/events/schedulers/parallel"
17 "github.com/bluesky-social/indigo/models"
18 "go.opentelemetry.io/otel"
19 "golang.org/x/time/rate"
20
21 "github.com/gorilla/websocket"
22 pq "github.com/lib/pq"
23 "gorm.io/gorm"
24)
25
26var log = slog.Default().With("system", "bgs")
27
28type IndexCallback func(context.Context, *models.PDS, *events.XRPCStreamEvent) error
29
30// TODO: rename me
31type Slurper struct {
32 cb IndexCallback
33
34 db *gorm.DB
35
36 lk sync.Mutex
37 active map[string]*activeSub
38
39 LimitMux sync.RWMutex
40 Limiters map[uint]*Limiters
41 DefaultPerSecondLimit int64
42 DefaultPerHourLimit int64
43 DefaultPerDayLimit int64
44
45 DefaultCrawlLimit rate.Limit
46 DefaultRepoLimit int64
47 ConcurrencyPerPDS int64
48 MaxQueuePerPDS int64
49
50 NewPDSPerDayLimiter *slidingwindow.Limiter
51
52 newSubsDisabled bool
53 trustedDomains []string
54
55 shutdownChan chan bool
56 shutdownResult chan []error
57
58 ssl bool
59}
60
61type Limiters struct {
62 PerSecond *slidingwindow.Limiter
63 PerHour *slidingwindow.Limiter
64 PerDay *slidingwindow.Limiter
65}
66
67type SlurperOptions struct {
68 SSL bool
69 DefaultPerSecondLimit int64
70 DefaultPerHourLimit int64
71 DefaultPerDayLimit int64
72 DefaultCrawlLimit rate.Limit
73 DefaultRepoLimit int64
74 ConcurrencyPerPDS int64
75 MaxQueuePerPDS int64
76}
77
78func DefaultSlurperOptions() *SlurperOptions {
79 return &SlurperOptions{
80 SSL: false,
81 DefaultPerSecondLimit: 50,
82 DefaultPerHourLimit: 2500,
83 DefaultPerDayLimit: 20_000,
84 DefaultCrawlLimit: rate.Limit(5),
85 DefaultRepoLimit: 100,
86 ConcurrencyPerPDS: 100,
87 MaxQueuePerPDS: 1_000,
88 }
89}
90
91type activeSub struct {
92 pds *models.PDS
93 lk sync.RWMutex
94 ctx context.Context
95 cancel func()
96}
97
98func NewSlurper(db *gorm.DB, cb IndexCallback, opts *SlurperOptions) (*Slurper, error) {
99 if opts == nil {
100 opts = DefaultSlurperOptions()
101 }
102 db.AutoMigrate(&SlurpConfig{})
103 s := &Slurper{
104 cb: cb,
105 db: db,
106 active: make(map[string]*activeSub),
107 Limiters: make(map[uint]*Limiters),
108 DefaultPerSecondLimit: opts.DefaultPerSecondLimit,
109 DefaultPerHourLimit: opts.DefaultPerHourLimit,
110 DefaultPerDayLimit: opts.DefaultPerDayLimit,
111 DefaultCrawlLimit: opts.DefaultCrawlLimit,
112 DefaultRepoLimit: opts.DefaultRepoLimit,
113 ConcurrencyPerPDS: opts.ConcurrencyPerPDS,
114 MaxQueuePerPDS: opts.MaxQueuePerPDS,
115 ssl: opts.SSL,
116 shutdownChan: make(chan bool),
117 shutdownResult: make(chan []error),
118 }
119 if err := s.loadConfig(); err != nil {
120 return nil, err
121 }
122
123 // Start a goroutine to flush cursors to the DB every 30s
124 go func() {
125 for {
126 select {
127 case <-s.shutdownChan:
128 log.Info("flushing PDS cursors on shutdown")
129 ctx := context.Background()
130 ctx, span := otel.Tracer("feedmgr").Start(ctx, "CursorFlusherShutdown")
131 defer span.End()
132 var errs []error
133 if errs = s.flushCursors(ctx); len(errs) > 0 {
134 for _, err := range errs {
135 log.Error("failed to flush cursors on shutdown", "err", err)
136 }
137 }
138 log.Info("done flushing PDS cursors on shutdown")
139 s.shutdownResult <- errs
140 return
141 case <-time.After(time.Second * 10):
142 log.Debug("flushing PDS cursors")
143 ctx := context.Background()
144 ctx, span := otel.Tracer("feedmgr").Start(ctx, "CursorFlusher")
145 defer span.End()
146 if errs := s.flushCursors(ctx); len(errs) > 0 {
147 for _, err := range errs {
148 log.Error("failed to flush cursors", "err", err)
149 }
150 }
151 log.Debug("done flushing PDS cursors")
152 }
153 }
154 }()
155
156 return s, nil
157}
158
159func windowFunc() (slidingwindow.Window, slidingwindow.StopFunc) {
160 return slidingwindow.NewLocalWindow()
161}
162
163func (s *Slurper) GetLimiters(pdsID uint) *Limiters {
164 s.LimitMux.RLock()
165 defer s.LimitMux.RUnlock()
166 return s.Limiters[pdsID]
167}
168
169func (s *Slurper) GetOrCreateLimiters(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) *Limiters {
170 s.LimitMux.RLock()
171 defer s.LimitMux.RUnlock()
172 lim, ok := s.Limiters[pdsID]
173 if !ok {
174 perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc)
175 perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc)
176 perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc)
177 lim = &Limiters{
178 PerSecond: perSec,
179 PerHour: perHour,
180 PerDay: perDay,
181 }
182 s.Limiters[pdsID] = lim
183 }
184
185 return lim
186}
187
188func (s *Slurper) SetLimits(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) {
189 s.LimitMux.Lock()
190 defer s.LimitMux.Unlock()
191 lim, ok := s.Limiters[pdsID]
192 if !ok {
193 perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc)
194 perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc)
195 perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc)
196 lim = &Limiters{
197 PerSecond: perSec,
198 PerHour: perHour,
199 PerDay: perDay,
200 }
201 s.Limiters[pdsID] = lim
202 }
203
204 lim.PerSecond.SetLimit(perSecLimit)
205 lim.PerHour.SetLimit(perHourLimit)
206 lim.PerDay.SetLimit(perDayLimit)
207}
208
209// Shutdown shuts down the slurper
210func (s *Slurper) Shutdown() []error {
211 s.shutdownChan <- true
212 log.Info("waiting for slurper shutdown")
213 errs := <-s.shutdownResult
214 if len(errs) > 0 {
215 for _, err := range errs {
216 log.Error("shutdown error", "err", err)
217 }
218 }
219 log.Info("slurper shutdown complete")
220 return errs
221}
222
223func (s *Slurper) loadConfig() error {
224 var sc SlurpConfig
225 if err := s.db.Find(&sc).Error; err != nil {
226 return err
227 }
228
229 if sc.ID == 0 {
230 if err := s.db.Create(&SlurpConfig{}).Error; err != nil {
231 return err
232 }
233 }
234
235 s.newSubsDisabled = sc.NewSubsDisabled
236 s.trustedDomains = sc.TrustedDomains
237
238 s.NewPDSPerDayLimiter, _ = slidingwindow.NewLimiter(time.Hour*24, sc.NewPDSPerDayLimit, windowFunc)
239
240 return nil
241}
242
243type SlurpConfig struct {
244 gorm.Model
245
246 NewSubsDisabled bool
247 TrustedDomains pq.StringArray `gorm:"type:text[]"`
248 NewPDSPerDayLimit int64
249}
250
251func (s *Slurper) SetNewSubsDisabled(dis bool) error {
252 s.lk.Lock()
253 defer s.lk.Unlock()
254
255 if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("new_subs_disabled", dis).Error; err != nil {
256 return err
257 }
258
259 s.newSubsDisabled = dis
260 return nil
261}
262
263func (s *Slurper) GetNewSubsDisabledState() bool {
264 s.lk.Lock()
265 defer s.lk.Unlock()
266 return s.newSubsDisabled
267}
268
269func (s *Slurper) SetNewPDSPerDayLimit(limit int64) error {
270 s.lk.Lock()
271 defer s.lk.Unlock()
272
273 if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("new_pds_per_day_limit", limit).Error; err != nil {
274 return err
275 }
276
277 s.NewPDSPerDayLimiter.SetLimit(limit)
278 return nil
279}
280
281func (s *Slurper) GetNewPDSPerDayLimit() int64 {
282 s.lk.Lock()
283 defer s.lk.Unlock()
284 return s.NewPDSPerDayLimiter.Limit()
285}
286
287func (s *Slurper) AddTrustedDomain(domain string) error {
288 s.lk.Lock()
289 defer s.lk.Unlock()
290
291 if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_append(trusted_domains, ?)", domain)).Error; err != nil {
292 return err
293 }
294
295 s.trustedDomains = append(s.trustedDomains, domain)
296 return nil
297}
298
299func (s *Slurper) RemoveTrustedDomain(domain string) error {
300 s.lk.Lock()
301 defer s.lk.Unlock()
302
303 if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_remove(trusted_domains, ?)", domain)).Error; err != nil {
304 if errors.Is(err, gorm.ErrRecordNotFound) {
305 return nil
306 }
307 return err
308 }
309
310 for i, d := range s.trustedDomains {
311 if d == domain {
312 s.trustedDomains = append(s.trustedDomains[:i], s.trustedDomains[i+1:]...)
313 break
314 }
315 }
316
317 return nil
318}
319
320func (s *Slurper) SetTrustedDomains(domains []string) error {
321 s.lk.Lock()
322 defer s.lk.Unlock()
323
324 if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", domains).Error; err != nil {
325 return err
326 }
327
328 s.trustedDomains = domains
329 return nil
330}
331
332func (s *Slurper) GetTrustedDomains() []string {
333 s.lk.Lock()
334 defer s.lk.Unlock()
335 return s.trustedDomains
336}
337
338var ErrNewSubsDisabled = fmt.Errorf("new subscriptions temporarily disabled")
339
340// Checks whether a host is allowed to be subscribed to
341// must be called with the slurper lock held
342func (s *Slurper) canSlurpHost(host string) bool {
343 // Check if we're over the limit for new PDSs today
344 if !s.NewPDSPerDayLimiter.Allow() {
345 return false
346 }
347
348 // Check if the host is a trusted domain
349 for _, d := range s.trustedDomains {
350 // If the domain starts with a *., it's a wildcard
351 if strings.HasPrefix(d, "*.") {
352 // Cut off the * so we have .domain.com
353 if strings.HasSuffix(host, strings.TrimPrefix(d, "*")) {
354 return true
355 }
356 } else {
357 if host == d {
358 return true
359 }
360 }
361 }
362
363 return !s.newSubsDisabled
364}
365
366func (s *Slurper) SubscribeToPds(ctx context.Context, host string, reg bool, adminOverride bool, rateOverrides *PDSRates) error {
367 // TODO: for performance, lock on the hostname instead of global
368 s.lk.Lock()
369 defer s.lk.Unlock()
370
371 _, ok := s.active[host]
372 if ok {
373 return nil
374 }
375
376 var peering models.PDS
377 if err := s.db.Find(&peering, "host = ?", host).Error; err != nil {
378 return err
379 }
380
381 if peering.Blocked {
382 return fmt.Errorf("cannot subscribe to blocked pds")
383 }
384
385 if peering.ID == 0 {
386 if !adminOverride && !s.canSlurpHost(host) {
387 return ErrNewSubsDisabled
388 }
389 // New PDS!
390 npds := models.PDS{
391 Host: host,
392 SSL: s.ssl,
393 Registered: reg,
394 RateLimit: float64(s.DefaultPerSecondLimit),
395 HourlyEventLimit: s.DefaultPerHourLimit,
396 DailyEventLimit: s.DefaultPerDayLimit,
397 CrawlRateLimit: float64(s.DefaultCrawlLimit),
398 RepoLimit: s.DefaultRepoLimit,
399 }
400 if rateOverrides != nil {
401 npds.RateLimit = float64(rateOverrides.PerSecond)
402 npds.HourlyEventLimit = rateOverrides.PerHour
403 npds.DailyEventLimit = rateOverrides.PerDay
404 npds.CrawlRateLimit = float64(rateOverrides.CrawlRate)
405 npds.RepoLimit = rateOverrides.RepoLimit
406 }
407 if err := s.db.Create(&npds).Error; err != nil {
408 return err
409 }
410
411 peering = npds
412 }
413
414 if !peering.Registered && reg {
415 peering.Registered = true
416 if err := s.db.Model(models.PDS{}).Where("id = ?", peering.ID).Update("registered", true).Error; err != nil {
417 return err
418 }
419 }
420
421 ctx, cancel := context.WithCancel(context.Background())
422 sub := activeSub{
423 pds: &peering,
424 ctx: ctx,
425 cancel: cancel,
426 }
427 s.active[host] = &sub
428
429 s.GetOrCreateLimiters(peering.ID, int64(peering.RateLimit), peering.HourlyEventLimit, peering.DailyEventLimit)
430
431 go s.subscribeWithRedialer(ctx, &peering, &sub)
432
433 return nil
434}
435
436func (s *Slurper) RestartAll() error {
437 s.lk.Lock()
438 defer s.lk.Unlock()
439
440 var all []models.PDS
441 if err := s.db.Find(&all, "registered = true AND blocked = false").Error; err != nil {
442 return err
443 }
444
445 for _, pds := range all {
446 pds := pds
447
448 ctx, cancel := context.WithCancel(context.Background())
449 sub := activeSub{
450 pds: &pds,
451 ctx: ctx,
452 cancel: cancel,
453 }
454 s.active[pds.Host] = &sub
455
456 // Check if we've already got a limiter for this PDS
457 s.GetOrCreateLimiters(pds.ID, int64(pds.RateLimit), pds.HourlyEventLimit, pds.DailyEventLimit)
458 go s.subscribeWithRedialer(ctx, &pds, &sub)
459 }
460
461 return nil
462}
463
464func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.PDS, sub *activeSub) {
465 defer func() {
466 s.lk.Lock()
467 defer s.lk.Unlock()
468
469 delete(s.active, host.Host)
470 }()
471
472 d := websocket.Dialer{
473 HandshakeTimeout: time.Second * 5,
474 }
475
476 protocol := "ws"
477 if s.ssl {
478 protocol = "wss"
479 }
480
481 // Special case `.host.bsky.network` PDSs to rewind cursor by 200 events to smooth over unclean shutdowns
482 if strings.HasSuffix(host.Host, ".host.bsky.network") && host.Cursor > 200 {
483 host.Cursor -= 200
484 }
485
486 cursor := host.Cursor
487
488 connectedInbound.Inc()
489 defer connectedInbound.Dec()
490 // TODO:? maybe keep a gauge of 'in retry backoff' sources?
491
492 var backoff int
493 for {
494 select {
495 case <-ctx.Done():
496 return
497 default:
498 }
499
500 url := fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos?cursor=%d", protocol, host.Host, cursor)
501 con, res, err := d.DialContext(ctx, url, nil)
502 if err != nil {
503 log.Warn("dialing failed", "pdsHost", host.Host, "err", err, "backoff", backoff)
504 time.Sleep(sleepForBackoff(backoff))
505 backoff++
506
507 if backoff > 15 {
508 log.Warn("pds does not appear to be online, disabling for now", "pdsHost", host.Host)
509 if err := s.db.Model(&models.PDS{}).Where("id = ?", host.ID).Update("registered", false).Error; err != nil {
510 log.Error("failed to unregister failing pds", "err", err)
511 }
512
513 return
514 }
515
516 continue
517 }
518
519 log.Info("event subscription response", "code", res.StatusCode)
520
521 curCursor := cursor
522 if err := s.handleConnection(ctx, host, con, &cursor, sub); err != nil {
523 if errors.Is(err, ErrTimeoutShutdown) {
524 log.Info("shutting down pds subscription after timeout", "host", host.Host, "time", EventsTimeout)
525 return
526 }
527 log.Warn("connection to failed", "host", host.Host, "err", err)
528 }
529
530 if cursor > curCursor {
531 backoff = 0
532 }
533 }
534}
535
536func sleepForBackoff(b int) time.Duration {
537 if b == 0 {
538 return 0
539 }
540
541 if b < 10 {
542 return (time.Duration(b) * 2) + (time.Millisecond * time.Duration(rand.Intn(1000)))
543 }
544
545 return time.Second * 30
546}
547
548var ErrTimeoutShutdown = fmt.Errorf("timed out waiting for new events")
549
550var EventsTimeout = time.Minute
551
552func (s *Slurper) handleConnection(ctx context.Context, host *models.PDS, con *websocket.Conn, lastCursor *int64, sub *activeSub) error {
553 ctx, cancel := context.WithCancel(ctx)
554 defer cancel()
555
556 rsc := &events.RepoStreamCallbacks{
557 RepoCommit: func(evt *comatproto.SyncSubscribeRepos_Commit) error {
558 log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Repo, "seq", evt.Seq)
559 if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{
560 RepoCommit: evt,
561 }); err != nil {
562 log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err)
563 }
564 *lastCursor = evt.Seq
565
566 if err := s.updateCursor(sub, *lastCursor); err != nil {
567 return fmt.Errorf("updating cursor: %w", err)
568 }
569
570 return nil
571 },
572 RepoSync: func(evt *comatproto.SyncSubscribeRepos_Sync) error {
573 log.Info("sync event", "did", evt.Did, "pdsHost", host.Host, "seq", evt.Seq)
574 if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{
575 RepoSync: evt,
576 }); err != nil {
577 log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err)
578 }
579 *lastCursor = evt.Seq
580
581 if err := s.updateCursor(sub, *lastCursor); err != nil {
582 return fmt.Errorf("updating cursor: %w", err)
583 }
584
585 return nil
586 },
587 RepoInfo: func(info *comatproto.SyncSubscribeRepos_Info) error {
588 log.Info("info event", "name", info.Name, "message", info.Message, "pdsHost", host.Host)
589 return nil
590 },
591 RepoIdentity: func(ident *comatproto.SyncSubscribeRepos_Identity) error {
592 log.Info("identity event", "did", ident.Did)
593 if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{
594 RepoIdentity: ident,
595 }); err != nil {
596 log.Error("failed handling event", "host", host.Host, "seq", ident.Seq, "err", err)
597 }
598 *lastCursor = ident.Seq
599
600 if err := s.updateCursor(sub, *lastCursor); err != nil {
601 return fmt.Errorf("updating cursor: %w", err)
602 }
603
604 return nil
605 },
606 RepoAccount: func(acct *comatproto.SyncSubscribeRepos_Account) error {
607 log.Info("account event", "did", acct.Did, "status", acct.Status)
608 if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{
609 RepoAccount: acct,
610 }); err != nil {
611 log.Error("failed handling event", "host", host.Host, "seq", acct.Seq, "err", err)
612 }
613 *lastCursor = acct.Seq
614
615 if err := s.updateCursor(sub, *lastCursor); err != nil {
616 return fmt.Errorf("updating cursor: %w", err)
617 }
618
619 return nil
620 },
621 // TODO: all the other event types (handle change, migration, etc)
622 Error: func(errf *events.ErrorFrame) error {
623 switch errf.Error {
624 case "FutureCursor":
625 // if we get a FutureCursor frame, reset our sequence number for this host
626 if err := s.db.Table("pds").Where("id = ?", host.ID).Update("cursor", 0).Error; err != nil {
627 return err
628 }
629
630 *lastCursor = 0
631 return fmt.Errorf("got FutureCursor frame, reset cursor tracking for host")
632 default:
633 return fmt.Errorf("error frame: %s: %s", errf.Error, errf.Message)
634 }
635 },
636 }
637
638 lims := s.GetOrCreateLimiters(host.ID, int64(host.RateLimit), host.HourlyEventLimit, host.DailyEventLimit)
639
640 limiters := []*slidingwindow.Limiter{
641 lims.PerSecond,
642 lims.PerHour,
643 lims.PerDay,
644 }
645
646 instrumentedRSC := events.NewInstrumentedRepoStreamCallbacks(limiters, rsc.EventHandler)
647
648 pool := parallel.NewScheduler(
649 100,
650 1_000,
651 con.RemoteAddr().String(),
652 instrumentedRSC.EventHandler,
653 )
654 return events.HandleRepoStream(ctx, con, pool, nil)
655}
656
657func (s *Slurper) updateCursor(sub *activeSub, curs int64) error {
658 sub.lk.Lock()
659 defer sub.lk.Unlock()
660 sub.pds.Cursor = curs
661 return nil
662}
663
664type cursorSnapshot struct {
665 id uint
666 cursor int64
667}
668
669// flushCursors updates the PDS cursors in the DB for all active subscriptions
670func (s *Slurper) flushCursors(ctx context.Context) []error {
671 ctx, span := otel.Tracer("feedmgr").Start(ctx, "flushCursors")
672 defer span.End()
673
674 var cursors []cursorSnapshot
675
676 s.lk.Lock()
677 // Iterate over active subs and copy the current cursor
678 for _, sub := range s.active {
679 sub.lk.RLock()
680 cursors = append(cursors, cursorSnapshot{
681 id: sub.pds.ID,
682 cursor: sub.pds.Cursor,
683 })
684 sub.lk.RUnlock()
685 }
686 s.lk.Unlock()
687
688 errs := []error{}
689
690 tx := s.db.WithContext(ctx).Begin()
691 for _, cursor := range cursors {
692 if err := tx.WithContext(ctx).Model(models.PDS{}).Where("id = ?", cursor.id).UpdateColumn("cursor", cursor.cursor).Error; err != nil {
693 errs = append(errs, err)
694 }
695 }
696 if err := tx.WithContext(ctx).Commit().Error; err != nil {
697 errs = append(errs, err)
698 }
699
700 return errs
701}
702
703func (s *Slurper) GetActiveList() []string {
704 s.lk.Lock()
705 defer s.lk.Unlock()
706 var out []string
707 for k := range s.active {
708 out = append(out, k)
709 }
710
711 return out
712}
713
714var ErrNoActiveConnection = fmt.Errorf("no active connection to host")
715
716func (s *Slurper) KillUpstreamConnection(host string, block bool) error {
717 s.lk.Lock()
718 defer s.lk.Unlock()
719
720 ac, ok := s.active[host]
721 if !ok {
722 return fmt.Errorf("killing connection %q: %w", host, ErrNoActiveConnection)
723 }
724 ac.cancel()
725 // cleanup in the run thread subscribeWithRedialer() will delete(s.active, host)
726
727 if block {
728 if err := s.db.Model(models.PDS{}).Where("id = ?", ac.pds.ID).UpdateColumn("blocked", true).Error; err != nil {
729 return fmt.Errorf("failed to set host as blocked: %w", err)
730 }
731 }
732
733 return nil
734}