fork of indigo with slightly nicer lexgen
1package splitter
2
3import (
4 "context"
5 "fmt"
6 "net"
7 "strconv"
8 "sync"
9 "time"
10
11 "github.com/bluesky-social/indigo/events"
12
13 "github.com/gorilla/websocket"
14 "github.com/labstack/echo/v4"
15 "github.com/prometheus/client_golang/prometheus"
16 dto "github.com/prometheus/client_model/go"
17)
18
19func (s *Splitter) HandleSubscribeRepos(c echo.Context) error {
20 var since *int64
21 if sinceVal := c.QueryParam("cursor"); sinceVal != "" {
22 sval, err := strconv.ParseInt(sinceVal, 10, 64)
23 if err != nil {
24 return err
25 }
26 since = &sval
27 }
28
29 // NOTE: the request context outlives the HTTP 101 response; it lives as long as the WebSocket is open, and then get cancelled. That is the behavior we want for this ctx, but should be careful if spawning goroutines which should outlive the WebSocket connection.
30 // https://github.com/bluesky-social/indigo/pull/1023#pullrequestreview-2768335762
31 ctx, cancel := context.WithCancel(c.Request().Context())
32 defer cancel()
33
34 // TODO: authhhh
35 conn, err := websocket.Upgrade(c.Response(), c.Request(), c.Response().Header(), 10<<10, 10<<10)
36 if err != nil {
37 return fmt.Errorf("upgrading websocket: %w", err)
38 }
39
40 defer conn.Close()
41
42 lastWriteLk := sync.Mutex{}
43 lastWrite := time.Now()
44
45 // Start a goroutine to ping the client every 30 seconds to check if it's
46 // still alive. If the client doesn't respond to a ping within 5 seconds,
47 // we'll close the connection and teardown the consumer.
48 go func() {
49 ticker := time.NewTicker(30 * time.Second)
50 defer ticker.Stop()
51
52 for {
53 select {
54 case <-ticker.C:
55 lastWriteLk.Lock()
56 lw := lastWrite
57 lastWriteLk.Unlock()
58
59 if time.Since(lw) < 30*time.Second {
60 continue
61 }
62
63 if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil {
64 s.logger.Error("failed to ping client", "err", err)
65 cancel()
66 return
67 }
68 case <-ctx.Done():
69 return
70 }
71 }
72 }()
73
74 conn.SetPingHandler(func(message string) error {
75 err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60))
76 if err == websocket.ErrCloseSent {
77 return nil
78 } else if e, ok := err.(net.Error); ok && e.Temporary() {
79 return nil
80 }
81 return err
82 })
83
84 // Start a goroutine to read messages from the client and discard them.
85 go func() {
86 for {
87 _, _, err := conn.ReadMessage()
88 if err != nil {
89 s.logger.Error("failed to read message from client", "err", err)
90 cancel()
91 return
92 }
93 }
94 }()
95
96 ident := c.RealIP() + "-" + c.Request().UserAgent()
97
98 evts, cleanup, err := s.events.Subscribe(ctx, ident, func(evt *events.XRPCStreamEvent) bool { return true }, since)
99 if err != nil {
100 return err
101 }
102 defer cleanup()
103
104 // Keep track of the consumer for metrics and admin endpoints
105 consumer := SocketConsumer{
106 RemoteAddr: c.RealIP(),
107 UserAgent: c.Request().UserAgent(),
108 ConnectedAt: time.Now(),
109 }
110 sentCounter := eventsSentCounter.WithLabelValues(consumer.RemoteAddr, consumer.UserAgent)
111 consumer.EventsSent = sentCounter
112
113 consumerID := s.registerConsumer(&consumer)
114 defer s.cleanupConsumer(consumerID)
115
116 s.logger.Info("new consumer",
117 "remote_addr", consumer.RemoteAddr,
118 "user_agent", consumer.UserAgent,
119 "cursor", since,
120 "consumer_id", consumerID,
121 )
122 activeClientGauge.Inc()
123 defer activeClientGauge.Dec()
124
125 for {
126 select {
127 case evt, ok := <-evts:
128 if !ok {
129 s.logger.Error("event stream closed unexpectedly")
130 return nil
131 }
132
133 wc, err := conn.NextWriter(websocket.BinaryMessage)
134 if err != nil {
135 s.logger.Error("failed to get next writer", "err", err)
136 return err
137 }
138
139 if evt.Preserialized != nil {
140 _, err = wc.Write(evt.Preserialized)
141 } else {
142 err = evt.Serialize(wc)
143 }
144 if err != nil {
145 return fmt.Errorf("failed to write event: %w", err)
146 }
147
148 if err := wc.Close(); err != nil {
149 s.logger.Warn("failed to flush-close our event write", "err", err)
150 return nil
151 }
152
153 lastWriteLk.Lock()
154 lastWrite = time.Now()
155 lastWriteLk.Unlock()
156 sentCounter.Inc()
157 case <-ctx.Done():
158 return nil
159 }
160 }
161}
162
163type SocketConsumer struct {
164 UserAgent string
165 RemoteAddr string
166 ConnectedAt time.Time
167 EventsSent prometheus.Counter
168}
169
170func (s *Splitter) registerConsumer(c *SocketConsumer) uint64 {
171 s.consumersLk.Lock()
172 defer s.consumersLk.Unlock()
173
174 id := s.nextConsumerID
175 s.nextConsumerID++
176
177 s.consumers[id] = c
178
179 return id
180}
181
182func (s *Splitter) cleanupConsumer(id uint64) {
183 s.consumersLk.Lock()
184 defer s.consumersLk.Unlock()
185
186 c := s.consumers[id]
187
188 var m = &dto.Metric{}
189 if err := c.EventsSent.Write(m); err != nil {
190 s.logger.Error("failed to get sent counter", "err", err)
191 }
192
193 s.logger.Info("consumer disconnected",
194 "consumer_id", id,
195 "remote_addr", c.RemoteAddr,
196 "user_agent", c.UserAgent,
197 "events_sent", m.Counter.GetValue())
198
199 delete(s.consumers, id)
200}