1import logging
2import math
3import sys
4from typing import List
5from datetime import datetime, timezone, timedelta
6import click
7from qdrant_client.models import (
8 DatetimeRange,
9 Direction,
10 FieldCondition,
11 Filter,
12 MatchValue,
13 OrderBy,
14)
15from rich.console import Console
16from rich.table import Table
17from rich.panel import Panel
18from rich import box
19from itertools import combinations
20import numpy as np
21
22from config import CONFIG
23from database import QDRANT_SERVICE, Result
24from embedder import EMBEDDING_SERVICE
25from retina import RETINA_CLIENT, binary_to_float_vector, hex_to_binary
26
27
28logging.basicConfig(
29 level=logging.INFO,
30 format=logging.BASIC_FORMAT,
31)
32
33logger = logging.getLogger(__name__)
34
35console = Console()
36
37
38def display_results(
39 type: str, query: str, results: List[Result], show_more: bool = False
40):
41 """
42 A lil guy that turns the results into a table for viewing
43 """
44
45 if not results:
46 console.print("[yellow]No similar profiles found.[/yellow]")
47 return
48
49 console.print(Panel(f"[bold blue]Query: {query}[/bold blue]", box=box.ROUNDED))
50 console.print()
51
52 table = Table(
53 title=f"Found {len(results)} similar results",
54 box=box.ROUNDED,
55 header_style="bold magenta",
56 expand=True,
57 show_header=True,
58 show_lines=True,
59 )
60
61 table.add_column("#", style="dim", width=4)
62 table.add_column("DID", style="cyan", width=35)
63 table.add_column("Similarity", justify="right", style="green", width=10)
64
65 if show_more:
66 table.add_column("More", style="white", overflow="fold")
67
68 for idx, result in enumerate(results, 1):
69 similarity_percent = f"{result.score or 0.0 * 100:.4f}%"
70
71 row: List[str] = [
72 str(idx),
73 result.did,
74 similarity_percent,
75 ]
76
77 if show_more and result.payload is not None:
78 more = None
79 if type == "profile":
80 more = result.payload.get("description")
81 elif type == "avatar":
82 cid = result.payload.get("cid")
83 more = f"https://cdn.bsky.app/img/feed_thumbnail/plain/{result.did}/{cid}@jpeg"
84 elif type == "post":
85 more = result.payload.get("text")
86
87 if more is not None:
88 row.append(more)
89
90 table.add_row(*row)
91
92 console.print(table)
93 console.print()
94
95
96@click.group()
97def main():
98 pass
99
100
101@main.command()
102@click.argument("query", required=False)
103@click.option(
104 "--type",
105 default="profile",
106 show_default=True,
107)
108@click.option(
109 "--did",
110)
111@click.option(
112 "--limit",
113 default=10,
114 show_default=True,
115)
116@click.option(
117 "--threshold",
118 default=0.7,
119 type=float,
120 show_default=True,
121)
122@click.option(
123 "--more",
124 is_flag=True,
125 default=False,
126 show_default=True,
127)
128def search(
129 query: str,
130 type: str,
131 did: str,
132 limit: int,
133 threshold: float,
134 more: bool,
135):
136 # TODO: would be nice if these were flags instead
137 if type not in ["profile", "avatar", "post"]:
138 raise Exception("invalid type")
139
140 QDRANT_SERVICE.initialize()
141
142 try:
143 if type == "profile":
144 if not query:
145 console.print("[cyan]Looking up profile...[/cyan]")
146 profile = QDRANT_SERVICE.get_profile_by_did(did)
147
148 if not profile:
149 console.print(f"[red]Profile not found: {did}[/red]")
150 sys.exit(1)
151
152 description = profile.payload.get("description")
153 query_vector = profile.vector
154
155 console.print("[green]Found profile[/green]")
156 else:
157 EMBEDDING_SERVICE.initialize()
158
159 description = query
160 query_vector = EMBEDDING_SERVICE.encode(query)
161
162 console.print("[cyan]Looking up similar profiles...[/cyan]")
163
164 results = QDRANT_SERVICE.search_similar(
165 collection_name=CONFIG.qdrant_profile_collection_name,
166 query_vector=query_vector,
167 limit=limit,
168 score_threshold=math.sqrt(threshold),
169 )
170
171 display_results(type, description, results, more)
172 elif type == "avatar":
173 if not query:
174 console.print("[cyan]Looking up avatar...[/cyan]")
175 avatar = QDRANT_SERVICE.get_avatar_by_did(did)
176
177 if not avatar:
178 console.print(f"[red]Avatar not found: {did}[/red]")
179 sys.exit(1)
180
181 cid = avatar.payload.get("cid")
182 query_vector = avatar.vector
183 else:
184 pts = query.split("/")
185
186 if len(pts) != 8:
187 console.print("[red]Invalid avatar URL provided[/red]")
188 sys.exit(1)
189
190 did = pts[6]
191 cid = pts[7].split("@")[0]
192
193 resp = RETINA_CLIENT.get_image_hash(did, cid)
194
195 if resp.quality_too_low or resp.hash is None:
196 console.print("[red]Hash quality too low[/red]")
197 sys.exit(1)
198
199 query_vector = binary_to_float_vector(hex_to_binary(resp.hash))
200
201 console.print("[cyan]Looking up similar avatars...[/cyan]")
202
203 results = QDRANT_SERVICE.search_similar(
204 collection_name=CONFIG.qdrant_avatar_collection_name,
205 query_vector=query_vector,
206 limit=limit,
207 score_threshold=threshold,
208 )
209
210 display_results(type, cid, results, more)
211 elif type == "post":
212 if not query:
213 console.print("[red]Must supply input for post search[/red]")
214 sys.exit(1)
215 else:
216 EMBEDDING_SERVICE.initialize()
217
218 description = query
219 query_vector = EMBEDDING_SERVICE.encode(query)
220
221 console.print("[cyan]Looking up similar posts...[/cyan]")
222
223 results = QDRANT_SERVICE.search_similar(
224 collection_name=CONFIG.qdrant_post_collection_name,
225 query_vector=query_vector,
226 limit=limit,
227 score_threshold=threshold,
228 )
229
230 display_results(type, description, results, more)
231
232 except Exception as e:
233 console.print(f"[red]Error: {e}[/red]")
234 logger.error(f"Search error: {e}", exc_info=True)
235 sys.exit(1)
236
237
238@main.command()
239@click.argument("text", required=True)
240@click.option("--did")
241@click.option(
242 "--more",
243 is_flag=True,
244 default=False,
245)
246def did_similar_posts(text: str, did: str, more: bool):
247 QDRANT_SERVICE.initialize()
248 EMBEDDING_SERVICE.initialize()
249
250 vector = EMBEDDING_SERVICE.encode(text)
251
252 client = QDRANT_SERVICE.get_client()
253
254 console.print(f"[cyan]Searching for [bold]{did}[/bold]'s posts...[/cyan]")
255
256 results = client.query_points(
257 collection_name=CONFIG.qdrant_post_collection_name,
258 query=vector,
259 query_filter=Filter(
260 must=[FieldCondition(key="did", match=MatchValue(value=did))]
261 ),
262 limit=30,
263 score_threshold=0.85,
264 with_payload=True,
265 ).points
266
267 total_score = 0
268 for hit in results:
269 total_score += hit.score
270 avg = total_score / len(results)
271
272 console.print(
273 f"[green]Found [bold]{len(results)}[/bold] similar posts from [bold]{did}[/bold]. Average similarity was [bold]{avg}[/bold].[/green]"
274 )
275
276 if more:
277 for hit in results:
278 text = hit.payload.get("text")
279 console.print(text)
280 console.print()
281
282
283@main.command()
284@click.option("--did")
285@click.option(
286 "--more",
287 is_flag=True,
288 default=False,
289)
290def did_similar_recent(did: str, more: bool):
291 QDRANT_SERVICE.initialize()
292
293 client = QDRANT_SERVICE.get_client()
294
295 day_ago = datetime.now(timezone.utc) - timedelta(days=1)
296
297 results = client.scroll(
298 collection_name=CONFIG.qdrant_post_collection_name,
299 scroll_filter=Filter(
300 must=[
301 FieldCondition(key="did", match=MatchValue(value=did)),
302 FieldCondition(
303 key="timestamp",
304 range=DatetimeRange(gte=day_ago),
305 ),
306 ]
307 ),
308 order_by=OrderBy(
309 key="timestamp",
310 direction=Direction.DESC,
311 ),
312 limit=30,
313 with_payload=True,
314 with_vectors=True,
315 )[0]
316
317 if len(results) < 2:
318 console.print(
319 f"[yellow]Found only {len(results)} post(s). Need at least 2 to compare.[/yellow]"
320 )
321 return
322
323 vectors = [point.vector for point in results]
324
325 similarities = []
326 for i, j in combinations(range(len(vectors)), 2):
327 dot_product = np.dot(vectors[i], vectors[j])
328 norm_i = np.linalg.norm(vectors[i])
329 norm_j = np.linalg.norm(vectors[j])
330 similarity = dot_product / (norm_i * norm_j)
331 similarities.append(similarity)
332
333 avg_similarity = np.mean(similarities)
334 min_similarity = np.min(similarities)
335 max_similarity = np.max(similarities)
336
337 console.print(
338 f"[green]Found [bold]{len(results)}[/bold] posts from [bold]{did}[/bold] in the last 24h.[/green]\n"
339 f"[cyan]Average pairwise similarity: [bold]{avg_similarity:.4f}[/bold][/cyan]\n"
340 f"[cyan]Min: {min_similarity:.4f}, Max: {max_similarity:.4f}[/cyan]"
341 )
342
343 if more:
344 console.print("\n[bold]Posts:[/bold]\n")
345 for i, point in enumerate(results):
346 text = point.payload.get("text", "")
347 timestamp = point.payload.get("timestamp", "")
348 console.print(f"[dim]{i + 1}. {timestamp}[/dim]")
349 console.print(text)
350 console.print()
351
352
353if __name__ == "__main__":
354 main()