forked from tangled.org/core
Monorepo for Tangled

jetstream: subscriber pool to support many watched dids

Introduces rebalancing of dids across multiple subscribers if we exceed
a threshold of 100 wantedDids.

Turns out jetstream has a limit that we hit with our 300+ subscriptions
in a single query.

Changed files
+182 -49
jetstream
+182 -49
jetstream/jetstream.go
··· 19 UpdateLastTimeUs(int64) error 20 } 21 22 type JetstreamClient struct { 23 - cfg *client.ClientConfig 24 - client *client.Client 25 - ident string 26 - l *slog.Logger 27 - 28 - db DB 29 - waitForDid bool 30 - mu sync.RWMutex 31 32 - cancel context.CancelFunc 33 - cancelMu sync.Mutex 34 } 35 36 func (j *JetstreamClient) AddDid(did string) { ··· 38 return 39 } 40 j.mu.Lock() 41 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 42 - j.mu.Unlock() 43 } 44 45 func (j *JetstreamClient) UpdateDids(dids []string) { ··· 49 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 50 } 51 } 52 j.mu.Unlock() 53 54 - j.cancelMu.Lock() 55 - if j.cancel != nil { 56 - j.cancel() 57 } 58 - j.cancelMu.Unlock() 59 } 60 61 - func NewJetstreamClient(ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) { 62 if cfg == nil { 63 cfg = client.DefaultClientConfig() 64 - cfg.WebsocketURL = "wss://jetstream1.us-west.bsky.network/subscribe" 65 cfg.WantedCollections = collections 66 } 67 68 return &JetstreamClient{ 69 - cfg: cfg, 70 - ident: ident, 71 - db: db, 72 - l: logger, 73 - 74 - // This will make the goroutine in StartJetstream wait until 75 - // cfg.WantedDids has been populated, typically using UpdateDids. 76 - waitForDid: waitForDid, 77 }, nil 78 } 79 80 // StartJetstream starts the jetstream client and processes events using the provided processFunc. 81 // The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs). 82 func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error { 83 - logger := j.l 84 85 - sched := sequential.NewScheduler(j.ident, logger, processFunc) 86 87 - client, err := client.NewClient(j.cfg, log.New("jetstream"), sched) 88 if err != nil { 89 - return fmt.Errorf("failed to create jetstream client: %w", err) 90 } 91 - j.client = client 92 93 - go func() { 94 - if j.waitForDid { 95 - for len(j.cfg.WantedDids) == 0 { 96 - time.Sleep(time.Second) 97 - } 98 - } 99 - logger.Info("done waiting for did") 100 - j.connectAndRead(ctx) 101 - }() 102 103 - return nil 104 } 105 106 - func (j *JetstreamClient) connectAndRead(ctx context.Context) { 107 - l := log.FromContext(ctx) 108 for { 109 cursor := j.getLastTimeUs(ctx) 110 111 connCtx, cancel := context.WithCancel(ctx) 112 - j.cancelMu.Lock() 113 - j.cancel = cancel 114 - j.cancelMu.Unlock() 115 116 - if err := j.client.ConnectAndRead(connCtx, cursor); err != nil { 117 l.Error("error reading jetstream", "error", err) 118 cancel() 119 continue 120 } 121 122 select { 123 case <-ctx.Done(): 124 - l.Info("context done, stopping jetstream") 125 return 126 case <-connCtx.Done(): 127 l.Info("connection context done, reconnecting") ··· 130 } 131 } 132 133 func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 { 134 l := log.FromContext(ctx) 135 lastTimeUs, err := j.db.GetLastTimeUs() ··· 142 } 143 } 144 145 - // If last time is older than a week, start from now 146 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 { 147 lastTimeUs = time.Now().UnixMicro() 148 l.Warn("last time us is older than 2 days; discarding that and starting from now") ··· 152 } 153 } 154 155 - l.Info("found last time_us", "time_us", lastTimeUs) 156 return &lastTimeUs 157 }
··· 19 UpdateLastTimeUs(int64) error 20 } 21 22 + type JetstreamSubscriber struct { 23 + client *client.Client 24 + cancel context.CancelFunc 25 + dids []string 26 + ident string 27 + running bool 28 + } 29 + 30 type JetstreamClient struct { 31 + cfg *client.ClientConfig 32 + baseIdent string 33 + l *slog.Logger 34 + db DB 35 + waitForDid bool 36 + maxDidsPerSubscriber int 37 38 + mu sync.RWMutex 39 + subscribers []*JetstreamSubscriber 40 + processFunc func(context.Context, *models.Event) error 41 + subscriberWg sync.WaitGroup 42 } 43 44 func (j *JetstreamClient) AddDid(did string) { ··· 46 return 47 } 48 j.mu.Lock() 49 + defer j.mu.Unlock() 50 + 51 + // Just add to the config for now, actual subscriber management happens in UpdateDids 52 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 53 } 54 55 func (j *JetstreamClient) UpdateDids(dids []string) { ··· 59 j.cfg.WantedDids = append(j.cfg.WantedDids, did) 60 } 61 } 62 + 63 + needRebalance := j.processFunc != nil 64 j.mu.Unlock() 65 66 + if needRebalance { 67 + j.rebalanceSubscribers() 68 } 69 } 70 71 + func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) { 72 if cfg == nil { 73 cfg = client.DefaultClientConfig() 74 + cfg.WebsocketURL = endpoint 75 cfg.WantedCollections = collections 76 } 77 78 return &JetstreamClient{ 79 + cfg: cfg, 80 + baseIdent: ident, 81 + db: db, 82 + l: logger, 83 + waitForDid: waitForDid, 84 + subscribers: make([]*JetstreamSubscriber, 0), 85 + maxDidsPerSubscriber: 100, 86 }, nil 87 } 88 89 // StartJetstream starts the jetstream client and processes events using the provided processFunc. 90 // The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs). 91 func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error { 92 + j.mu.Lock() 93 + j.processFunc = processFunc 94 + j.mu.Unlock() 95 + 96 + if j.waitForDid { 97 + // Start a goroutine to wait for DIDs and then start subscribers 98 + go func() { 99 + for { 100 + j.mu.RLock() 101 + hasDids := len(j.cfg.WantedDids) > 0 102 + j.mu.RUnlock() 103 + 104 + if hasDids { 105 + j.l.Info("done waiting for did, starting subscribers") 106 + j.rebalanceSubscribers() 107 + return 108 + } 109 + time.Sleep(time.Second) 110 + } 111 + }() 112 + } else { 113 + // Start subscribers immediately 114 + j.rebalanceSubscribers() 115 + } 116 + 117 + return nil 118 + } 119 + 120 + // rebalanceSubscribers creates, updates, or removes subscribers based on the current list of DIDs 121 + func (j *JetstreamClient) rebalanceSubscribers() { 122 + j.mu.Lock() 123 + defer j.mu.Unlock() 124 + 125 + if j.processFunc == nil { 126 + j.l.Warn("cannot rebalance subscribers without a process function") 127 + return 128 + } 129 + 130 + // stop all subscribers first 131 + for _, sub := range j.subscribers { 132 + if sub.running && sub.cancel != nil { 133 + sub.cancel() 134 + sub.running = false 135 + } 136 + } 137 + 138 + // calculate how many subscribers we need 139 + totalDids := len(j.cfg.WantedDids) 140 + subscribersNeeded := (totalDids + j.maxDidsPerSubscriber - 1) / j.maxDidsPerSubscriber // ceiling division 141 + 142 + // create or reuse subscribers as needed 143 + j.subscribers = j.subscribers[:0] 144 + 145 + for i := range subscribersNeeded { 146 + startIdx := i * j.maxDidsPerSubscriber 147 + endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids) 148 + 149 + subscriberDids := j.cfg.WantedDids[startIdx:endIdx] 150 + 151 + subCfg := *j.cfg 152 + subCfg.WantedDids = subscriberDids 153 + 154 + ident := fmt.Sprintf("%s-%d", j.baseIdent, i) 155 + subscriber := &JetstreamSubscriber{ 156 + dids: subscriberDids, 157 + ident: ident, 158 + } 159 + j.subscribers = append(j.subscribers, subscriber) 160 + 161 + j.subscriberWg.Add(1) 162 + go j.startSubscriber(subscriber, &subCfg) 163 + } 164 + } 165 166 + // startSubscriber initializes and starts a single subscriber 167 + func (j *JetstreamClient) startSubscriber(sub *JetstreamSubscriber, cfg *client.ClientConfig) { 168 + defer j.subscriberWg.Done() 169 170 + logger := j.l.With("subscriber", sub.ident) 171 + logger.Info("starting subscriber", "dids_count", len(sub.dids)) 172 + 173 + sched := sequential.NewScheduler(sub.ident, logger, j.processFunc) 174 + 175 + client, err := client.NewClient(cfg, log.New("jetstream-"+sub.ident), sched) 176 if err != nil { 177 + logger.Error("failed to create jetstream client", "error", err) 178 + return 179 } 180 181 + sub.client = client 182 + 183 + j.mu.Lock() 184 + sub.running = true 185 + j.mu.Unlock() 186 187 + j.connectAndReadForSubscriber(sub) 188 } 189 190 + func (j *JetstreamClient) connectAndReadForSubscriber(sub *JetstreamSubscriber) { 191 + ctx := context.Background() 192 + l := j.l.With("subscriber", sub.ident) 193 + 194 for { 195 + // Check if this subscriber should still be running 196 + j.mu.RLock() 197 + running := sub.running 198 + j.mu.RUnlock() 199 + 200 + if !running { 201 + l.Info("subscriber marked for shutdown") 202 + return 203 + } 204 + 205 cursor := j.getLastTimeUs(ctx) 206 207 connCtx, cancel := context.WithCancel(ctx) 208 + 209 + j.mu.Lock() 210 + sub.cancel = cancel 211 + j.mu.Unlock() 212 213 + l.Info("connecting subscriber to jetstream") 214 + if err := sub.client.ConnectAndRead(connCtx, cursor); err != nil { 215 l.Error("error reading jetstream", "error", err) 216 cancel() 217 + time.Sleep(time.Second) // Small backoff before retry 218 continue 219 } 220 221 select { 222 case <-ctx.Done(): 223 + l.Info("context done, stopping subscriber") 224 return 225 case <-connCtx.Done(): 226 l.Info("connection context done, reconnecting") ··· 229 } 230 } 231 232 + // GetRunningSubscribersCount returns the total number of currently running subscribers 233 + func (j *JetstreamClient) GetRunningSubscribersCount() int { 234 + j.mu.RLock() 235 + defer j.mu.RUnlock() 236 + 237 + runningCount := 0 238 + for _, sub := range j.subscribers { 239 + if sub.running { 240 + runningCount++ 241 + } 242 + } 243 + 244 + return runningCount 245 + } 246 + 247 + // Shutdown gracefully stops all subscribers 248 + func (j *JetstreamClient) Shutdown() { 249 + j.mu.Lock() 250 + 251 + // Cancel all subscribers 252 + for _, sub := range j.subscribers { 253 + if sub.running && sub.cancel != nil { 254 + sub.cancel() 255 + sub.running = false 256 + } 257 + } 258 + 259 + j.mu.Unlock() 260 + 261 + // Wait for all subscribers to complete 262 + j.subscriberWg.Wait() 263 + j.l.Info("all subscribers shut down", "total_subscribers", len(j.subscribers), "running_subscribers", j.GetRunningSubscribersCount()) 264 + } 265 + 266 func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 { 267 l := log.FromContext(ctx) 268 lastTimeUs, err := j.db.GetLastTimeUs() ··· 275 } 276 } 277 278 + // If last time is older than 2 days, start from now 279 if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 { 280 lastTimeUs = time.Now().UnixMicro() 281 l.Warn("last time us is older than 2 days; discarding that and starting from now") ··· 285 } 286 } 287 288 + l.Info("found last time_us", "time_us", lastTimeUs, "running_subscribers", j.GetRunningSubscribersCount()) 289 return &lastTimeUs 290 }