1package atkafka
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "net/http"
10 "net/url"
11 "os"
12 "os/signal"
13 "strings"
14 "sync"
15 "syscall"
16 "time"
17
18 "github.com/bluesky-social/indigo/api/bsky"
19 "github.com/bluesky-social/indigo/atproto/atdata"
20 "github.com/bluesky-social/indigo/atproto/identity"
21 "github.com/bluesky-social/indigo/events"
22 "github.com/bluesky-social/indigo/events/schedulers/parallel"
23 "github.com/bluesky-social/indigo/repo"
24 "github.com/bluesky-social/indigo/repomgr"
25 "github.com/gorilla/websocket"
26 "github.com/twmb/franz-go/pkg/kgo"
27)
28
29type Server struct {
30 relayHost string
31 tapHost string
32 tapWorkers int
33 disableAcks bool
34 bootstrapServers []string
35 outputTopic string
36 ospreyCompat bool
37
38 watchedServices []string
39 ignoredServices []string
40
41 watchedCollections []string
42 ignoredCollections []string
43
44 producer *Producer
45 plcClient *PlcClient
46 apiClient *ApiClient
47 logger *slog.Logger
48 ws *websocket.Conn
49 ackQueue chan uint
50}
51
52type ServerArgs struct {
53 // network params
54 RelayHost string
55 TapHost string
56 TapWorkers int
57 DisableAcks bool
58 PlcHost string
59 ApiHost string
60
61 // for watched and ignoed services or collections, only one list may be supplied
62 // for both services and collections, wildcards are acceptable. for example:
63 // app.bsky.* will watch/ignore any collection that falls under the app.bsky namespace.
64 // *.bsky.network will watch/ignore any event that falls under the bsky.network list of PDSes
65
66 // list of services that are events will be emitted for
67 WatchedServices []string
68 // list of services that events are ignored for
69 IgnoredServices []string
70
71 // list of collections that events are emitted for
72 WatchedCollections []string
73 // list of collections that events are ignored for
74 IgnoredCollections []string
75
76 // kafka params
77 BootstrapServers []string
78 OutputTopic string
79
80 // osprey-specific params
81 OspreyCompat bool
82
83 // other
84 Logger *slog.Logger
85}
86
87func NewServer(args *ServerArgs) (*Server, error) {
88 if args.Logger == nil {
89 args.Logger = slog.Default()
90 }
91
92 if len(args.WatchedServices) > 0 && len(args.IgnoredServices) > 0 {
93 return nil, fmt.Errorf("you may only specify a list of watched services _or_ ignored services, not both")
94 }
95
96 if (len(args.WatchedServices) > 0 || len(args.IgnoredServices) > 0) && args.PlcHost == "" {
97 return nil, fmt.Errorf("unable to support watched/ignored services without specifying a PLC host")
98 }
99
100 if len(args.WatchedCollections) > 0 && len(args.IgnoredCollections) > 0 {
101 return nil, fmt.Errorf("you may only specify a list of watched collections _or_ ignored collections, not both")
102 }
103
104 var plcClient *PlcClient
105 if args.PlcHost != "" {
106 plcClient = NewPlcClient(&PlcClientArgs{
107 PlcHost: args.PlcHost,
108 })
109 }
110
111 var apiClient *ApiClient
112 if args.ApiHost != "" {
113 var err error
114 apiClient, err = NewApiClient(&ApiClientArgs{
115 ApiHost: args.ApiHost,
116 })
117 if err != nil {
118 return nil, fmt.Errorf("failed to create new api client: %w", err)
119 }
120 }
121
122 s := &Server{
123 relayHost: args.RelayHost,
124 tapHost: args.TapHost,
125 tapWorkers: args.TapWorkers,
126 disableAcks: args.DisableAcks,
127 plcClient: plcClient,
128 apiClient: apiClient,
129 bootstrapServers: args.BootstrapServers,
130 outputTopic: args.OutputTopic,
131 ospreyCompat: args.OspreyCompat,
132 logger: args.Logger,
133 }
134
135 if len(args.WatchedServices) > 0 {
136 watchedServices := make([]string, 0, len(args.WatchedServices))
137 for _, service := range args.WatchedServices {
138 watchedServices = append(watchedServices, strings.TrimPrefix(strings.TrimPrefix(service, "*."), "."))
139 }
140 s.watchedServices = watchedServices
141 } else if len(args.IgnoredServices) > 0 {
142 ignoredServices := make([]string, 0, len(args.IgnoredServices))
143 for _, service := range args.IgnoredServices {
144 ignoredServices = append(ignoredServices, strings.TrimPrefix(strings.TrimPrefix(service, "*."), "."))
145 }
146 s.ignoredServices = ignoredServices
147 }
148
149 if len(args.WatchedCollections) > 0 {
150 watchedCollections := make([]string, 0, len(args.WatchedCollections))
151 for _, collection := range args.WatchedCollections {
152 watchedCollections = append(watchedCollections, strings.TrimSuffix(strings.TrimSuffix(collection, ".*"), "."))
153 }
154 s.watchedCollections = watchedCollections
155 } else if len(args.IgnoredCollections) > 0 {
156 ignoredCollections := make([]string, 0, len(args.IgnoredCollections))
157 for _, collection := range args.IgnoredCollections {
158 ignoredCollections = append(ignoredCollections, strings.TrimSuffix(strings.TrimSuffix(collection, ".*"), "."))
159 }
160 s.ignoredCollections = ignoredCollections
161 }
162
163 return s, nil
164}
165
166func (s *Server) Run(ctx context.Context) error {
167 s.logger.Info("starting consumer", "relay-host", s.relayHost, "bootstrap-servers", s.bootstrapServers, "output-topic", s.outputTopic)
168
169 createCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
170 defer cancel()
171
172 producerLogger := s.logger.With("component", "producer")
173 kafProducer, err := NewProducer(createCtx, producerLogger, s.bootstrapServers, s.outputTopic,
174 WithEnsureTopic(true),
175 WithTopicPartitions(200),
176 )
177 if err != nil {
178 return fmt.Errorf("failed to create producer: %w", err)
179 }
180 defer kafProducer.Close()
181 s.producer = kafProducer
182 s.logger.Info("created producer")
183
184 wsDialer := websocket.DefaultDialer
185 u, err := url.Parse(s.relayHost)
186 if err != nil {
187 return fmt.Errorf("invalid relayHost: %w", err)
188 }
189 u.Path = "/xrpc/com.atproto.sync.subscribeRepos"
190 s.logger.Info("created dialer")
191
192 wsErr := make(chan error, 1)
193 shutdownWs := make(chan struct{}, 1)
194 go func() {
195 logger := s.logger.With("component", "websocket")
196
197 logger.Info("subscribing to repo event stream", "upstream", s.relayHost)
198
199 conn, _, err := wsDialer.Dial(u.String(), http.Header{
200 "User-Agent": []string{"at-kafka/0.0.0"},
201 })
202 if err != nil {
203 wsErr <- err
204 return
205 }
206
207 parallelism := 400
208 scheduler := parallel.NewScheduler(parallelism, 1000, s.relayHost, s.handleEvent)
209 defer scheduler.Shutdown()
210
211 logger.Info("firehose scheduler configured", "parallelism", parallelism)
212
213 go func() {
214 if err := events.HandleRepoStream(ctx, conn, scheduler, logger); err != nil {
215 wsErr <- err
216 return
217 }
218 }()
219
220 <-shutdownWs
221
222 wsErr <- nil
223 }()
224 s.logger.Info("created relay consumer")
225
226 signals := make(chan os.Signal, 1)
227 signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
228
229 select {
230 case sig := <-signals:
231 s.logger.Info("shutting down on signal", "signal", sig)
232 case err := <-wsErr:
233 if err != nil {
234 s.logger.Error("websocket error", "err", err)
235 } else {
236 s.logger.Info("websocket shutdown unexpectedly")
237 }
238 }
239
240 close(shutdownWs)
241
242 return nil
243}
244
245func (s *Server) FetchEventMetadata(ctx context.Context, did string) (*EventMetadata, *identity.Identity, error) {
246 var ident *identity.Identity
247 var didDocument identity.DIDDocument
248 var pdsHost string
249 var handle string
250 var didCreatedAt string
251 accountAge := int64(-1)
252 var profile *bsky.ActorDefs_ProfileViewDetailed
253
254 var wg sync.WaitGroup
255
256 if s.plcClient != nil {
257 wg.Go(func() {
258 logger := s.logger.With("component", "didDoc")
259 var err error
260 ident, err = s.plcClient.GetIdentity(ctx, did)
261 if err != nil {
262 logger.Error("error fetching did doc", "did", did, "err", err)
263 return
264 }
265 didDocument = ident.DIDDocument()
266 pdsHost = ident.PDSEndpoint()
267 handle = ident.Handle.String()
268 })
269
270 wg.Go(func() {
271 logger := s.logger.With("component", "auditLog")
272 auditLog, err := s.plcClient.GetDIDAuditLog(ctx, did)
273 if err != nil {
274 logger.Error("error fetching did audit log", "did", did, "err", err)
275 return
276 }
277
278 didCreatedAt = auditLog.CreatedAt
279
280 createdAt, err := time.Parse(time.RFC3339Nano, auditLog.CreatedAt)
281 if err != nil {
282 logger.Error("error parsing timestamp in audit log", "did", did, "timestamp", auditLog.CreatedAt, "err", err)
283 return
284 }
285
286 accountAge = int64(time.Since(createdAt).Seconds())
287 })
288 }
289
290 if s.apiClient != nil {
291 wg.Go(func() {
292 logger := s.logger.With("component", "profile")
293 var err error
294 profile, err = s.apiClient.GetProfile(ctx, did)
295 if err != nil {
296 logger.Error("error getting actor profile", "did", did, "err", err)
297 return
298 }
299 })
300 }
301
302 wg.Wait()
303
304 return &EventMetadata{
305 DidDocument: didDocument,
306 PdsHost: pdsHost,
307 Handle: handle,
308 DidCreatedAt: didCreatedAt,
309 AccountAge: accountAge,
310 Profile: profile,
311 }, ident, nil
312}
313
314func (s *Server) handleEvent(ctx context.Context, evt *events.XRPCStreamEvent) error {
315 dispatchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
316 defer cancel()
317
318 logger := s.logger.With("component", "handleEvent")
319 logger.Debug("event", "seq", evt.Sequence())
320
321 var collection string
322 var actionName string
323
324 var evtKey string
325 var evtsToProduce [][]byte
326
327 if evt.RepoCommit != nil {
328 // key events by DID
329 evtKey = evt.RepoCommit.Repo
330
331 // read the repo
332 rr, err := repo.ReadRepoFromCar(ctx, bytes.NewReader(evt.RepoCommit.Blocks))
333 if err != nil {
334 logger.Error("failed to read repo from car", "error", err)
335 return nil
336 }
337
338 eventMetadata, ident, err := s.FetchEventMetadata(dispatchCtx, evt.RepoCommit.Repo)
339 if err != nil {
340 logger.Error("error fetching event metadata", "err", err)
341 } else if ident != nil {
342 skip := false
343 pdsEndpoint := ident.PDSEndpoint()
344 u, err := url.Parse(pdsEndpoint)
345 if err != nil {
346 return fmt.Errorf("failed to parse pds host: %w", err)
347 }
348 pdsHost := u.Hostname()
349
350 if pdsHost != "" {
351 if len(s.watchedServices) > 0 {
352 skip = true
353 for _, watchedService := range s.watchedServices {
354 if watchedService == pdsHost || strings.HasSuffix(pdsHost, "."+watchedService) {
355 skip = false
356 break
357 }
358 }
359 } else if len(s.ignoredServices) > 0 {
360 for _, ignoredService := range s.ignoredServices {
361 if ignoredService == pdsHost || strings.HasSuffix(pdsHost, "."+ignoredService) {
362 skip = true
363 break
364 }
365 }
366 }
367 }
368
369 if skip {
370 logger.Debug("skipping event based on pds host", "pdsHost", pdsHost)
371 return nil
372 }
373 }
374
375 for _, op := range evt.RepoCommit.Ops {
376 kind := repomgr.EventKind(op.Action)
377 collection = strings.Split(op.Path, "/")[0]
378 rkey := strings.Split(op.Path, "/")[1]
379 did := evt.RepoCommit.Repo
380 atUri := fmt.Sprintf("at://%s/%s/%s", did, collection, rkey)
381
382 // bust the profile cache whenever it gets updated
383 // we won't worry about the counts of i.e. followers, since we have carveouts for some of these already anyway
384 if collection == "app.bsky.actor.profile" {
385 if s.apiClient != nil {
386 s.apiClient.BustProfileCache(did)
387 }
388 }
389
390 skip := false
391 if len(s.watchedCollections) > 0 {
392 skip = true
393 for _, watchedCollection := range s.watchedCollections {
394 if watchedCollection == collection || strings.HasPrefix(collection, watchedCollection+".") {
395 skip = false
396 break
397 }
398 }
399 } else if len(s.ignoredCollections) > 0 {
400 for _, ignoredCollection := range s.ignoredCollections {
401 if ignoredCollection == collection || strings.HasPrefix(collection, ignoredCollection+".") {
402 skip = true
403 break
404 }
405 }
406 }
407
408 if skip {
409 logger.Debug("skipping event based on collection", "collection", collection)
410 continue
411 }
412
413 kindStr := "create"
414 switch kind {
415 case repomgr.EvtKindUpdateRecord:
416 kindStr = "update"
417 case repomgr.EvtKindDeleteRecord:
418 kindStr = "delete"
419 }
420 actionName = "operation#" + kindStr
421
422 handledEvents.WithLabelValues(actionName, collection).Inc()
423
424 var rec map[string]any
425 var recCid string
426 if kind == repomgr.EvtKindCreateRecord || kind == repomgr.EvtKindUpdateRecord {
427 rcid, recB, err := rr.GetRecordBytes(ctx, op.Path)
428 if err != nil {
429 logger.Error("failed to get record bytes", "error", err)
430 continue
431 }
432
433 // verify the cids match
434 recCid = rcid.String()
435 if recCid != op.Cid.String() {
436 logger.Error("record cid mismatch", "expected", *op.Cid, "actual", rcid)
437 continue
438 }
439
440 // unmarshal the cbor into a map[string]any
441 maybeRec, err := atdata.UnmarshalCBOR(*recB)
442 if err != nil {
443 logger.Error("failed to unmarshal record", "error", err)
444 continue
445 }
446 rec = maybeRec
447 }
448
449 // create the formatted operation
450 atkOp := AtKafkaOp{
451 Action: op.Action,
452 Collection: collection,
453 Rkey: rkey,
454 Uri: atUri,
455 Cid: recCid,
456 Path: op.Path,
457 Record: rec,
458 }
459
460 // create the evt to put on kafka, regardless of if we are using osprey or not
461 kafkaEvt := AtKafkaEvent{
462 Did: did,
463 Timestamp: evt.RepoCommit.Time,
464 Operation: &atkOp,
465 }
466
467 if eventMetadata != nil {
468 kafkaEvt.Metadata = eventMetadata
469 }
470
471 var evtBytes []byte
472 if s.ospreyCompat {
473 // create the wrapper event for osprey
474 ospreyKafkaEvent := OspreyAtKafkaEvent{
475 Data: OspreyEventData{
476 ActionName: actionName,
477 ActionId: time.Now().UnixNano(), // TODO: this should be a snowflake
478 Data: kafkaEvt,
479 Timestamp: evt.RepoCommit.Time,
480 SecretData: map[string]string{},
481 Encoding: "UTF8",
482 },
483 SendTime: time.Now().Format(time.RFC3339),
484 }
485
486 evtBytes, err = json.Marshal(&ospreyKafkaEvent)
487 } else {
488 evtBytes, err = json.Marshal(&kafkaEvt)
489 }
490 if err != nil {
491 return fmt.Errorf("failed to marshal kafka event: %w", err)
492 }
493
494 evtsToProduce = append(evtsToProduce, evtBytes)
495 }
496 } else {
497 defer func() {
498 handledEvents.WithLabelValues(actionName, "").Inc()
499 }()
500
501 // start with a kafka event and an action name
502 var kafkaEvt AtKafkaEvent
503 var timestamp string
504 var did string
505
506 if evt.RepoAccount != nil {
507 actionName = "account"
508 timestamp = evt.RepoAccount.Time
509 did = evt.RepoAccount.Did
510
511 kafkaEvt = AtKafkaEvent{
512 Did: evt.RepoAccount.Did,
513 Timestamp: evt.RepoAccount.Time,
514 Account: &AtKafkaAccount{
515 Active: evt.RepoAccount.Active,
516 Seq: evt.RepoAccount.Seq,
517 Status: evt.RepoAccount.Status,
518 },
519 }
520 } else if evt.RepoIdentity != nil {
521 actionName = "identity"
522 timestamp = evt.RepoIdentity.Time
523 did = evt.RepoIdentity.Did
524
525 var handle string
526 if evt.RepoIdentity.Handle != nil {
527 handle = *evt.RepoIdentity.Handle
528 }
529
530 kafkaEvt = AtKafkaEvent{
531 Did: evt.RepoIdentity.Did,
532 Timestamp: evt.RepoIdentity.Time,
533 Identity: &AtKafkaIdentity{
534 Seq: evt.RepoIdentity.Seq,
535 Handle: handle,
536 },
537 }
538 } else if evt.RepoInfo != nil {
539 actionName = "info"
540 timestamp = time.Now().Format(time.RFC3339Nano)
541
542 kafkaEvt = AtKafkaEvent{
543 Info: &AtKafkaInfo{
544 Name: evt.RepoInfo.Name,
545 Message: evt.RepoInfo.Message,
546 },
547 }
548 } else {
549 return fmt.Errorf("unhandled event received")
550 }
551
552 if did != "" {
553 // key events by DID
554 evtKey = did
555 eventMetadata, ident, err := s.FetchEventMetadata(dispatchCtx, did)
556 if err != nil {
557 logger.Error("error fetching event metadata", "err", err)
558 } else if ident != nil {
559 skip := false
560 pdsEndpoint := ident.PDSEndpoint()
561 u, err := url.Parse(pdsEndpoint)
562 if err != nil {
563 return fmt.Errorf("failed to parse pds host: %w", err)
564 }
565 pdsHost := u.Hostname()
566
567 if pdsHost != "" {
568 if len(s.watchedServices) > 0 {
569 skip = true
570 for _, watchedService := range s.watchedServices {
571 if watchedService == pdsHost || strings.HasSuffix(pdsHost, "."+watchedService) {
572 skip = false
573 break
574 }
575 }
576 } else if len(s.ignoredServices) > 0 {
577 for _, ignoredService := range s.ignoredServices {
578 if ignoredService == pdsHost || strings.HasSuffix(pdsHost, "."+ignoredService) {
579 skip = true
580 break
581 }
582 }
583 }
584 }
585
586 if skip {
587 logger.Debug("skipping event based on pds host", "pdsHost", pdsHost)
588 return nil
589 }
590
591 kafkaEvt.Metadata = eventMetadata
592 }
593 } else {
594 // key events without a DID by "unknown"
595 evtKey = "<unknown>"
596 }
597
598 // create the kafka event bytes
599 var evtBytes []byte
600 var err error
601
602 if s.ospreyCompat {
603 // wrap the event in an osprey event
604 ospreyKafkaEvent := OspreyAtKafkaEvent{
605 Data: OspreyEventData{
606 ActionName: actionName,
607 ActionId: time.Now().UnixNano(), // TODO: this should be a snowflake
608 Data: kafkaEvt,
609 Timestamp: timestamp,
610 SecretData: map[string]string{},
611 Encoding: "UTF8",
612 },
613 SendTime: time.Now().Format(time.RFC3339),
614 }
615
616 evtBytes, err = json.Marshal(&ospreyKafkaEvent)
617 } else {
618 evtBytes, err = json.Marshal(&kafkaEvt)
619 }
620 if err != nil {
621 return fmt.Errorf("failed to marshal kafka event: %w", err)
622 }
623
624 evtsToProduce = append(evtsToProduce, evtBytes)
625 }
626
627 for _, evtBytes := range evtsToProduce {
628 if err := s.produceAsync(ctx, evtKey, evtBytes); err != nil {
629 return err
630 }
631 }
632
633 return nil
634}
635
636func (s *Server) produceAsync(ctx context.Context, key string, msg []byte) error {
637 callback := func(r *kgo.Record, err error) {
638 status := "ok"
639 if err != nil {
640 status = "error"
641 s.logger.Error("error producing message", "err", err)
642 }
643 producedEvents.WithLabelValues(status).Inc()
644 }
645
646 if err := s.producer.ProduceAsync(ctx, key, msg, callback); err != nil {
647 return fmt.Errorf("failed to produce message: %w", err)
648 }
649
650 return nil
651}