import logging import sys import uuid from dataclasses import dataclass from datetime import datetime, timezone from time import time from typing import List, Optional from qdrant_client import QdrantClient from qdrant_client.grpc import OptimizersConfigDiff from qdrant_client.http.models import BinaryQuantizationConfig from qdrant_client.models import ( BinaryQuantization, Distance, FieldCondition, Filter, HnswConfigDiff, MatchValue, Payload, PayloadSchemaType, PointStruct, ScalarQuantization, ScalarQuantizationConfig, ScalarType, VectorParams, ) from config import CONFIG from metrics import prom_metrics logger = logging.getLogger(__name__) @dataclass class Result: did: str payload: Optional[Payload] score: Optional[float] @dataclass class ResultWithVector(Result): vector: PointStruct class QdrantService: def __init__(self) -> None: self._client = None def initialized(self): return self._client is not None def get_client(self): return self._client def initialize(self) -> None: logger.info(f"Connecting to Qdrant at {CONFIG.qdrant_url}") self._client = QdrantClient( url=CONFIG.qdrant_url, ) self.profile_collection_name = CONFIG.qdrant_profile_collection_name self.avatar_collection_name = CONFIG.qdrant_avatar_collection_name self.post_collection_name = CONFIG.qdrant_post_collection_name self._ensure_collections_exist() def _ensure_collections_exist(self): profile_coll_exists = self._client.collection_exists( self.profile_collection_name ) avatar_coll_exists = self._client.collection_exists(self.avatar_collection_name) post_coll_exists = self._client.collection_exists(self.post_collection_name) if not profile_coll_exists: logger.info(f"Creating profile collection: {self.profile_collection_name}") try: self._client.create_collection( collection_name=self.profile_collection_name, vectors_config=VectorParams(size=1024, distance=Distance.COSINE), hnsw_config=HnswConfigDiff(m=32, ef_construct=200), quantization_config=ScalarQuantization( scalar=ScalarQuantizationConfig( type=ScalarType.INT8, quantile=0.99, always_ram=True ) ), ) except Exception as e: logger.error(f"Failed to create profiles collection: {e}") sys.exit(1) try: self._client.create_payload_index( collection_name=self.profile_collection_name, field_name="did", field_schema=PayloadSchemaType.KEYWORD, ) self._client.create_payload_index( collection_name=self.profile_collection_name, field_name="timestamp", field_schema=PayloadSchemaType.DATETIME, ) except Exception as e: logger.error(f"Failed to create profiles indexes: {e}") sys.exit(1) logger.info("Collection created successfully") if not avatar_coll_exists: logger.info(f"Creating avatar collection: {self.avatar_collection_name}") try: self._client.create_collection( collection_name=self.avatar_collection_name, vectors_config=VectorParams( # PDQ vectors have a size of 256 size=256, # Qdrant doesn't support hamming distance, so we'll use euclidian distance and # use the square root of the selected max distance for lookups distance=Distance.EUCLID, ), hnsw_config=HnswConfigDiff( m=16, # lower m for binary-like data ef_construct=100, ), quantization_config=BinaryQuantization( binary=BinaryQuantizationConfig(always_ram=True) ), ) except Exception as e: logger.error(f"Failed to create avatar collection: {e}") sys.exit(1) try: self._client.create_payload_index( collection_name=self.avatar_collection_name, field_name="did", field_schema=PayloadSchemaType.KEYWORD, ) self._client.create_payload_index( collection_name=self.avatar_collection_name, field_name="timestamp", field_schema=PayloadSchemaType.DATETIME, ) except Exception as e: logger.error(f"Failed to create avatar indexes: {e}") sys.exit(1) if not post_coll_exists: logger.info(f"Creating post collection: {self.post_collection_name}") try: self._client.create_collection( collection_name=self.post_collection_name, vectors_config=VectorParams( size=CONFIG.embedding_size, distance=Distance.COSINE, ), hnsw_config=HnswConfigDiff( m=48, ef_construct=256, ), quantization_config=ScalarQuantization( scalar=ScalarQuantizationConfig( type=ScalarType.INT8, quantile=0.99, always_ram=True, ), ), optimizers_config=OptimizersConfigDiff( indexing_threshold=50_000, ), ) except Exception as e: logger.error(f"Failed to create posts collection: {e}") sys.exit(1) try: self._client.create_payload_index( collection_name=self.post_collection_name, field_name="uri", field_schema=PayloadSchemaType.KEYWORD, ) self._client.create_payload_index( collection_name=self.post_collection_name, field_name="timestamp", field_schema=PayloadSchemaType.DATETIME, ) except Exception as e: logger.error(f"Failed to create post indexes: {e}") sys.exit(1) logger.info("Collection created successfully") def upsert_profile(self, did: str, description: str, vector: List[float]): status = "error" start_time = time() try: payload = { "did": did, "description": description, "timestamp": create_now_timestamp(), } existing = self._client.scroll( collection_name=self.profile_collection_name, scroll_filter=Filter( must=[FieldCondition(key="did", match=MatchValue(value=did))] ), ) if existing and existing[0] and len(existing[0]) > 0: point_id = existing[0][0].id else: point_id = str(uuid.uuid4()) point = PointStruct( id=point_id, vector=vector, payload=payload, ) self._client.upsert( collection_name=self.profile_collection_name, points=[point], ) status = "ok" return True except Exception as e: logger.error(f"Error upserting profile: {e}") return False finally: prom_metrics.upserts.labels(kind="profile", status=status).inc() prom_metrics.upsert_duration.labels(kind="profile", status=status).observe( time() - start_time ) def upsert_avatar(self, did: str, cid: str, vector: List[float]): status = "error" start_time = time() try: payload = { "did": did, "cid": cid, "timestamp": create_now_timestamp(), } existing = self._client.scroll( collection_name=self.avatar_collection_name, scroll_filter=Filter( must=[FieldCondition(key="did", match=MatchValue(value=did))] ), ) if existing and existing[0] and len(existing[0]) > 0: point_id = existing[0][0].id else: point_id = str(uuid.uuid4()) point = PointStruct( id=point_id, vector=vector, payload=payload, ) self._client.upsert( collection_name=self.avatar_collection_name, points=[point], ) status = "ok" return True except Exception as e: logger.error(f"Error upserting avatar: {e}") return False finally: prom_metrics.upserts.labels(kind="avatar", status=status).inc() prom_metrics.upsert_duration.labels(kind="avatar", status=status).observe( time() - start_time ) def upsert_post(self, did: str, uri: str, text: str, vector: List[float]): status = "error" start_time = time() word_ct = len(text.split()) try: payload = { "did": did, "uri": uri, "text": text, "word_count": word_ct, "timestamp": create_now_timestamp(), } # we don't care about upserting these point_id = str(uuid.uuid4()) point = PointStruct( id=point_id, vector=vector, payload=payload, ) self._client.upsert( collection_name=self.post_collection_name, points=[point], ) status = "ok" return True except Exception as e: logger.error(f"Error upserting post: {e}") return False finally: prom_metrics.upserts.labels(kind="post", status=status).inc() prom_metrics.upsert_duration.labels(kind="post", status=status).observe( time() - start_time ) def search_similar( self, collection_name: str, query_vector: List[float], limit: int = 10, score_threshold: Optional[float] = None, filter_conditions: Optional[Filter] = None, ) -> Optional[List[Result]]: try: results = self._client.query_points( collection_name=collection_name, query=query_vector, query_filter=filter_conditions, limit=limit, score_threshold=score_threshold, with_payload=True, ).points return [ Result( did=hit.payload.get("did"), payload=hit.payload, score=hit.score, ) for hit in results ] except Exception as e: logger.error(f"Error searching for similar vectors: {e}") def get_profile_by_did(self, did: str) -> Optional[ResultWithVector]: result = self._client.scroll( collection_name=self.profile_collection_name, scroll_filter=Filter( must=[FieldCondition(key="did", match=MatchValue(value=did))] ), with_vectors=True, with_payload=True, ) if result and result[0] and len(result[0]) > 0: point = result[0][0] return ResultWithVector( did=point.payload["did"], payload=point.payload, vector=point.vector, score=1.0, ) def get_avatar_by_did(self, did: str) -> Optional[ResultWithVector]: result = self._client.scroll( collection_name=self.avatar_collection_name, scroll_filter=Filter( must=[FieldCondition(key="did", match=MatchValue(value=did))] ), with_vectors=True, with_payload=True, ) if result and result[0] and len(result[0]) > 0: point = result[0][0] return ResultWithVector( did=point.payload["did"], payload=point.payload, vector=point.vector, score=1.0, ) QDRANT_SERVICE = QdrantService() def create_now_timestamp(): return datetime.now(timezone.utc).isoformat()