fork of indigo with slightly nicer lexgen
at main 8.0 kB view raw
1package events 2 3import ( 4 "context" 5 "fmt" 6 "io" 7 "log/slog" 8 "net" 9 "time" 10 11 "github.com/RussellLuo/slidingwindow" 12 comatproto "github.com/bluesky-social/indigo/api/atproto" 13 "github.com/prometheus/client_golang/prometheus" 14 15 "github.com/gorilla/websocket" 16) 17 18type RepoStreamCallbacks struct { 19 RepoCommit func(evt *comatproto.SyncSubscribeRepos_Commit) error 20 RepoSync func(evt *comatproto.SyncSubscribeRepos_Sync) error 21 RepoIdentity func(evt *comatproto.SyncSubscribeRepos_Identity) error 22 RepoAccount func(evt *comatproto.SyncSubscribeRepos_Account) error 23 RepoInfo func(evt *comatproto.SyncSubscribeRepos_Info) error 24 LabelLabels func(evt *comatproto.LabelSubscribeLabels_Labels) error 25 LabelInfo func(evt *comatproto.LabelSubscribeLabels_Info) error 26 Error func(evt *ErrorFrame) error 27} 28 29func (rsc *RepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error { 30 switch { 31 case xev.RepoCommit != nil && rsc.RepoCommit != nil: 32 return rsc.RepoCommit(xev.RepoCommit) 33 case xev.RepoSync != nil && rsc.RepoSync != nil: 34 return rsc.RepoSync(xev.RepoSync) 35 case xev.RepoInfo != nil && rsc.RepoInfo != nil: 36 return rsc.RepoInfo(xev.RepoInfo) 37 case xev.RepoIdentity != nil && rsc.RepoIdentity != nil: 38 return rsc.RepoIdentity(xev.RepoIdentity) 39 case xev.RepoAccount != nil && rsc.RepoAccount != nil: 40 return rsc.RepoAccount(xev.RepoAccount) 41 case xev.LabelLabels != nil && rsc.LabelLabels != nil: 42 return rsc.LabelLabels(xev.LabelLabels) 43 case xev.LabelInfo != nil && rsc.LabelInfo != nil: 44 return rsc.LabelInfo(xev.LabelInfo) 45 case xev.Error != nil && rsc.Error != nil: 46 return rsc.Error(xev.Error) 47 default: 48 return nil 49 } 50} 51 52type InstrumentedRepoStreamCallbacks struct { 53 limiters []*slidingwindow.Limiter 54 Next func(ctx context.Context, xev *XRPCStreamEvent) error 55} 56 57func NewInstrumentedRepoStreamCallbacks(limiters []*slidingwindow.Limiter, next func(ctx context.Context, xev *XRPCStreamEvent) error) *InstrumentedRepoStreamCallbacks { 58 return &InstrumentedRepoStreamCallbacks{ 59 limiters: limiters, 60 Next: next, 61 } 62} 63 64func waitForLimiter(ctx context.Context, lim *slidingwindow.Limiter) error { 65 if lim.Allow() { 66 return nil 67 } 68 69 // wait until the limiter is ready (check every 100ms) 70 t := time.NewTicker(100 * time.Millisecond) 71 defer t.Stop() 72 73 for !lim.Allow() { 74 select { 75 case <-ctx.Done(): 76 return ctx.Err() 77 case <-t.C: 78 } 79 } 80 81 return nil 82} 83 84func (rsc *InstrumentedRepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error { 85 // Wait on all limiters before calling the next handler 86 for _, lim := range rsc.limiters { 87 if err := waitForLimiter(ctx, lim); err != nil { 88 return err 89 } 90 } 91 return rsc.Next(ctx, xev) 92} 93 94type instrumentedReader struct { 95 r io.Reader 96 addr string 97 bytesCounter prometheus.Counter 98} 99 100func (sr *instrumentedReader) Read(p []byte) (int, error) { 101 n, err := sr.r.Read(p) 102 sr.bytesCounter.Add(float64(n)) 103 return n, err 104} 105 106// HandleRepoStream 107// con is source of events 108// sched gets AddWork for each event 109// log may be nil for default logger 110func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, log *slog.Logger) error { 111 if log == nil { 112 log = slog.Default().With("system", "events") 113 } 114 ctx, cancel := context.WithCancel(ctx) 115 defer cancel() 116 defer sched.Shutdown() 117 118 remoteAddr := con.RemoteAddr().String() 119 120 go func() { 121 t := time.NewTicker(time.Second * 30) 122 defer t.Stop() 123 failcount := 0 124 125 for { 126 127 select { 128 case <-t.C: 129 if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil { 130 log.Warn("failed to ping", "err", err) 131 failcount++ 132 if failcount >= 4 { 133 log.Error("too many ping fails", "count", failcount) 134 con.Close() 135 return 136 } 137 } else { 138 failcount = 0 // ok ping 139 } 140 case <-ctx.Done(): 141 con.Close() 142 return 143 } 144 } 145 }() 146 147 con.SetPingHandler(func(message string) error { 148 err := con.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60)) 149 if err == websocket.ErrCloseSent { 150 return nil 151 } else if e, ok := err.(net.Error); ok && e.Temporary() { 152 return nil 153 } 154 return err 155 }) 156 157 con.SetPongHandler(func(_ string) error { 158 if err := con.SetReadDeadline(time.Now().Add(time.Minute)); err != nil { 159 log.Error("failed to set read deadline", "err", err) 160 } 161 162 return nil 163 }) 164 165 lastSeq := int64(-1) 166 for { 167 select { 168 case <-ctx.Done(): 169 return ctx.Err() 170 default: 171 } 172 173 mt, rawReader, err := con.NextReader() 174 if err != nil { 175 return fmt.Errorf("con err at read: %w", err) 176 } 177 178 switch mt { 179 default: 180 return fmt.Errorf("expected binary message from subscription endpoint") 181 case websocket.BinaryMessage: 182 // ok 183 } 184 185 r := &instrumentedReader{ 186 r: rawReader, 187 addr: remoteAddr, 188 bytesCounter: bytesFromStreamCounter.WithLabelValues(remoteAddr), 189 } 190 191 var header EventHeader 192 if err := header.UnmarshalCBOR(r); err != nil { 193 return fmt.Errorf("reading header: %w", err) 194 } 195 196 eventsFromStreamCounter.WithLabelValues(remoteAddr).Inc() 197 198 switch header.Op { 199 case EvtKindMessage: 200 switch header.MsgType { 201 case "#commit": 202 var evt comatproto.SyncSubscribeRepos_Commit 203 if err := evt.UnmarshalCBOR(r); err != nil { 204 return fmt.Errorf("reading repoCommit event: %w", err) 205 } 206 207 if evt.Seq < lastSeq { 208 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 209 } 210 211 lastSeq = evt.Seq 212 213 if err := sched.AddWork(ctx, evt.Repo, &XRPCStreamEvent{ 214 RepoCommit: &evt, 215 }); err != nil { 216 return err 217 } 218 case "#sync": 219 var evt comatproto.SyncSubscribeRepos_Sync 220 if err := evt.UnmarshalCBOR(r); err != nil { 221 return fmt.Errorf("reading repoSync event: %w", err) 222 } 223 224 if evt.Seq < lastSeq { 225 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 226 } 227 228 lastSeq = evt.Seq 229 230 if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 231 RepoSync: &evt, 232 }); err != nil { 233 return err 234 } 235 case "#identity": 236 var evt comatproto.SyncSubscribeRepos_Identity 237 if err := evt.UnmarshalCBOR(r); err != nil { 238 return err 239 } 240 241 if evt.Seq < lastSeq { 242 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 243 } 244 lastSeq = evt.Seq 245 246 if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 247 RepoIdentity: &evt, 248 }); err != nil { 249 return err 250 } 251 case "#account": 252 var evt comatproto.SyncSubscribeRepos_Account 253 if err := evt.UnmarshalCBOR(r); err != nil { 254 return err 255 } 256 257 if evt.Seq < lastSeq { 258 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 259 } 260 lastSeq = evt.Seq 261 262 if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 263 RepoAccount: &evt, 264 }); err != nil { 265 return err 266 } 267 case "#info": 268 // TODO: this might also be a LabelInfo (as opposed to RepoInfo) 269 var evt comatproto.SyncSubscribeRepos_Info 270 if err := evt.UnmarshalCBOR(r); err != nil { 271 return err 272 } 273 274 if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ 275 RepoInfo: &evt, 276 }); err != nil { 277 return err 278 } 279 case "#labels": 280 var evt comatproto.LabelSubscribeLabels_Labels 281 if err := evt.UnmarshalCBOR(r); err != nil { 282 return fmt.Errorf("reading Labels event: %w", err) 283 } 284 285 if evt.Seq < lastSeq { 286 log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 287 } 288 289 lastSeq = evt.Seq 290 291 if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ 292 LabelLabels: &evt, 293 }); err != nil { 294 return err 295 } 296 } 297 298 case EvtKindErrorFrame: 299 var errframe ErrorFrame 300 if err := errframe.UnmarshalCBOR(r); err != nil { 301 return err 302 } 303 304 if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ 305 Error: &errframe, 306 }); err != nil { 307 return err 308 } 309 310 default: 311 return fmt.Errorf("unrecognized event stream type: %d", header.Op) 312 } 313 314 } 315}