A lil service that creates embeddings of posts, profiles, and avatars to store them in Qdrant
at main 13 kB view raw
1import logging 2import sys 3import uuid 4from dataclasses import dataclass 5from datetime import datetime, timezone 6from time import time 7from typing import List, Optional 8 9from qdrant_client import QdrantClient 10from qdrant_client.grpc import OptimizersConfigDiff 11from qdrant_client.http.models import BinaryQuantizationConfig 12from qdrant_client.models import ( 13 BinaryQuantization, 14 Distance, 15 FieldCondition, 16 Filter, 17 HnswConfigDiff, 18 MatchValue, 19 Payload, 20 PayloadSchemaType, 21 PointStruct, 22 ScalarQuantization, 23 ScalarQuantizationConfig, 24 ScalarType, 25 VectorParams, 26) 27 28from config import CONFIG 29from metrics import prom_metrics 30 31logger = logging.getLogger(__name__) 32 33 34@dataclass 35class Result: 36 did: str 37 payload: Optional[Payload] 38 score: Optional[float] 39 40 41@dataclass 42class ResultWithVector(Result): 43 vector: PointStruct 44 45 46class QdrantService: 47 def __init__(self) -> None: 48 self._client = None 49 50 def initialized(self): 51 return self._client is not None 52 53 def get_client(self): 54 return self._client 55 56 def initialize(self) -> None: 57 logger.info(f"Connecting to Qdrant at {CONFIG.qdrant_url}") 58 59 self._client = QdrantClient( 60 url=CONFIG.qdrant_url, 61 ) 62 63 self.profile_collection_name = CONFIG.qdrant_profile_collection_name 64 self.avatar_collection_name = CONFIG.qdrant_avatar_collection_name 65 self.post_collection_name = CONFIG.qdrant_post_collection_name 66 self._ensure_collections_exist() 67 68 def _ensure_collections_exist(self): 69 profile_coll_exists = self._client.collection_exists( 70 self.profile_collection_name 71 ) 72 avatar_coll_exists = self._client.collection_exists(self.avatar_collection_name) 73 post_coll_exists = self._client.collection_exists(self.post_collection_name) 74 75 if not profile_coll_exists: 76 logger.info(f"Creating profile collection: {self.profile_collection_name}") 77 try: 78 self._client.create_collection( 79 collection_name=self.profile_collection_name, 80 vectors_config=VectorParams(size=1024, distance=Distance.COSINE), 81 hnsw_config=HnswConfigDiff(m=32, ef_construct=200), 82 quantization_config=ScalarQuantization( 83 scalar=ScalarQuantizationConfig( 84 type=ScalarType.INT8, quantile=0.99, always_ram=True 85 ) 86 ), 87 ) 88 except Exception as e: 89 logger.error(f"Failed to create profiles collection: {e}") 90 sys.exit(1) 91 92 try: 93 self._client.create_payload_index( 94 collection_name=self.profile_collection_name, 95 field_name="did", 96 field_schema=PayloadSchemaType.KEYWORD, 97 ) 98 self._client.create_payload_index( 99 collection_name=self.profile_collection_name, 100 field_name="timestamp", 101 field_schema=PayloadSchemaType.DATETIME, 102 ) 103 except Exception as e: 104 logger.error(f"Failed to create profiles indexes: {e}") 105 sys.exit(1) 106 107 logger.info("Collection created successfully") 108 109 if not avatar_coll_exists: 110 logger.info(f"Creating avatar collection: {self.avatar_collection_name}") 111 112 try: 113 self._client.create_collection( 114 collection_name=self.avatar_collection_name, 115 vectors_config=VectorParams( 116 # PDQ vectors have a size of 256 117 size=256, 118 # Qdrant doesn't support hamming distance, so we'll use euclidian distance and 119 # use the square root of the selected max distance for lookups 120 distance=Distance.EUCLID, 121 ), 122 hnsw_config=HnswConfigDiff( 123 m=16, # lower m for binary-like data 124 ef_construct=100, 125 ), 126 quantization_config=BinaryQuantization( 127 binary=BinaryQuantizationConfig(always_ram=True) 128 ), 129 ) 130 except Exception as e: 131 logger.error(f"Failed to create avatar collection: {e}") 132 sys.exit(1) 133 134 try: 135 self._client.create_payload_index( 136 collection_name=self.avatar_collection_name, 137 field_name="did", 138 field_schema=PayloadSchemaType.KEYWORD, 139 ) 140 self._client.create_payload_index( 141 collection_name=self.avatar_collection_name, 142 field_name="timestamp", 143 field_schema=PayloadSchemaType.DATETIME, 144 ) 145 except Exception as e: 146 logger.error(f"Failed to create avatar indexes: {e}") 147 sys.exit(1) 148 149 if not post_coll_exists: 150 logger.info(f"Creating post collection: {self.post_collection_name}") 151 try: 152 self._client.create_collection( 153 collection_name=self.post_collection_name, 154 vectors_config=VectorParams( 155 size=CONFIG.embedding_size, 156 distance=Distance.COSINE, 157 ), 158 hnsw_config=HnswConfigDiff( 159 m=48, 160 ef_construct=256, 161 ), 162 quantization_config=ScalarQuantization( 163 scalar=ScalarQuantizationConfig( 164 type=ScalarType.INT8, 165 quantile=0.99, 166 always_ram=True, 167 ), 168 ), 169 optimizers_config=OptimizersConfigDiff( 170 indexing_threshold=50_000, 171 ), 172 ) 173 except Exception as e: 174 logger.error(f"Failed to create posts collection: {e}") 175 sys.exit(1) 176 177 try: 178 self._client.create_payload_index( 179 collection_name=self.post_collection_name, 180 field_name="uri", 181 field_schema=PayloadSchemaType.KEYWORD, 182 ) 183 self._client.create_payload_index( 184 collection_name=self.post_collection_name, 185 field_name="timestamp", 186 field_schema=PayloadSchemaType.DATETIME, 187 ) 188 except Exception as e: 189 logger.error(f"Failed to create post indexes: {e}") 190 sys.exit(1) 191 192 logger.info("Collection created successfully") 193 194 def upsert_profile(self, did: str, description: str, vector: List[float]): 195 status = "error" 196 start_time = time() 197 198 try: 199 payload = { 200 "did": did, 201 "description": description, 202 "timestamp": create_now_timestamp(), 203 } 204 205 existing = self._client.scroll( 206 collection_name=self.profile_collection_name, 207 scroll_filter=Filter( 208 must=[FieldCondition(key="did", match=MatchValue(value=did))] 209 ), 210 ) 211 212 if existing and existing[0] and len(existing[0]) > 0: 213 point_id = existing[0][0].id 214 else: 215 point_id = str(uuid.uuid4()) 216 217 point = PointStruct( 218 id=point_id, 219 vector=vector, 220 payload=payload, 221 ) 222 223 self._client.upsert( 224 collection_name=self.profile_collection_name, 225 points=[point], 226 ) 227 228 status = "ok" 229 230 return True 231 except Exception as e: 232 logger.error(f"Error upserting profile: {e}") 233 return False 234 finally: 235 prom_metrics.upserts.labels(kind="profile", status=status).inc() 236 prom_metrics.upsert_duration.labels(kind="profile", status=status).observe( 237 time() - start_time 238 ) 239 240 def upsert_avatar(self, did: str, cid: str, vector: List[float]): 241 status = "error" 242 start_time = time() 243 244 try: 245 payload = { 246 "did": did, 247 "cid": cid, 248 "timestamp": create_now_timestamp(), 249 } 250 251 existing = self._client.scroll( 252 collection_name=self.avatar_collection_name, 253 scroll_filter=Filter( 254 must=[FieldCondition(key="did", match=MatchValue(value=did))] 255 ), 256 ) 257 258 if existing and existing[0] and len(existing[0]) > 0: 259 point_id = existing[0][0].id 260 else: 261 point_id = str(uuid.uuid4()) 262 263 point = PointStruct( 264 id=point_id, 265 vector=vector, 266 payload=payload, 267 ) 268 269 self._client.upsert( 270 collection_name=self.avatar_collection_name, 271 points=[point], 272 ) 273 274 status = "ok" 275 276 return True 277 except Exception as e: 278 logger.error(f"Error upserting avatar: {e}") 279 return False 280 finally: 281 prom_metrics.upserts.labels(kind="avatar", status=status).inc() 282 prom_metrics.upsert_duration.labels(kind="avatar", status=status).observe( 283 time() - start_time 284 ) 285 286 def upsert_post(self, did: str, uri: str, text: str, vector: List[float]): 287 status = "error" 288 start_time = time() 289 290 word_ct = len(text.split()) 291 292 try: 293 payload = { 294 "did": did, 295 "uri": uri, 296 "text": text, 297 "word_count": word_ct, 298 "timestamp": create_now_timestamp(), 299 } 300 301 # we don't care about upserting these 302 point_id = str(uuid.uuid4()) 303 304 point = PointStruct( 305 id=point_id, 306 vector=vector, 307 payload=payload, 308 ) 309 310 self._client.upsert( 311 collection_name=self.post_collection_name, 312 points=[point], 313 ) 314 315 status = "ok" 316 317 return True 318 except Exception as e: 319 logger.error(f"Error upserting post: {e}") 320 return False 321 finally: 322 prom_metrics.upserts.labels(kind="post", status=status).inc() 323 prom_metrics.upsert_duration.labels(kind="post", status=status).observe( 324 time() - start_time 325 ) 326 327 def search_similar( 328 self, 329 collection_name: str, 330 query_vector: List[float], 331 limit: int = 10, 332 score_threshold: Optional[float] = None, 333 filter_conditions: Optional[Filter] = None, 334 ) -> Optional[List[Result]]: 335 try: 336 results = self._client.query_points( 337 collection_name=collection_name, 338 query=query_vector, 339 query_filter=filter_conditions, 340 limit=limit, 341 score_threshold=score_threshold, 342 with_payload=True, 343 ).points 344 345 return [ 346 Result( 347 did=hit.payload.get("did"), 348 payload=hit.payload, 349 score=hit.score, 350 ) 351 for hit in results 352 ] 353 except Exception as e: 354 logger.error(f"Error searching for similar vectors: {e}") 355 356 def get_profile_by_did(self, did: str) -> Optional[ResultWithVector]: 357 result = self._client.scroll( 358 collection_name=self.profile_collection_name, 359 scroll_filter=Filter( 360 must=[FieldCondition(key="did", match=MatchValue(value=did))] 361 ), 362 with_vectors=True, 363 with_payload=True, 364 ) 365 366 if result and result[0] and len(result[0]) > 0: 367 point = result[0][0] 368 return ResultWithVector( 369 did=point.payload["did"], 370 payload=point.payload, 371 vector=point.vector, 372 score=1.0, 373 ) 374 375 def get_avatar_by_did(self, did: str) -> Optional[ResultWithVector]: 376 result = self._client.scroll( 377 collection_name=self.avatar_collection_name, 378 scroll_filter=Filter( 379 must=[FieldCondition(key="did", match=MatchValue(value=did))] 380 ), 381 with_vectors=True, 382 with_payload=True, 383 ) 384 385 if result and result[0] and len(result[0]) > 0: 386 point = result[0][0] 387 return ResultWithVector( 388 did=point.payload["did"], 389 payload=point.payload, 390 vector=point.vector, 391 score=1.0, 392 ) 393 394 395QDRANT_SERVICE = QdrantService() 396 397 398def create_now_timestamp(): 399 return datetime.now(timezone.utc).isoformat()