+182
-49
jetstream/jetstream.go
+182
-49
jetstream/jetstream.go
···
19
19
UpdateLastTimeUs(int64) error
20
20
}
21
21
22
+
type JetstreamSubscriber struct {
23
+
client *client.Client
24
+
cancel context.CancelFunc
25
+
dids []string
26
+
ident string
27
+
running bool
28
+
}
29
+
22
30
type JetstreamClient struct {
23
-
cfg *client.ClientConfig
24
-
client *client.Client
25
-
ident string
26
-
l *slog.Logger
27
-
28
-
db DB
29
-
waitForDid bool
30
-
mu sync.RWMutex
31
+
cfg *client.ClientConfig
32
+
baseIdent string
33
+
l *slog.Logger
34
+
db DB
35
+
waitForDid bool
36
+
maxDidsPerSubscriber int
31
37
32
-
cancel context.CancelFunc
33
-
cancelMu sync.Mutex
38
+
mu sync.RWMutex
39
+
subscribers []*JetstreamSubscriber
40
+
processFunc func(context.Context, *models.Event) error
41
+
subscriberWg sync.WaitGroup
34
42
}
35
43
36
44
func (j *JetstreamClient) AddDid(did string) {
···
38
46
return
39
47
}
40
48
j.mu.Lock()
49
+
defer j.mu.Unlock()
50
+
51
+
// Just add to the config for now, actual subscriber management happens in UpdateDids
41
52
j.cfg.WantedDids = append(j.cfg.WantedDids, did)
42
-
j.mu.Unlock()
43
53
}
44
54
45
55
func (j *JetstreamClient) UpdateDids(dids []string) {
···
49
59
j.cfg.WantedDids = append(j.cfg.WantedDids, did)
50
60
}
51
61
}
62
+
63
+
needRebalance := j.processFunc != nil
52
64
j.mu.Unlock()
53
65
54
-
j.cancelMu.Lock()
55
-
if j.cancel != nil {
56
-
j.cancel()
66
+
if needRebalance {
67
+
j.rebalanceSubscribers()
57
68
}
58
-
j.cancelMu.Unlock()
59
69
}
60
70
61
-
func NewJetstreamClient(ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) {
71
+
func NewJetstreamClient(endpoint, ident string, collections []string, cfg *client.ClientConfig, logger *slog.Logger, db DB, waitForDid bool) (*JetstreamClient, error) {
62
72
if cfg == nil {
63
73
cfg = client.DefaultClientConfig()
64
-
cfg.WebsocketURL = "wss://jetstream1.us-west.bsky.network/subscribe"
74
+
cfg.WebsocketURL = endpoint
65
75
cfg.WantedCollections = collections
66
76
}
67
77
68
78
return &JetstreamClient{
69
-
cfg: cfg,
70
-
ident: ident,
71
-
db: db,
72
-
l: logger,
73
-
74
-
// This will make the goroutine in StartJetstream wait until
75
-
// cfg.WantedDids has been populated, typically using UpdateDids.
76
-
waitForDid: waitForDid,
79
+
cfg: cfg,
80
+
baseIdent: ident,
81
+
db: db,
82
+
l: logger,
83
+
waitForDid: waitForDid,
84
+
subscribers: make([]*JetstreamSubscriber, 0),
85
+
maxDidsPerSubscriber: 100,
77
86
}, nil
78
87
}
79
88
80
89
// StartJetstream starts the jetstream client and processes events using the provided processFunc.
81
90
// The caller is responsible for saving the last time_us to the database (just use your db.SaveLastTimeUs).
82
91
func (j *JetstreamClient) StartJetstream(ctx context.Context, processFunc func(context.Context, *models.Event) error) error {
83
-
logger := j.l
92
+
j.mu.Lock()
93
+
j.processFunc = processFunc
94
+
j.mu.Unlock()
95
+
96
+
if j.waitForDid {
97
+
// Start a goroutine to wait for DIDs and then start subscribers
98
+
go func() {
99
+
for {
100
+
j.mu.RLock()
101
+
hasDids := len(j.cfg.WantedDids) > 0
102
+
j.mu.RUnlock()
103
+
104
+
if hasDids {
105
+
j.l.Info("done waiting for did, starting subscribers")
106
+
j.rebalanceSubscribers()
107
+
return
108
+
}
109
+
time.Sleep(time.Second)
110
+
}
111
+
}()
112
+
} else {
113
+
// Start subscribers immediately
114
+
j.rebalanceSubscribers()
115
+
}
116
+
117
+
return nil
118
+
}
119
+
120
+
// rebalanceSubscribers creates, updates, or removes subscribers based on the current list of DIDs
121
+
func (j *JetstreamClient) rebalanceSubscribers() {
122
+
j.mu.Lock()
123
+
defer j.mu.Unlock()
124
+
125
+
if j.processFunc == nil {
126
+
j.l.Warn("cannot rebalance subscribers without a process function")
127
+
return
128
+
}
129
+
130
+
// stop all subscribers first
131
+
for _, sub := range j.subscribers {
132
+
if sub.running && sub.cancel != nil {
133
+
sub.cancel()
134
+
sub.running = false
135
+
}
136
+
}
137
+
138
+
// calculate how many subscribers we need
139
+
totalDids := len(j.cfg.WantedDids)
140
+
subscribersNeeded := (totalDids + j.maxDidsPerSubscriber - 1) / j.maxDidsPerSubscriber // ceiling division
141
+
142
+
// create or reuse subscribers as needed
143
+
j.subscribers = j.subscribers[:0]
144
+
145
+
for i := range subscribersNeeded {
146
+
startIdx := i * j.maxDidsPerSubscriber
147
+
endIdx := min((i+1)*j.maxDidsPerSubscriber, totalDids)
148
+
149
+
subscriberDids := j.cfg.WantedDids[startIdx:endIdx]
150
+
151
+
subCfg := *j.cfg
152
+
subCfg.WantedDids = subscriberDids
153
+
154
+
ident := fmt.Sprintf("%s-%d", j.baseIdent, i)
155
+
subscriber := &JetstreamSubscriber{
156
+
dids: subscriberDids,
157
+
ident: ident,
158
+
}
159
+
j.subscribers = append(j.subscribers, subscriber)
160
+
161
+
j.subscriberWg.Add(1)
162
+
go j.startSubscriber(subscriber, &subCfg)
163
+
}
164
+
}
84
165
85
-
sched := sequential.NewScheduler(j.ident, logger, processFunc)
166
+
// startSubscriber initializes and starts a single subscriber
167
+
func (j *JetstreamClient) startSubscriber(sub *JetstreamSubscriber, cfg *client.ClientConfig) {
168
+
defer j.subscriberWg.Done()
86
169
87
-
client, err := client.NewClient(j.cfg, log.New("jetstream"), sched)
170
+
logger := j.l.With("subscriber", sub.ident)
171
+
logger.Info("starting subscriber", "dids_count", len(sub.dids))
172
+
173
+
sched := sequential.NewScheduler(sub.ident, logger, j.processFunc)
174
+
175
+
client, err := client.NewClient(cfg, log.New("jetstream-"+sub.ident), sched)
88
176
if err != nil {
89
-
return fmt.Errorf("failed to create jetstream client: %w", err)
177
+
logger.Error("failed to create jetstream client", "error", err)
178
+
return
90
179
}
91
-
j.client = client
92
180
93
-
go func() {
94
-
if j.waitForDid {
95
-
for len(j.cfg.WantedDids) == 0 {
96
-
time.Sleep(time.Second)
97
-
}
98
-
}
99
-
logger.Info("done waiting for did")
100
-
j.connectAndRead(ctx)
101
-
}()
181
+
sub.client = client
182
+
183
+
j.mu.Lock()
184
+
sub.running = true
185
+
j.mu.Unlock()
102
186
103
-
return nil
187
+
j.connectAndReadForSubscriber(sub)
104
188
}
105
189
106
-
func (j *JetstreamClient) connectAndRead(ctx context.Context) {
107
-
l := log.FromContext(ctx)
190
+
func (j *JetstreamClient) connectAndReadForSubscriber(sub *JetstreamSubscriber) {
191
+
ctx := context.Background()
192
+
l := j.l.With("subscriber", sub.ident)
193
+
108
194
for {
195
+
// Check if this subscriber should still be running
196
+
j.mu.RLock()
197
+
running := sub.running
198
+
j.mu.RUnlock()
199
+
200
+
if !running {
201
+
l.Info("subscriber marked for shutdown")
202
+
return
203
+
}
204
+
109
205
cursor := j.getLastTimeUs(ctx)
110
206
111
207
connCtx, cancel := context.WithCancel(ctx)
112
-
j.cancelMu.Lock()
113
-
j.cancel = cancel
114
-
j.cancelMu.Unlock()
208
+
209
+
j.mu.Lock()
210
+
sub.cancel = cancel
211
+
j.mu.Unlock()
115
212
116
-
if err := j.client.ConnectAndRead(connCtx, cursor); err != nil {
213
+
l.Info("connecting subscriber to jetstream")
214
+
if err := sub.client.ConnectAndRead(connCtx, cursor); err != nil {
117
215
l.Error("error reading jetstream", "error", err)
118
216
cancel()
217
+
time.Sleep(time.Second) // Small backoff before retry
119
218
continue
120
219
}
121
220
122
221
select {
123
222
case <-ctx.Done():
124
-
l.Info("context done, stopping jetstream")
223
+
l.Info("context done, stopping subscriber")
125
224
return
126
225
case <-connCtx.Done():
127
226
l.Info("connection context done, reconnecting")
···
130
229
}
131
230
}
132
231
232
+
// GetRunningSubscribersCount returns the total number of currently running subscribers
233
+
func (j *JetstreamClient) GetRunningSubscribersCount() int {
234
+
j.mu.RLock()
235
+
defer j.mu.RUnlock()
236
+
237
+
runningCount := 0
238
+
for _, sub := range j.subscribers {
239
+
if sub.running {
240
+
runningCount++
241
+
}
242
+
}
243
+
244
+
return runningCount
245
+
}
246
+
247
+
// Shutdown gracefully stops all subscribers
248
+
func (j *JetstreamClient) Shutdown() {
249
+
j.mu.Lock()
250
+
251
+
// Cancel all subscribers
252
+
for _, sub := range j.subscribers {
253
+
if sub.running && sub.cancel != nil {
254
+
sub.cancel()
255
+
sub.running = false
256
+
}
257
+
}
258
+
259
+
j.mu.Unlock()
260
+
261
+
// Wait for all subscribers to complete
262
+
j.subscriberWg.Wait()
263
+
j.l.Info("all subscribers shut down", "total_subscribers", len(j.subscribers), "running_subscribers", j.GetRunningSubscribersCount())
264
+
}
265
+
133
266
func (j *JetstreamClient) getLastTimeUs(ctx context.Context) *int64 {
134
267
l := log.FromContext(ctx)
135
268
lastTimeUs, err := j.db.GetLastTimeUs()
···
142
275
}
143
276
}
144
277
145
-
// If last time is older than a week, start from now
278
+
// If last time is older than 2 days, start from now
146
279
if time.Now().UnixMicro()-lastTimeUs > 2*24*60*60*1000*1000 {
147
280
lastTimeUs = time.Now().UnixMicro()
148
281
l.Warn("last time us is older than 2 days; discarding that and starting from now")
···
152
285
}
153
286
}
154
287
155
-
l.Info("found last time_us", "time_us", lastTimeUs)
288
+
l.Info("found last time_us", "time_us", lastTimeUs, "running_subscribers", j.GetRunningSubscribersCount())
156
289
return &lastTimeUs
157
290
}