this repo has no description
at main 6.4 kB view raw
1package server 2 3import ( 4 "context" 5 "encoding/hex" 6 "fmt" 7 "sync" 8 "time" 9 10 "github.com/bluesky-social/indigo/api/bsky" 11 "github.com/google/uuid" 12 "github.com/qdrant/go-client/qdrant" 13) 14 15func (s *Server) handleProfile(ctx context.Context, did string, profile *bsky.ActorProfile) { 16 logger := s.logger.With("name", "handleProfile", "did", did) 17 searchCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 18 defer cancel() 19 20 var avatarCid string 21 if profile.Avatar != nil { 22 avatarCid = profile.Avatar.Ref.String() 23 } 24 25 var bannerCid string 26 if profile.Banner != nil { 27 bannerCid = profile.Banner.Ref.String() 28 } 29 30 if avatarCid == "" && bannerCid == "" { 31 return 32 } 33 34 var avatarHash string 35 var bannerHash string 36 getHashResult := func(kind, cid string) { 37 logger := logger.With("kind", kind, "cid", cid) 38 hashResult, err := s.getImageHash(searchCtx, did, cid) 39 if err != nil { 40 logger.Error("failed to get hash result", "err", err) 41 return 42 } 43 if hashResult.QualityTooLow { 44 logger.Info("quality too low") 45 return 46 } 47 if hashResult.Hash == nil { 48 logger.Warn("nil hash") 49 return 50 } 51 52 switch kind { 53 case "avatar": 54 avatarHash = *hashResult.Hash 55 case "banner": 56 bannerHash = *hashResult.Hash 57 } 58 } 59 60 var wg sync.WaitGroup 61 if avatarCid != "" { 62 wg.Go(func() { 63 getHashResult("avatar", avatarCid) 64 }) 65 } 66 if bannerCid != "" { 67 wg.Go(func() { 68 getHashResult("banner", bannerCid) 69 }) 70 } 71 wg.Wait() 72 73 var avatarVector []float32 74 var bannerVector []float32 75 if avatarHash != "" { 76 vec, err := convertHash(avatarHash) 77 if err != nil { 78 logger.Error("error getting vector", "err", err) 79 } else { 80 avatarVector = vec 81 } 82 } 83 if bannerHash != "" { 84 vec, err := convertHash(bannerHash) 85 if err != nil { 86 logger.Error("error getting vector", "err", err) 87 } else { 88 bannerVector = vec 89 } 90 } 91 92 now := time.Now().UTC().UnixMilli() 93 nowF64 := float64(now) 94 startRange := float64(now - s.maxSearchTime.Milliseconds()) 95 96 var similarAvatarDids []string 97 var similarBannerDids []string 98 99 if avatarVector != nil { 100 wg.Go(func() { 101 query := &qdrant.QueryPoints{ 102 CollectionName: s.qdrantCollection, 103 Query: qdrant.NewQueryDense(avatarVector), 104 Filter: &qdrant.Filter{ 105 Must: []*qdrant.Condition{ 106 qdrant.NewRange("timestamp", &qdrant.Range{ 107 Lte: &nowF64, 108 Gte: &startRange, 109 }), 110 qdrant.NewMatch("kind", "avatar"), 111 }, 112 }, 113 WithPayload: qdrant.NewWithPayload(true), 114 Limit: qdrant.PtrOf(uint64(s.maxLimit)), 115 ScoreThreshold: qdrant.PtrOf(s.minEuclidianDistance), 116 } 117 118 results, err := s.dbClient.Query(searchCtx, query) 119 if err != nil { 120 logger.Error("failed to search for vectors", "err", err) 121 return 122 } 123 124 seenDids := make(map[string]bool) 125 for _, r := range results { 126 rDidVal, ok := r.Payload["did"] 127 if !ok { 128 continue 129 } 130 rDid := rDidVal.GetStringValue() 131 if !seenDids[rDid] { 132 seenDids[rDid] = true 133 similarAvatarDids = append(similarAvatarDids, rDid) 134 } 135 } 136 }) 137 } 138 139 if bannerVector != nil { 140 wg.Go(func() { 141 query := &qdrant.QueryPoints{ 142 CollectionName: s.qdrantCollection, 143 Query: qdrant.NewQueryDense(bannerVector), 144 Filter: &qdrant.Filter{ 145 Must: []*qdrant.Condition{ 146 qdrant.NewRange("timestamp", &qdrant.Range{ 147 Lte: &nowF64, 148 Gte: &startRange, 149 }), 150 qdrant.NewMatch("kind", "banner"), 151 }, 152 }, 153 WithPayload: qdrant.NewWithPayload(true), 154 Limit: qdrant.PtrOf(uint64(s.maxLimit)), 155 ScoreThreshold: qdrant.PtrOf(s.minEuclidianDistance), 156 } 157 158 results, err := s.dbClient.Query(searchCtx, query) 159 if err != nil { 160 logger.Error("failed to search for vectors", "err", err) 161 return 162 } 163 164 seenDids := make(map[string]bool) 165 for _, r := range results { 166 rDidVal, ok := r.Payload["did"] 167 if !ok { 168 continue 169 } 170 rDid := rDidVal.GetStringValue() 171 if !seenDids[rDid] { 172 seenDids[rDid] = true 173 similarBannerDids = append(similarBannerDids, rDid) 174 } 175 } 176 }) 177 } 178 179 wg.Wait() 180 181 // inserts can happen in a different goroutine so we don't block returning these results 182 go func() { 183 insertCtx, cancel := context.WithTimeout(ctx, 10*time.Second) 184 defer cancel() 185 186 points := make([]*qdrant.PointStruct, 0, 2) 187 if avatarVector != nil { 188 point := &qdrant.PointStruct{ 189 Id: qdrant.NewIDUUID(uuid.NewString()), 190 Vectors: qdrant.NewVectors(avatarVector...), 191 Payload: qdrant.NewValueMap(map[string]any{ 192 "did": did, 193 "cid": avatarCid, 194 "kind": "avatar", 195 "timestamp": now, 196 }), 197 } 198 points = append(points, point) 199 } 200 if bannerVector != nil { 201 point := &qdrant.PointStruct{ 202 Id: qdrant.NewIDUUID(uuid.NewString()), 203 Vectors: qdrant.NewVectors(bannerVector...), 204 Payload: qdrant.NewValueMap(map[string]any{ 205 "did": did, 206 "cid": bannerCid, 207 "kind": "banner", 208 "timestamp": now, 209 }), 210 } 211 points = append(points, point) 212 } 213 214 if len(points) > 0 { 215 if _, err := s.dbClient.Upsert(insertCtx, &qdrant.UpsertPoints{ 216 CollectionName: s.qdrantCollection, 217 Points: points, 218 }); err != nil { 219 logger.Error("failed to insert hashes", "err", err) 220 } 221 } 222 }() 223 224 if len(similarAvatarDids) >= s.seenThreshold { 225 logger = logger.With("kind", "avatar", "did", did, "cid", avatarCid, "hash", avatarHash) 226 logger.Info("found users with similar avatars within search duration", "dids", similarAvatarDids) 227 } 228 if len(similarBannerDids) >= s.seenThreshold { 229 logger = logger.With("kind", "banner", "did", did, "cid", avatarCid, "hash", bannerHash) 230 logger.Info("found users with similar banners within search duration", "dids", similarBannerDids) 231 } 232} 233 234func convertHash(hash string) ([]float32, error) { 235 hashBin, err := HexToBinary(hash) 236 if err != nil { 237 return nil, fmt.Errorf("failed to convert to binary: %w", err) 238 } 239 return BinaryToFloatVector(hashBin), nil 240} 241 242func HexToBinary(input string) ([]byte, error) { 243 hashb, err := hex.DecodeString(input) 244 if err != nil { 245 return nil, err 246 } 247 return hashb, nil 248} 249 250func BinaryToFloatVector(bin []byte) []float32 { 251 vectorData := make([]float32, len(bin)*8) 252 for i, b := range bin { 253 for j := range 8 { 254 if (b>>(7-j))&1 == 1 { 255 vectorData[i*8+j] = 1.0 256 } else { 257 vectorData[i*8+j] = 0.0 258 } 259 } 260 } 261 return vectorData 262}