1package splitter
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "log/slog"
9 "math/rand"
10 "net"
11 "net/http"
12 "net/url"
13 "os"
14 "strconv"
15 "strings"
16 "sync"
17 "time"
18
19 "github.com/bluesky-social/indigo/events"
20 "github.com/bluesky-social/indigo/events/pebblepersist"
21 "github.com/bluesky-social/indigo/events/schedulers/sequential"
22 "github.com/bluesky-social/indigo/util"
23 "github.com/bluesky-social/indigo/util/svcutil"
24
25 "github.com/gorilla/websocket"
26 "github.com/labstack/echo/v4"
27 "github.com/labstack/echo/v4/middleware"
28 "github.com/prometheus/client_golang/prometheus/promhttp"
29)
30
31type Splitter struct {
32 erb *EventRingBuffer
33 pp *pebblepersist.PebblePersist
34 events *events.EventManager
35
36 // Management of Socket Consumers
37 consumersLk sync.RWMutex
38 nextConsumerID uint64
39 consumers map[uint64]*SocketConsumer
40
41 conf SplitterConfig
42
43 logger *slog.Logger
44
45 upstreamClient *http.Client
46 peerClient *http.Client
47 nextCrawlers []url.URL
48}
49
50type SplitterConfig struct {
51 UpstreamHost string
52 CollectionDirHost string
53 CursorFile string
54 UserAgent string
55 PebbleOptions *pebblepersist.PebblePersistOptions
56 Logger *slog.Logger
57}
58
59func (sc *SplitterConfig) UpstreamHostWebsocket() string {
60
61 if !strings.Contains(sc.UpstreamHost, "://") {
62 return "wss://" + sc.UpstreamHost
63 }
64 u, err := url.Parse(sc.UpstreamHost)
65 if err != nil {
66 // this will cause an error downstream
67 return ""
68 }
69
70 switch u.Scheme {
71 case "http", "ws":
72 return "ws://" + u.Host
73 case "https", "wss":
74 return "wss://" + u.Host
75 default:
76 return "wss://" + u.Host
77 }
78}
79
80func (sc *SplitterConfig) UpstreamHostHTTP() string {
81
82 if !strings.Contains(sc.UpstreamHost, "://") {
83 return "https://" + sc.UpstreamHost
84 }
85 u, err := url.Parse(sc.UpstreamHost)
86 if err != nil {
87 // this will cause an error downstream
88 return ""
89 }
90
91 switch u.Scheme {
92 case "http", "ws":
93 return "http://" + u.Host
94 case "https", "wss":
95 return "https://" + u.Host
96 default:
97 return "https://" + u.Host
98 }
99}
100
101func NewSplitter(conf SplitterConfig, nextCrawlers []string) (*Splitter, error) {
102
103 logger := conf.Logger
104 if logger == nil {
105 logger = slog.Default().With("system", "splitter")
106 }
107
108 var nextCrawlerURLs []url.URL
109 for _, raw := range nextCrawlers {
110 if raw == "" {
111 continue
112 }
113 u, err := url.Parse(raw)
114 if err != nil {
115 return nil, fmt.Errorf("failed to parse next-crawler url: %w", err)
116 }
117 if u.Host == "" {
118 return nil, fmt.Errorf("empty URL host for next crawler: %s", raw)
119 }
120 nextCrawlerURLs = append(nextCrawlerURLs, *u)
121 }
122 if len(nextCrawlerURLs) > 0 {
123 logger.Info("configured crawler forwarding", "crawlers", nextCrawlerURLs)
124 }
125
126 _, err := url.Parse(conf.UpstreamHostHTTP())
127 if err != nil {
128 return nil, fmt.Errorf("failed to parse upstream url %#v: %w", conf.UpstreamHostHTTP(), err)
129 }
130
131 // generic HTTP client for upstream relay and collectiondr; but disable automatic following of redirects
132 upstreamClient := http.Client{
133 Timeout: 10 * time.Second,
134 CheckRedirect: func(req *http.Request, via []*http.Request) error {
135 return http.ErrUseLastResponse
136 },
137 }
138
139 s := &Splitter{
140 conf: conf,
141 consumers: make(map[uint64]*SocketConsumer),
142 logger: logger,
143 upstreamClient: &upstreamClient,
144 peerClient: util.RobustHTTPClient(),
145 nextCrawlers: nextCrawlerURLs,
146 }
147
148 if conf.PebbleOptions == nil {
149 // mem splitter
150 erb := NewEventRingBuffer(20_000, 10_000)
151 s.erb = erb
152 s.events = events.NewEventManager(erb)
153 } else {
154 pp, err := pebblepersist.NewPebblePersistance(conf.PebbleOptions)
155 if err != nil {
156 return nil, err
157 }
158 go pp.GCThread(context.Background())
159 s.pp = pp
160 s.events = events.NewEventManager(pp)
161 }
162
163 return s, nil
164}
165
166func (s *Splitter) StartAPI(addr string) error {
167 var lc net.ListenConfig
168 ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
169 defer cancel()
170
171 curs, err := s.getLastCursor()
172 if err != nil {
173 return fmt.Errorf("loading cursor failed: %w", err)
174 }
175
176 go s.subscribeWithRedialer(context.Background(), curs)
177
178 li, err := lc.Listen(ctx, "tcp", addr)
179 if err != nil {
180 return err
181 }
182 return s.startWithListener(li)
183}
184
185func (s *Splitter) StartMetrics(listen string) error {
186 http.Handle("/metrics", promhttp.Handler())
187 return http.ListenAndServe(listen, nil)
188}
189
190func (s *Splitter) Shutdown() error {
191 return nil
192}
193
194func (s *Splitter) startWithListener(listen net.Listener) error {
195 e := echo.New()
196 e.HideBanner = true
197
198 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
199 AllowOrigins: []string{"*"},
200 AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization},
201 }))
202
203 e.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
204 return func(c echo.Context) error {
205 c.Response().Header().Set(echo.HeaderServer, s.conf.UserAgent)
206 return next(c)
207 }
208 })
209
210 /*
211 if !s.ssl {
212 e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{
213 Format: "method=${method}, uri=${uri}, status=${status} latency=${latency_human}\n",
214 }))
215 } else {
216 e.Use(middleware.LoggerWithConfig(middleware.DefaultLoggerConfig))
217 }
218 */
219
220 e.Use(svcutil.MetricsMiddleware)
221 e.HTTPErrorHandler = s.errorHandler
222
223 if len(s.nextCrawlers) > 0 {
224 // forwards on to multiple hosts, but strips several headers (like User-Agent)
225 s.logger.Info("using legacy requestCrawl forwarding")
226 e.POST("/xrpc/com.atproto.sync.requestCrawl", s.HandleComAtprotoSyncRequestCrawl)
227 } else {
228 // simply proxies to upstream
229 e.POST("/xrpc/com.atproto.sync.requestCrawl", s.ProxyRequestUpstream)
230 }
231 e.GET("/xrpc/com.atproto.sync.subscribeRepos", s.HandleSubscribeRepos)
232
233 // proxy endpoints to upstream (relay)
234 e.GET("/xrpc/com.atproto.sync.listRepos", s.ProxyRequestUpstream)
235 e.GET("/xrpc/com.atproto.sync.getRepoStatus", s.ProxyRequestUpstream)
236 e.GET("/xrpc/com.atproto.sync.getLatestCommit", s.ProxyRequestUpstream)
237 e.GET("/xrpc/com.atproto.sync.listHosts", s.ProxyRequestUpstream)
238 e.GET("/xrpc/com.atproto.sync.getHostStatus", s.ProxyRequestUpstream)
239 e.GET("/xrpc/com.atproto.sync.getRepo", s.ProxyRequestUpstream)
240
241 // proxy relay admin endpoints for inter-relay synchronization
242 e.GET("/admin/subs/getUpstreamConns", s.ProxyRequestUpstream)
243 e.POST("/admin/subs/killUpstream", s.ProxyRequestUpstream)
244 e.GET("/admin/subs/getEnabled", s.ProxyRequestUpstream)
245 e.POST("/admin/subs/setEnabled", s.ProxyRequestUpstream)
246 e.GET("/admin/subs/perDayLimit", s.ProxyRequestUpstream)
247 e.POST("/admin/subs/setPerDayLimit", s.ProxyRequestUpstream)
248 e.GET("/admin/subs/listDomainBans", s.ProxyRequestUpstream)
249 e.POST("/admin/subs/banDomain", s.ProxyRequestUpstream)
250 e.POST("/admin/subs/unbanDomain", s.ProxyRequestUpstream)
251 e.POST("/admin/repo/takeDown", s.ProxyRequestUpstream)
252 e.POST("/admin/repo/reverseTakedown", s.ProxyRequestUpstream)
253 e.GET("/admin/pds/list", s.ProxyRequestUpstream)
254 e.POST("/admin/pds/requestCrawl", s.ProxyRequestUpstream)
255 e.POST("/admin/pds/changeLimits", s.ProxyRequestUpstream)
256 e.POST("/admin/pds/block", s.ProxyRequestUpstream)
257 e.POST("/admin/pds/unblock", s.ProxyRequestUpstream)
258 e.GET("/admin/consumers/list", s.ProxyRequestUpstream)
259
260 // proxy endpoint to collectiondir
261 e.GET("/xrpc/com.atproto.sync.listReposByCollection", s.ProxyRequestCollectionDir)
262
263 e.GET("/xrpc/_health", s.HandleHealthCheck)
264 e.GET("/_health", s.HandleHealthCheck)
265 e.GET("/", s.HandleHomeMessage)
266
267 // In order to support booting on random ports in tests, we need to tell the
268 // Echo instance it's already got a port, and then use its StartServer
269 // method to re-use that listener.
270 e.Listener = listen
271 srv := &http.Server{}
272 return e.StartServer(srv)
273}
274
275func (s *Splitter) errorHandler(err error, ctx echo.Context) {
276 switch err := err.(type) {
277 case *echo.HTTPError:
278 if err2 := ctx.JSON(err.Code, map[string]any{
279 "error": err.Message,
280 }); err2 != nil {
281 s.logger.Error("Failed to write http error", "err", err2)
282 }
283 default:
284 sendHeader := true
285 if ctx.Path() == "/xrpc/com.atproto.sync.subscribeRepos" {
286 sendHeader = false
287 }
288
289 s.logger.Warn("HANDLER ERROR", "path", ctx.Path(), "err", err)
290
291 if strings.HasPrefix(ctx.Path(), "/admin/") {
292 ctx.JSON(500, map[string]any{
293 "error": err.Error(),
294 })
295 return
296 }
297
298 if sendHeader {
299 ctx.Response().WriteHeader(500)
300 }
301 }
302}
303
304func (s *Splitter) getLastCursor() (int64, error) {
305 if s.pp != nil {
306 seq, millis, _, err := s.pp.GetLast(context.Background())
307 if err == nil {
308 s.logger.Debug("got last cursor from pebble", "seq", seq, "millis", millis)
309 return seq, nil
310 } else if errors.Is(err, pebblepersist.ErrNoLast) {
311 s.logger.Info("pebble no last")
312 } else {
313 s.logger.Error("pebble seq fail", "err", err)
314 }
315 }
316
317 fi, err := os.Open(s.conf.CursorFile)
318 if err != nil {
319 if os.IsNotExist(err) {
320 return -1, nil
321 }
322 return -1, err
323 }
324
325 b, err := io.ReadAll(fi)
326 if err != nil {
327 return -1, err
328 }
329
330 v, err := strconv.ParseInt(string(b), 10, 64)
331 if err != nil {
332 return -1, err
333 }
334
335 return v, nil
336}
337
338func (s *Splitter) writeCursor(curs int64) error {
339 return os.WriteFile(s.conf.CursorFile, []byte(fmt.Sprint(curs)), 0664)
340}
341
342func sleepForBackoff(b int) time.Duration {
343 if b == 0 {
344 return 0
345 }
346
347 if b < 50 {
348 return time.Millisecond * time.Duration(rand.Intn(100)+(5*b))
349 }
350
351 return time.Second * 5
352}
353
354func (s *Splitter) subscribeWithRedialer(ctx context.Context, cursor int64) {
355 d := websocket.Dialer{}
356
357 upstreamUrl, err := url.Parse(s.conf.UpstreamHostWebsocket())
358 if err != nil {
359 panic(err) // this should have been checked in NewSplitter
360 }
361 upstreamUrl = upstreamUrl.JoinPath("/xrpc/com.atproto.sync.subscribeRepos")
362
363 header := http.Header{
364 "User-Agent": []string{s.conf.UserAgent},
365 }
366
367 var backoff int
368 for {
369 select {
370 case <-ctx.Done():
371 return
372 default:
373 }
374
375 var uurl string
376 if cursor < 0 {
377 upstreamUrl.RawQuery = ""
378 uurl = upstreamUrl.String()
379 } else {
380 upstreamUrl.RawQuery = fmt.Sprintf("cursor=%d", cursor)
381 uurl = upstreamUrl.String()
382 }
383 con, res, err := d.DialContext(ctx, uurl, header)
384 if err != nil {
385 s.logger.Warn("dialing failed", "url", uurl, "err", err, "backoff", backoff)
386 time.Sleep(sleepForBackoff(backoff))
387 backoff++
388
389 continue
390 }
391
392 s.logger.Info("event subscription response", "code", res.StatusCode)
393
394 if err := s.handleUpstreamConnection(ctx, con, &cursor); err != nil {
395 s.logger.Warn("upstream connection failed", "url", uurl, "err", err)
396 }
397 }
398}
399
400func (s *Splitter) handleUpstreamConnection(ctx context.Context, con *websocket.Conn, lastCursor *int64) error {
401 ctx, cancel := context.WithCancel(ctx)
402 defer cancel()
403
404 sched := sequential.NewScheduler("splitter", func(ctx context.Context, evt *events.XRPCStreamEvent) error {
405 seq := events.SequenceForEvent(evt)
406 if seq < 0 {
407 // ignore info events and other unsupported types
408 return nil
409 }
410
411 if err := s.events.AddEvent(ctx, evt); err != nil {
412 return err
413 }
414
415 if seq%5000 == 0 {
416 // TODO: don't need this after we move to getting seq from pebble
417 if err := s.writeCursor(seq); err != nil {
418 s.logger.Error("write cursor failed", "err", err)
419 }
420 }
421
422 *lastCursor = seq
423 return nil
424 })
425
426 return events.HandleRepoStream(ctx, con, sched, nil)
427}