1package server
2
3import (
4 "context"
5 "encoding/hex"
6 "fmt"
7 "sync"
8 "time"
9
10 "github.com/bluesky-social/indigo/api/bsky"
11 "github.com/google/uuid"
12 "github.com/qdrant/go-client/qdrant"
13)
14
15func (s *Server) handleProfile(ctx context.Context, did string, profile *bsky.ActorProfile) {
16 logger := s.logger.With("name", "handleProfile", "did", did)
17 searchCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
18 defer cancel()
19
20 var avatarCid string
21 if profile.Avatar != nil {
22 avatarCid = profile.Avatar.Ref.String()
23 }
24
25 var bannerCid string
26 if profile.Banner != nil {
27 bannerCid = profile.Banner.Ref.String()
28 }
29
30 if avatarCid == "" && bannerCid == "" {
31 return
32 }
33
34 var avatarHash string
35 var bannerHash string
36 getHashResult := func(kind, cid string) {
37 logger := logger.With("kind", kind, "cid", cid)
38 hashResult, err := s.getImageHash(searchCtx, did, cid)
39 if err != nil {
40 logger.Error("failed to get hash result", "err", err)
41 return
42 }
43 if hashResult.QualityTooLow {
44 logger.Info("quality too low")
45 return
46 }
47 if hashResult.Hash == nil {
48 logger.Warn("nil hash")
49 return
50 }
51
52 switch kind {
53 case "avatar":
54 avatarHash = *hashResult.Hash
55 case "banner":
56 bannerHash = *hashResult.Hash
57 }
58 }
59
60 var wg sync.WaitGroup
61 if avatarCid != "" {
62 wg.Go(func() {
63 getHashResult("avatar", avatarCid)
64 })
65 }
66 if bannerCid != "" {
67 wg.Go(func() {
68 getHashResult("banner", bannerCid)
69 })
70 }
71 wg.Wait()
72
73 var avatarVector []float32
74 var bannerVector []float32
75 if avatarHash != "" {
76 vec, err := convertHash(avatarHash)
77 if err != nil {
78 logger.Error("error getting vector", "err", err)
79 } else {
80 avatarVector = vec
81 }
82 }
83 if bannerHash != "" {
84 vec, err := convertHash(bannerHash)
85 if err != nil {
86 logger.Error("error getting vector", "err", err)
87 } else {
88 bannerVector = vec
89 }
90 }
91
92 now := time.Now().UTC().UnixMilli()
93 nowF64 := float64(now)
94 startRange := float64(now - s.maxSearchTime.Milliseconds())
95
96 var similarAvatarDids []string
97 var similarBannerDids []string
98
99 if avatarVector != nil {
100 wg.Go(func() {
101 query := &qdrant.QueryPoints{
102 CollectionName: s.qdrantCollection,
103 Query: qdrant.NewQueryDense(avatarVector),
104 Filter: &qdrant.Filter{
105 Must: []*qdrant.Condition{
106 qdrant.NewRange("timestamp", &qdrant.Range{
107 Lte: &nowF64,
108 Gte: &startRange,
109 }),
110 qdrant.NewMatch("kind", "avatar"),
111 },
112 },
113 WithPayload: qdrant.NewWithPayload(true),
114 Limit: qdrant.PtrOf(uint64(s.maxLimit)),
115 ScoreThreshold: qdrant.PtrOf(s.minEuclidianDistance),
116 }
117
118 results, err := s.dbClient.Query(searchCtx, query)
119 if err != nil {
120 logger.Error("failed to search for vectors", "err", err)
121 return
122 }
123
124 seenDids := make(map[string]bool)
125 for _, r := range results {
126 rDidVal, ok := r.Payload["did"]
127 if !ok {
128 continue
129 }
130 rDid := rDidVal.GetStringValue()
131 if !seenDids[rDid] {
132 seenDids[rDid] = true
133 similarAvatarDids = append(similarAvatarDids, rDid)
134 }
135 }
136 })
137 }
138
139 if bannerVector != nil {
140 wg.Go(func() {
141 query := &qdrant.QueryPoints{
142 CollectionName: s.qdrantCollection,
143 Query: qdrant.NewQueryDense(bannerVector),
144 Filter: &qdrant.Filter{
145 Must: []*qdrant.Condition{
146 qdrant.NewRange("timestamp", &qdrant.Range{
147 Lte: &nowF64,
148 Gte: &startRange,
149 }),
150 qdrant.NewMatch("kind", "banner"),
151 },
152 },
153 WithPayload: qdrant.NewWithPayload(true),
154 Limit: qdrant.PtrOf(uint64(s.maxLimit)),
155 ScoreThreshold: qdrant.PtrOf(s.minEuclidianDistance),
156 }
157
158 results, err := s.dbClient.Query(searchCtx, query)
159 if err != nil {
160 logger.Error("failed to search for vectors", "err", err)
161 return
162 }
163
164 seenDids := make(map[string]bool)
165 for _, r := range results {
166 rDidVal, ok := r.Payload["did"]
167 if !ok {
168 continue
169 }
170 rDid := rDidVal.GetStringValue()
171 if !seenDids[rDid] {
172 seenDids[rDid] = true
173 similarBannerDids = append(similarBannerDids, rDid)
174 }
175 }
176 })
177 }
178
179 wg.Wait()
180
181 // inserts can happen in a different goroutine so we don't block returning these results
182 go func() {
183 insertCtx, cancel := context.WithTimeout(ctx, 10*time.Second)
184 defer cancel()
185
186 points := make([]*qdrant.PointStruct, 0, 2)
187 if avatarVector != nil {
188 point := &qdrant.PointStruct{
189 Id: qdrant.NewIDUUID(uuid.NewString()),
190 Vectors: qdrant.NewVectors(avatarVector...),
191 Payload: qdrant.NewValueMap(map[string]any{
192 "did": did,
193 "cid": avatarCid,
194 "kind": "avatar",
195 "timestamp": now,
196 }),
197 }
198 points = append(points, point)
199 }
200 if bannerVector != nil {
201 point := &qdrant.PointStruct{
202 Id: qdrant.NewIDUUID(uuid.NewString()),
203 Vectors: qdrant.NewVectors(bannerVector...),
204 Payload: qdrant.NewValueMap(map[string]any{
205 "did": did,
206 "cid": bannerCid,
207 "kind": "banner",
208 "timestamp": now,
209 }),
210 }
211 points = append(points, point)
212 }
213
214 if len(points) > 0 {
215 if _, err := s.dbClient.Upsert(insertCtx, &qdrant.UpsertPoints{
216 CollectionName: s.qdrantCollection,
217 Points: points,
218 }); err != nil {
219 logger.Error("failed to insert hashes", "err", err)
220 }
221 }
222 }()
223
224 if len(similarAvatarDids) >= s.seenThreshold {
225 logger = logger.With("kind", "avatar", "did", did, "cid", avatarCid, "hash", avatarHash)
226 logger.Info("found users with similar avatars within search duration", "dids", similarAvatarDids)
227 }
228 if len(similarBannerDids) >= s.seenThreshold {
229 logger = logger.With("kind", "banner", "did", did, "cid", avatarCid, "hash", bannerHash)
230 logger.Info("found users with similar banners within search duration", "dids", similarBannerDids)
231 }
232}
233
234func convertHash(hash string) ([]float32, error) {
235 hashBin, err := HexToBinary(hash)
236 if err != nil {
237 return nil, fmt.Errorf("failed to convert to binary: %w", err)
238 }
239 return BinaryToFloatVector(hashBin), nil
240}
241
242func HexToBinary(input string) ([]byte, error) {
243 hashb, err := hex.DecodeString(input)
244 if err != nil {
245 return nil, err
246 }
247 return hashb, nil
248}
249
250func BinaryToFloatVector(bin []byte) []float32 {
251 vectorData := make([]float32, len(bin)*8)
252 for i, b := range bin {
253 for j := range 8 {
254 if (b>>(7-j))&1 == 1 {
255 vectorData[i*8+j] = 1.0
256 } else {
257 vectorData[i*8+j] = 0.0
258 }
259 }
260 }
261 return vectorData
262}