1package server
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "log/slog"
8 "math"
9 "net/http"
10 "net/url"
11 "os"
12 "os/signal"
13 "strings"
14 "syscall"
15 "time"
16
17 "github.com/bluesky-social/indigo/api/bsky"
18 "github.com/bluesky-social/indigo/events"
19 "github.com/bluesky-social/indigo/events/schedulers/parallel"
20 "github.com/bluesky-social/indigo/repo"
21 "github.com/gorilla/websocket"
22 "github.com/qdrant/go-client/qdrant"
23)
24
25const (
26 pdqHashSize = 256
27)
28
29type Server struct {
30 logger *slog.Logger
31
32 dbClient *qdrant.Client
33 httpClient *http.Client
34
35 retinaHost string
36 websocketHost string
37
38 qdrantCollection string
39
40 minEuclidianDistance float32
41
42 maxSearchTime time.Duration
43 maxLimit int
44 seenThreshold int
45}
46
47type Args struct {
48 Logger *slog.Logger
49 RetinaHost string
50 WebsocketHost string
51 MaxSearchTime time.Duration
52 MaxLimit int
53 SeenThreshold int
54 MaxHammingDistance float64
55
56 QdrantHost string
57 QdrantPort int
58 QdrantColletion string
59}
60
61func New(ctx context.Context, args *Args) (*Server, error) {
62 ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
63 defer cancel()
64
65 if args.Logger != nil {
66 args.Logger = slog.Default()
67 }
68
69 qdrantConfig := qdrant.Config{
70 Host: args.QdrantHost,
71 Port: args.QdrantPort,
72 }
73 dbClient, err := qdrant.NewClient(&qdrantConfig)
74 if err != nil {
75 return nil, fmt.Errorf("failed to create qdrant client: %w", err)
76 }
77
78 exists, err := dbClient.CollectionExists(ctx, args.QdrantColletion)
79 if err != nil {
80 return nil, fmt.Errorf("failed to check if collection exists: %w", err)
81 }
82
83 if !exists {
84 args.Logger.Info("collection does not exist, creating", "collection", args.QdrantColletion)
85
86 if err := dbClient.CreateCollection(ctx, &qdrant.CreateCollection{
87 CollectionName: args.QdrantColletion,
88 VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{
89 Size: pdqHashSize,
90 Distance: qdrant.Distance_Euclid,
91 }),
92 }); err != nil {
93 return nil, fmt.Errorf("failed to create flagged image collection: %w", err)
94 }
95
96 if _, err := dbClient.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{
97 CollectionName: args.QdrantColletion,
98 FieldName: "timestamp",
99 FieldType: qdrant.PtrOf(qdrant.FieldType_FieldTypeInteger),
100 }); err != nil {
101 return nil, fmt.Errorf("failed to create flagged image collection timestamp index: %w", err)
102 }
103
104 args.Logger.Info("Successfully created flagged image vector collection in Qdrant")
105
106 } else {
107 args.Logger.Info("collection already exists", "collection", args.QdrantColletion)
108 }
109
110 server := Server{
111 retinaHost: args.RetinaHost,
112 websocketHost: args.WebsocketHost,
113 qdrantCollection: args.QdrantColletion,
114 maxSearchTime: args.MaxSearchTime,
115 maxLimit: args.MaxLimit,
116 seenThreshold: args.SeenThreshold,
117
118 // PDQ distance should be measured with hamming distance. Because Qdrant doesn't support this (?)
119 // we'll use the sqrt of the given hamming distance as the euclidean distance we search for
120 minEuclidianDistance: float32(math.Sqrt(args.MaxHammingDistance)),
121
122 logger: args.Logger,
123
124 dbClient: dbClient,
125 httpClient: &http.Client{
126 Timeout: 3 * time.Second,
127 },
128 }
129
130 return &server, nil
131}
132
133func (s *Server) Run(ctx context.Context) error {
134 logger := s.logger.With("name", "Run")
135
136 wsDialer := websocket.DefaultDialer
137 u, err := url.Parse(s.websocketHost + "/xrpc/com.atproto.sync.subscribeRepos")
138 if err != nil {
139 return fmt.Errorf("failed to parse websocket host: %w", err)
140 }
141
142 // run the consumer in a goroutine and wait for close
143 shutdownConsumer := make(chan struct{}, 1)
144 consumerShutdown := make(chan struct{}, 1)
145 consumerErr := make(chan error, 1)
146 go func() {
147 logger := s.logger.With("component", "consumer")
148
149 logger.Info("subscribing to repo event stream", "url", u.String())
150
151 // dial the websocket
152 conn, _, err := wsDialer.Dial(u.String(), http.Header{
153 "User-Agent": []string{"bloblens/0.0.0"},
154 })
155 if err != nil {
156 logger.Error("error dialing websocket", "err", err)
157 close(shutdownConsumer)
158 return
159 }
160
161 // setup a new event scheduler
162 parallelism := 400
163
164 scheduler := parallel.NewScheduler(parallelism, 1000, s.websocketHost, s.handleEvent)
165
166 // run the consumer and wait for it to be shut down
167 go func() {
168 if err := events.HandleRepoStream(ctx, conn, scheduler, logger); err != nil {
169 logger.Error("error handling repo stream", "err", err)
170 consumerErr <- err
171 return
172 }
173
174 consumerErr <- nil
175 }()
176
177 select {
178 case <-shutdownConsumer:
179 case err := <-consumerErr:
180 if err != nil {
181 logger.Error("consumer error", "err", err)
182 }
183 }
184
185 if err := conn.Close(); err != nil {
186 logger.Error("error closing websocket", "err", err)
187 } else {
188 logger.Info("websocket closed")
189 }
190
191 close(consumerShutdown)
192 }()
193
194 signals := make(chan os.Signal, 1)
195 signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT)
196
197 // wait for any of the following to arise
198 select {
199 case sig := <-signals:
200 logger.Info("received exit signal", "signal", sig)
201 close(shutdownConsumer)
202 case <-ctx.Done():
203 logger.Info("main context cancelled")
204 close(shutdownConsumer)
205 case <-consumerShutdown:
206 logger.Warn("consumer shutdown unexpectedly, forcing exit")
207 }
208
209 select {
210 case <-consumerShutdown:
211 case <-time.After(5 * time.Second):
212 logger.Warn("websocket did not shut down within five seconds, forcefully shutting down")
213 }
214
215 logger.Info("shutdown successfully")
216
217 return nil
218}
219
220func (s *Server) handleEvent(ctx context.Context, evt *events.XRPCStreamEvent) error {
221 logger := s.logger.With("name", "handleEvent")
222
223 if evt.RepoCommit == nil {
224 return nil
225 }
226
227 logger = logger.With("did", evt.RepoCommit.Repo)
228
229 rr, err := repo.ReadRepoFromCar(ctx, bytes.NewReader(evt.RepoCommit.Blocks))
230 if err != nil {
231 logger.Error("failed to read repo from car", "did", evt.RepoCommit.Repo, "err", err)
232 return nil
233 }
234
235 for _, op := range evt.RepoCommit.Ops {
236 if op.Action != "create" && op.Action != "update" {
237 continue
238 }
239
240 pts := strings.Split(op.Path, "/")
241 if len(pts) != 2 {
242 continue
243 }
244
245 collection := pts[0]
246
247 if collection != "app.bsky.actor.profile" {
248 continue
249 }
250
251 rcid, recB, err := rr.GetRecordBytes(ctx, op.Path)
252 if err != nil {
253 logger.Error("failed to read record bytes", "err", err)
254 continue
255 }
256
257 recCid := rcid.String()
258 if recCid != op.Cid.String() {
259 logger.Error("record cid mismatch", "expected", *op.Cid, "actual", recCid)
260 continue
261 }
262
263 var profile bsky.ActorProfile
264 profile.UnmarshalCBOR(bytes.NewReader(*recB))
265
266 go func() {
267 s.handleProfile(ctx, evt.RepoCommit.Repo, &profile)
268 }()
269 }
270
271 return nil
272}