1import logging
2import sys
3import uuid
4from dataclasses import dataclass
5from datetime import datetime, timezone
6from time import time
7from typing import List, Optional
8
9from qdrant_client import QdrantClient
10from qdrant_client.grpc import OptimizersConfigDiff
11from qdrant_client.http.models import BinaryQuantizationConfig
12from qdrant_client.models import (
13 BinaryQuantization,
14 Distance,
15 FieldCondition,
16 Filter,
17 HnswConfigDiff,
18 MatchValue,
19 Payload,
20 PayloadSchemaType,
21 PointStruct,
22 ScalarQuantization,
23 ScalarQuantizationConfig,
24 ScalarType,
25 VectorParams,
26)
27
28from config import CONFIG
29from metrics import prom_metrics
30
31logger = logging.getLogger(__name__)
32
33
34@dataclass
35class Result:
36 did: str
37 payload: Optional[Payload]
38 score: Optional[float]
39
40
41@dataclass
42class ResultWithVector(Result):
43 vector: PointStruct
44
45
46class QdrantService:
47 def __init__(self) -> None:
48 self._client = None
49
50 def initialized(self):
51 return self._client is not None
52
53 def get_client(self):
54 return self._client
55
56 def initialize(self) -> None:
57 logger.info(f"Connecting to Qdrant at {CONFIG.qdrant_url}")
58
59 self._client = QdrantClient(
60 url=CONFIG.qdrant_url,
61 )
62
63 self.profile_collection_name = CONFIG.qdrant_profile_collection_name
64 self.avatar_collection_name = CONFIG.qdrant_avatar_collection_name
65 self.post_collection_name = CONFIG.qdrant_post_collection_name
66 self._ensure_collections_exist()
67
68 def _ensure_collections_exist(self):
69 profile_coll_exists = self._client.collection_exists(
70 self.profile_collection_name
71 )
72 avatar_coll_exists = self._client.collection_exists(self.avatar_collection_name)
73 post_coll_exists = self._client.collection_exists(self.post_collection_name)
74
75 if not profile_coll_exists:
76 logger.info(f"Creating profile collection: {self.profile_collection_name}")
77 try:
78 self._client.create_collection(
79 collection_name=self.profile_collection_name,
80 vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
81 hnsw_config=HnswConfigDiff(m=32, ef_construct=200),
82 quantization_config=ScalarQuantization(
83 scalar=ScalarQuantizationConfig(
84 type=ScalarType.INT8, quantile=0.99, always_ram=True
85 )
86 ),
87 )
88 except Exception as e:
89 logger.error(f"Failed to create profiles collection: {e}")
90 sys.exit(1)
91
92 try:
93 self._client.create_payload_index(
94 collection_name=self.profile_collection_name,
95 field_name="did",
96 field_schema=PayloadSchemaType.KEYWORD,
97 )
98 self._client.create_payload_index(
99 collection_name=self.profile_collection_name,
100 field_name="timestamp",
101 field_schema=PayloadSchemaType.DATETIME,
102 )
103 except Exception as e:
104 logger.error(f"Failed to create profiles indexes: {e}")
105 sys.exit(1)
106
107 logger.info("Collection created successfully")
108
109 if not avatar_coll_exists:
110 logger.info(f"Creating avatar collection: {self.avatar_collection_name}")
111
112 try:
113 self._client.create_collection(
114 collection_name=self.avatar_collection_name,
115 vectors_config=VectorParams(
116 # PDQ vectors have a size of 256
117 size=256,
118 # Qdrant doesn't support hamming distance, so we'll use euclidian distance and
119 # use the square root of the selected max distance for lookups
120 distance=Distance.EUCLID,
121 ),
122 hnsw_config=HnswConfigDiff(
123 m=16, # lower m for binary-like data
124 ef_construct=100,
125 ),
126 quantization_config=BinaryQuantization(
127 binary=BinaryQuantizationConfig(always_ram=True)
128 ),
129 )
130 except Exception as e:
131 logger.error(f"Failed to create avatar collection: {e}")
132 sys.exit(1)
133
134 try:
135 self._client.create_payload_index(
136 collection_name=self.avatar_collection_name,
137 field_name="did",
138 field_schema=PayloadSchemaType.KEYWORD,
139 )
140 self._client.create_payload_index(
141 collection_name=self.avatar_collection_name,
142 field_name="timestamp",
143 field_schema=PayloadSchemaType.DATETIME,
144 )
145 except Exception as e:
146 logger.error(f"Failed to create avatar indexes: {e}")
147 sys.exit(1)
148
149 if not post_coll_exists:
150 logger.info(f"Creating post collection: {self.post_collection_name}")
151 try:
152 self._client.create_collection(
153 collection_name=self.post_collection_name,
154 vectors_config=VectorParams(
155 size=CONFIG.embedding_size,
156 distance=Distance.COSINE,
157 ),
158 hnsw_config=HnswConfigDiff(
159 m=48,
160 ef_construct=256,
161 ),
162 quantization_config=ScalarQuantization(
163 scalar=ScalarQuantizationConfig(
164 type=ScalarType.INT8,
165 quantile=0.99,
166 always_ram=True,
167 ),
168 ),
169 optimizers_config=OptimizersConfigDiff(
170 indexing_threshold=50_000,
171 ),
172 )
173 except Exception as e:
174 logger.error(f"Failed to create posts collection: {e}")
175 sys.exit(1)
176
177 try:
178 self._client.create_payload_index(
179 collection_name=self.post_collection_name,
180 field_name="uri",
181 field_schema=PayloadSchemaType.KEYWORD,
182 )
183 self._client.create_payload_index(
184 collection_name=self.post_collection_name,
185 field_name="timestamp",
186 field_schema=PayloadSchemaType.DATETIME,
187 )
188 except Exception as e:
189 logger.error(f"Failed to create post indexes: {e}")
190 sys.exit(1)
191
192 logger.info("Collection created successfully")
193
194 def upsert_profile(self, did: str, description: str, vector: List[float]):
195 status = "error"
196 start_time = time()
197
198 try:
199 payload = {
200 "did": did,
201 "description": description,
202 "timestamp": create_now_timestamp(),
203 }
204
205 existing = self._client.scroll(
206 collection_name=self.profile_collection_name,
207 scroll_filter=Filter(
208 must=[FieldCondition(key="did", match=MatchValue(value=did))]
209 ),
210 )
211
212 if existing and existing[0] and len(existing[0]) > 0:
213 point_id = existing[0][0].id
214 else:
215 point_id = str(uuid.uuid4())
216
217 point = PointStruct(
218 id=point_id,
219 vector=vector,
220 payload=payload,
221 )
222
223 self._client.upsert(
224 collection_name=self.profile_collection_name,
225 points=[point],
226 )
227
228 status = "ok"
229
230 return True
231 except Exception as e:
232 logger.error(f"Error upserting profile: {e}")
233 return False
234 finally:
235 prom_metrics.upserts.labels(kind="profile", status=status).inc()
236 prom_metrics.upsert_duration.labels(kind="profile", status=status).observe(
237 time() - start_time
238 )
239
240 def upsert_avatar(self, did: str, cid: str, vector: List[float]):
241 status = "error"
242 start_time = time()
243
244 try:
245 payload = {
246 "did": did,
247 "cid": cid,
248 "timestamp": create_now_timestamp(),
249 }
250
251 existing = self._client.scroll(
252 collection_name=self.avatar_collection_name,
253 scroll_filter=Filter(
254 must=[FieldCondition(key="did", match=MatchValue(value=did))]
255 ),
256 )
257
258 if existing and existing[0] and len(existing[0]) > 0:
259 point_id = existing[0][0].id
260 else:
261 point_id = str(uuid.uuid4())
262
263 point = PointStruct(
264 id=point_id,
265 vector=vector,
266 payload=payload,
267 )
268
269 self._client.upsert(
270 collection_name=self.avatar_collection_name,
271 points=[point],
272 )
273
274 status = "ok"
275
276 return True
277 except Exception as e:
278 logger.error(f"Error upserting avatar: {e}")
279 return False
280 finally:
281 prom_metrics.upserts.labels(kind="avatar", status=status).inc()
282 prom_metrics.upsert_duration.labels(kind="avatar", status=status).observe(
283 time() - start_time
284 )
285
286 def upsert_post(self, did: str, uri: str, text: str, vector: List[float]):
287 status = "error"
288 start_time = time()
289
290 word_ct = len(text.split())
291
292 try:
293 payload = {
294 "did": did,
295 "uri": uri,
296 "text": text,
297 "word_count": word_ct,
298 "timestamp": create_now_timestamp(),
299 }
300
301 # we don't care about upserting these
302 point_id = str(uuid.uuid4())
303
304 point = PointStruct(
305 id=point_id,
306 vector=vector,
307 payload=payload,
308 )
309
310 self._client.upsert(
311 collection_name=self.post_collection_name,
312 points=[point],
313 )
314
315 status = "ok"
316
317 return True
318 except Exception as e:
319 logger.error(f"Error upserting post: {e}")
320 return False
321 finally:
322 prom_metrics.upserts.labels(kind="post", status=status).inc()
323 prom_metrics.upsert_duration.labels(kind="post", status=status).observe(
324 time() - start_time
325 )
326
327 def search_similar(
328 self,
329 collection_name: str,
330 query_vector: List[float],
331 limit: int = 10,
332 score_threshold: Optional[float] = None,
333 filter_conditions: Optional[Filter] = None,
334 ) -> Optional[List[Result]]:
335 try:
336 results = self._client.query_points(
337 collection_name=collection_name,
338 query=query_vector,
339 query_filter=filter_conditions,
340 limit=limit,
341 score_threshold=score_threshold,
342 with_payload=True,
343 ).points
344
345 return [
346 Result(
347 did=hit.payload.get("did"),
348 payload=hit.payload,
349 score=hit.score,
350 )
351 for hit in results
352 ]
353 except Exception as e:
354 logger.error(f"Error searching for similar vectors: {e}")
355
356 def get_profile_by_did(self, did: str) -> Optional[ResultWithVector]:
357 result = self._client.scroll(
358 collection_name=self.profile_collection_name,
359 scroll_filter=Filter(
360 must=[FieldCondition(key="did", match=MatchValue(value=did))]
361 ),
362 with_vectors=True,
363 with_payload=True,
364 )
365
366 if result and result[0] and len(result[0]) > 0:
367 point = result[0][0]
368 return ResultWithVector(
369 did=point.payload["did"],
370 payload=point.payload,
371 vector=point.vector,
372 score=1.0,
373 )
374
375 def get_avatar_by_did(self, did: str) -> Optional[ResultWithVector]:
376 result = self._client.scroll(
377 collection_name=self.avatar_collection_name,
378 scroll_filter=Filter(
379 must=[FieldCondition(key="did", match=MatchValue(value=did))]
380 ),
381 with_vectors=True,
382 with_payload=True,
383 )
384
385 if result and result[0] and len(result[0]) > 0:
386 point = result[0][0]
387 return ResultWithVector(
388 did=point.payload["did"],
389 payload=point.payload,
390 vector=point.vector,
391 score=1.0,
392 )
393
394
395QDRANT_SERVICE = QdrantService()
396
397
398def create_now_timestamp():
399 return datetime.now(timezone.utc).isoformat()