1package events
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "log/slog"
8 "net"
9 "time"
10
11 "github.com/RussellLuo/slidingwindow"
12 comatproto "github.com/bluesky-social/indigo/api/atproto"
13 "github.com/prometheus/client_golang/prometheus"
14
15 "github.com/gorilla/websocket"
16)
17
18type RepoStreamCallbacks struct {
19 RepoCommit func(evt *comatproto.SyncSubscribeRepos_Commit) error
20 RepoSync func(evt *comatproto.SyncSubscribeRepos_Sync) error
21 RepoIdentity func(evt *comatproto.SyncSubscribeRepos_Identity) error
22 RepoAccount func(evt *comatproto.SyncSubscribeRepos_Account) error
23 RepoInfo func(evt *comatproto.SyncSubscribeRepos_Info) error
24 LabelLabels func(evt *comatproto.LabelSubscribeLabels_Labels) error
25 LabelInfo func(evt *comatproto.LabelSubscribeLabels_Info) error
26 Error func(evt *ErrorFrame) error
27}
28
29func (rsc *RepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error {
30 switch {
31 case xev.RepoCommit != nil && rsc.RepoCommit != nil:
32 return rsc.RepoCommit(xev.RepoCommit)
33 case xev.RepoSync != nil && rsc.RepoSync != nil:
34 return rsc.RepoSync(xev.RepoSync)
35 case xev.RepoInfo != nil && rsc.RepoInfo != nil:
36 return rsc.RepoInfo(xev.RepoInfo)
37 case xev.RepoIdentity != nil && rsc.RepoIdentity != nil:
38 return rsc.RepoIdentity(xev.RepoIdentity)
39 case xev.RepoAccount != nil && rsc.RepoAccount != nil:
40 return rsc.RepoAccount(xev.RepoAccount)
41 case xev.LabelLabels != nil && rsc.LabelLabels != nil:
42 return rsc.LabelLabels(xev.LabelLabels)
43 case xev.LabelInfo != nil && rsc.LabelInfo != nil:
44 return rsc.LabelInfo(xev.LabelInfo)
45 case xev.Error != nil && rsc.Error != nil:
46 return rsc.Error(xev.Error)
47 default:
48 return nil
49 }
50}
51
52type InstrumentedRepoStreamCallbacks struct {
53 limiters []*slidingwindow.Limiter
54 Next func(ctx context.Context, xev *XRPCStreamEvent) error
55}
56
57func NewInstrumentedRepoStreamCallbacks(limiters []*slidingwindow.Limiter, next func(ctx context.Context, xev *XRPCStreamEvent) error) *InstrumentedRepoStreamCallbacks {
58 return &InstrumentedRepoStreamCallbacks{
59 limiters: limiters,
60 Next: next,
61 }
62}
63
64func waitForLimiter(ctx context.Context, lim *slidingwindow.Limiter) error {
65 if lim.Allow() {
66 return nil
67 }
68
69 // wait until the limiter is ready (check every 100ms)
70 t := time.NewTicker(100 * time.Millisecond)
71 defer t.Stop()
72
73 for !lim.Allow() {
74 select {
75 case <-ctx.Done():
76 return ctx.Err()
77 case <-t.C:
78 }
79 }
80
81 return nil
82}
83
84func (rsc *InstrumentedRepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error {
85 // Wait on all limiters before calling the next handler
86 for _, lim := range rsc.limiters {
87 if err := waitForLimiter(ctx, lim); err != nil {
88 return err
89 }
90 }
91 return rsc.Next(ctx, xev)
92}
93
94type instrumentedReader struct {
95 r io.Reader
96 addr string
97 bytesCounter prometheus.Counter
98}
99
100func (sr *instrumentedReader) Read(p []byte) (int, error) {
101 n, err := sr.r.Read(p)
102 sr.bytesCounter.Add(float64(n))
103 return n, err
104}
105
106// HandleRepoStream
107// con is source of events
108// sched gets AddWork for each event
109// log may be nil for default logger
110func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, log *slog.Logger) error {
111 if log == nil {
112 log = slog.Default().With("system", "events")
113 }
114 ctx, cancel := context.WithCancel(ctx)
115 defer cancel()
116 defer sched.Shutdown()
117
118 remoteAddr := con.RemoteAddr().String()
119
120 go func() {
121 t := time.NewTicker(time.Second * 30)
122 defer t.Stop()
123 failcount := 0
124
125 for {
126
127 select {
128 case <-t.C:
129 if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil {
130 log.Warn("failed to ping", "err", err)
131 failcount++
132 if failcount >= 4 {
133 log.Error("too many ping fails", "count", failcount)
134 con.Close()
135 return
136 }
137 } else {
138 failcount = 0 // ok ping
139 }
140 case <-ctx.Done():
141 con.Close()
142 return
143 }
144 }
145 }()
146
147 con.SetPingHandler(func(message string) error {
148 err := con.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60))
149 if err == websocket.ErrCloseSent {
150 return nil
151 } else if e, ok := err.(net.Error); ok && e.Temporary() {
152 return nil
153 }
154 return err
155 })
156
157 con.SetPongHandler(func(_ string) error {
158 if err := con.SetReadDeadline(time.Now().Add(time.Minute)); err != nil {
159 log.Error("failed to set read deadline", "err", err)
160 }
161
162 return nil
163 })
164
165 lastSeq := int64(-1)
166 for {
167 select {
168 case <-ctx.Done():
169 return ctx.Err()
170 default:
171 }
172
173 mt, rawReader, err := con.NextReader()
174 if err != nil {
175 return fmt.Errorf("con err at read: %w", err)
176 }
177
178 switch mt {
179 default:
180 return fmt.Errorf("expected binary message from subscription endpoint")
181 case websocket.BinaryMessage:
182 // ok
183 }
184
185 r := &instrumentedReader{
186 r: rawReader,
187 addr: remoteAddr,
188 bytesCounter: bytesFromStreamCounter.WithLabelValues(remoteAddr),
189 }
190
191 var header EventHeader
192 if err := header.UnmarshalCBOR(r); err != nil {
193 return fmt.Errorf("reading header: %w", err)
194 }
195
196 eventsFromStreamCounter.WithLabelValues(remoteAddr).Inc()
197
198 switch header.Op {
199 case EvtKindMessage:
200 switch header.MsgType {
201 case "#commit":
202 var evt comatproto.SyncSubscribeRepos_Commit
203 if err := evt.UnmarshalCBOR(r); err != nil {
204 return fmt.Errorf("reading repoCommit event: %w", err)
205 }
206
207 if evt.Seq < lastSeq {
208 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
209 }
210
211 lastSeq = evt.Seq
212
213 if err := sched.AddWork(ctx, evt.Repo, &XRPCStreamEvent{
214 RepoCommit: &evt,
215 }); err != nil {
216 return err
217 }
218 case "#sync":
219 var evt comatproto.SyncSubscribeRepos_Sync
220 if err := evt.UnmarshalCBOR(r); err != nil {
221 return fmt.Errorf("reading repoSync event: %w", err)
222 }
223
224 if evt.Seq < lastSeq {
225 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
226 }
227
228 lastSeq = evt.Seq
229
230 if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{
231 RepoSync: &evt,
232 }); err != nil {
233 return err
234 }
235 case "#identity":
236 var evt comatproto.SyncSubscribeRepos_Identity
237 if err := evt.UnmarshalCBOR(r); err != nil {
238 return err
239 }
240
241 if evt.Seq < lastSeq {
242 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
243 }
244 lastSeq = evt.Seq
245
246 if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{
247 RepoIdentity: &evt,
248 }); err != nil {
249 return err
250 }
251 case "#account":
252 var evt comatproto.SyncSubscribeRepos_Account
253 if err := evt.UnmarshalCBOR(r); err != nil {
254 return err
255 }
256
257 if evt.Seq < lastSeq {
258 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
259 }
260 lastSeq = evt.Seq
261
262 if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{
263 RepoAccount: &evt,
264 }); err != nil {
265 return err
266 }
267 case "#info":
268 // TODO: this might also be a LabelInfo (as opposed to RepoInfo)
269 var evt comatproto.SyncSubscribeRepos_Info
270 if err := evt.UnmarshalCBOR(r); err != nil {
271 return err
272 }
273
274 if err := sched.AddWork(ctx, "", &XRPCStreamEvent{
275 RepoInfo: &evt,
276 }); err != nil {
277 return err
278 }
279 case "#labels":
280 var evt comatproto.LabelSubscribeLabels_Labels
281 if err := evt.UnmarshalCBOR(r); err != nil {
282 return fmt.Errorf("reading Labels event: %w", err)
283 }
284
285 if evt.Seq < lastSeq {
286 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq)
287 }
288
289 lastSeq = evt.Seq
290
291 if err := sched.AddWork(ctx, "", &XRPCStreamEvent{
292 LabelLabels: &evt,
293 }); err != nil {
294 return err
295 }
296 }
297
298 case EvtKindErrorFrame:
299 var errframe ErrorFrame
300 if err := errframe.UnmarshalCBOR(r); err != nil {
301 return err
302 }
303
304 if err := sched.AddWork(ctx, "", &XRPCStreamEvent{
305 Error: &errframe,
306 }); err != nil {
307 return err
308 }
309
310 default:
311 return fmt.Errorf("unrecognized event stream type: %d", header.Op)
312 }
313
314 }
315}