this repo has no description
at main 6.6 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "log/slog" 8 "math" 9 "net/http" 10 "net/url" 11 "os" 12 "os/signal" 13 "strings" 14 "syscall" 15 "time" 16 17 "github.com/bluesky-social/indigo/api/bsky" 18 "github.com/bluesky-social/indigo/events" 19 "github.com/bluesky-social/indigo/events/schedulers/parallel" 20 "github.com/bluesky-social/indigo/repo" 21 "github.com/gorilla/websocket" 22 "github.com/qdrant/go-client/qdrant" 23) 24 25const ( 26 pdqHashSize = 256 27) 28 29type Server struct { 30 logger *slog.Logger 31 32 dbClient *qdrant.Client 33 httpClient *http.Client 34 35 retinaHost string 36 websocketHost string 37 38 qdrantCollection string 39 40 minEuclidianDistance float32 41 42 maxSearchTime time.Duration 43 maxLimit int 44 seenThreshold int 45} 46 47type Args struct { 48 Logger *slog.Logger 49 RetinaHost string 50 WebsocketHost string 51 MaxSearchTime time.Duration 52 MaxLimit int 53 SeenThreshold int 54 MaxHammingDistance float64 55 56 QdrantHost string 57 QdrantPort int 58 QdrantColletion string 59} 60 61func New(ctx context.Context, args *Args) (*Server, error) { 62 ctx, cancel := context.WithTimeout(ctx, 10*time.Second) 63 defer cancel() 64 65 if args.Logger != nil { 66 args.Logger = slog.Default() 67 } 68 69 qdrantConfig := qdrant.Config{ 70 Host: args.QdrantHost, 71 Port: args.QdrantPort, 72 } 73 dbClient, err := qdrant.NewClient(&qdrantConfig) 74 if err != nil { 75 return nil, fmt.Errorf("failed to create qdrant client: %w", err) 76 } 77 78 exists, err := dbClient.CollectionExists(ctx, args.QdrantColletion) 79 if err != nil { 80 return nil, fmt.Errorf("failed to check if collection exists: %w", err) 81 } 82 83 if !exists { 84 args.Logger.Info("collection does not exist, creating", "collection", args.QdrantColletion) 85 86 if err := dbClient.CreateCollection(ctx, &qdrant.CreateCollection{ 87 CollectionName: args.QdrantColletion, 88 VectorsConfig: qdrant.NewVectorsConfig(&qdrant.VectorParams{ 89 Size: pdqHashSize, 90 Distance: qdrant.Distance_Euclid, 91 }), 92 }); err != nil { 93 return nil, fmt.Errorf("failed to create flagged image collection: %w", err) 94 } 95 96 if _, err := dbClient.CreateFieldIndex(ctx, &qdrant.CreateFieldIndexCollection{ 97 CollectionName: args.QdrantColletion, 98 FieldName: "timestamp", 99 FieldType: qdrant.PtrOf(qdrant.FieldType_FieldTypeInteger), 100 }); err != nil { 101 return nil, fmt.Errorf("failed to create flagged image collection timestamp index: %w", err) 102 } 103 104 args.Logger.Info("Successfully created flagged image vector collection in Qdrant") 105 106 } else { 107 args.Logger.Info("collection already exists", "collection", args.QdrantColletion) 108 } 109 110 server := Server{ 111 retinaHost: args.RetinaHost, 112 websocketHost: args.WebsocketHost, 113 qdrantCollection: args.QdrantColletion, 114 maxSearchTime: args.MaxSearchTime, 115 maxLimit: args.MaxLimit, 116 seenThreshold: args.SeenThreshold, 117 118 // PDQ distance should be measured with hamming distance. Because Qdrant doesn't support this (?) 119 // we'll use the sqrt of the given hamming distance as the euclidean distance we search for 120 minEuclidianDistance: float32(math.Sqrt(args.MaxHammingDistance)), 121 122 logger: args.Logger, 123 124 dbClient: dbClient, 125 httpClient: &http.Client{ 126 Timeout: 3 * time.Second, 127 }, 128 } 129 130 return &server, nil 131} 132 133func (s *Server) Run(ctx context.Context) error { 134 logger := s.logger.With("name", "Run") 135 136 wsDialer := websocket.DefaultDialer 137 u, err := url.Parse(s.websocketHost + "/xrpc/com.atproto.sync.subscribeRepos") 138 if err != nil { 139 return fmt.Errorf("failed to parse websocket host: %w", err) 140 } 141 142 // run the consumer in a goroutine and wait for close 143 shutdownConsumer := make(chan struct{}, 1) 144 consumerShutdown := make(chan struct{}, 1) 145 consumerErr := make(chan error, 1) 146 go func() { 147 logger := s.logger.With("component", "consumer") 148 149 logger.Info("subscribing to repo event stream", "url", u.String()) 150 151 // dial the websocket 152 conn, _, err := wsDialer.Dial(u.String(), http.Header{ 153 "User-Agent": []string{"bloblens/0.0.0"}, 154 }) 155 if err != nil { 156 logger.Error("error dialing websocket", "err", err) 157 close(shutdownConsumer) 158 return 159 } 160 161 // setup a new event scheduler 162 parallelism := 400 163 164 scheduler := parallel.NewScheduler(parallelism, 1000, s.websocketHost, s.handleEvent) 165 166 // run the consumer and wait for it to be shut down 167 go func() { 168 if err := events.HandleRepoStream(ctx, conn, scheduler, logger); err != nil { 169 logger.Error("error handling repo stream", "err", err) 170 consumerErr <- err 171 return 172 } 173 174 consumerErr <- nil 175 }() 176 177 select { 178 case <-shutdownConsumer: 179 case err := <-consumerErr: 180 if err != nil { 181 logger.Error("consumer error", "err", err) 182 } 183 } 184 185 if err := conn.Close(); err != nil { 186 logger.Error("error closing websocket", "err", err) 187 } else { 188 logger.Info("websocket closed") 189 } 190 191 close(consumerShutdown) 192 }() 193 194 signals := make(chan os.Signal, 1) 195 signal.Notify(signals, syscall.SIGTERM, syscall.SIGINT) 196 197 // wait for any of the following to arise 198 select { 199 case sig := <-signals: 200 logger.Info("received exit signal", "signal", sig) 201 close(shutdownConsumer) 202 case <-ctx.Done(): 203 logger.Info("main context cancelled") 204 close(shutdownConsumer) 205 case <-consumerShutdown: 206 logger.Warn("consumer shutdown unexpectedly, forcing exit") 207 } 208 209 select { 210 case <-consumerShutdown: 211 case <-time.After(5 * time.Second): 212 logger.Warn("websocket did not shut down within five seconds, forcefully shutting down") 213 } 214 215 logger.Info("shutdown successfully") 216 217 return nil 218} 219 220func (s *Server) handleEvent(ctx context.Context, evt *events.XRPCStreamEvent) error { 221 logger := s.logger.With("name", "handleEvent") 222 223 if evt.RepoCommit == nil { 224 return nil 225 } 226 227 logger = logger.With("did", evt.RepoCommit.Repo) 228 229 rr, err := repo.ReadRepoFromCar(ctx, bytes.NewReader(evt.RepoCommit.Blocks)) 230 if err != nil { 231 logger.Error("failed to read repo from car", "did", evt.RepoCommit.Repo, "err", err) 232 return nil 233 } 234 235 for _, op := range evt.RepoCommit.Ops { 236 if op.Action != "create" && op.Action != "update" { 237 continue 238 } 239 240 pts := strings.Split(op.Path, "/") 241 if len(pts) != 2 { 242 continue 243 } 244 245 collection := pts[0] 246 247 if collection != "app.bsky.actor.profile" { 248 continue 249 } 250 251 rcid, recB, err := rr.GetRecordBytes(ctx, op.Path) 252 if err != nil { 253 logger.Error("failed to read record bytes", "err", err) 254 continue 255 } 256 257 recCid := rcid.String() 258 if recCid != op.Cid.String() { 259 logger.Error("record cid mismatch", "expected", *op.Cid, "actual", recCid) 260 continue 261 } 262 263 var profile bsky.ActorProfile 264 profile.UnmarshalCBOR(bytes.NewReader(*recB)) 265 266 go func() { 267 s.handleProfile(ctx, evt.RepoCommit.Repo, &profile) 268 }() 269 } 270 271 return nil 272}