A lil service that creates embeddings of posts, profiles, and avatars to store them in Qdrant
at main 9.9 kB view raw
1import logging 2import math 3import sys 4from typing import List 5from datetime import datetime, timezone, timedelta 6import click 7from qdrant_client.models import ( 8 DatetimeRange, 9 Direction, 10 FieldCondition, 11 Filter, 12 MatchValue, 13 OrderBy, 14) 15from rich.console import Console 16from rich.table import Table 17from rich.panel import Panel 18from rich import box 19from itertools import combinations 20import numpy as np 21 22from config import CONFIG 23from database import QDRANT_SERVICE, Result 24from embedder import EMBEDDING_SERVICE 25from retina import RETINA_CLIENT, binary_to_float_vector, hex_to_binary 26 27 28logging.basicConfig( 29 level=logging.INFO, 30 format=logging.BASIC_FORMAT, 31) 32 33logger = logging.getLogger(__name__) 34 35console = Console() 36 37 38def display_results( 39 type: str, query: str, results: List[Result], show_more: bool = False 40): 41 """ 42 A lil guy that turns the results into a table for viewing 43 """ 44 45 if not results: 46 console.print("[yellow]No similar profiles found.[/yellow]") 47 return 48 49 console.print(Panel(f"[bold blue]Query: {query}[/bold blue]", box=box.ROUNDED)) 50 console.print() 51 52 table = Table( 53 title=f"Found {len(results)} similar results", 54 box=box.ROUNDED, 55 header_style="bold magenta", 56 expand=True, 57 show_header=True, 58 show_lines=True, 59 ) 60 61 table.add_column("#", style="dim", width=4) 62 table.add_column("DID", style="cyan", width=35) 63 table.add_column("Similarity", justify="right", style="green", width=10) 64 65 if show_more: 66 table.add_column("More", style="white", overflow="fold") 67 68 for idx, result in enumerate(results, 1): 69 similarity_percent = f"{result.score or 0.0 * 100:.4f}%" 70 71 row: List[str] = [ 72 str(idx), 73 result.did, 74 similarity_percent, 75 ] 76 77 if show_more and result.payload is not None: 78 more = None 79 if type == "profile": 80 more = result.payload.get("description") 81 elif type == "avatar": 82 cid = result.payload.get("cid") 83 more = f"https://cdn.bsky.app/img/feed_thumbnail/plain/{result.did}/{cid}@jpeg" 84 elif type == "post": 85 more = result.payload.get("text") 86 87 if more is not None: 88 row.append(more) 89 90 table.add_row(*row) 91 92 console.print(table) 93 console.print() 94 95 96@click.group() 97def main(): 98 pass 99 100 101@main.command() 102@click.argument("query", required=False) 103@click.option( 104 "--type", 105 default="profile", 106 show_default=True, 107) 108@click.option( 109 "--did", 110) 111@click.option( 112 "--limit", 113 default=10, 114 show_default=True, 115) 116@click.option( 117 "--threshold", 118 default=0.7, 119 type=float, 120 show_default=True, 121) 122@click.option( 123 "--more", 124 is_flag=True, 125 default=False, 126 show_default=True, 127) 128def search( 129 query: str, 130 type: str, 131 did: str, 132 limit: int, 133 threshold: float, 134 more: bool, 135): 136 # TODO: would be nice if these were flags instead 137 if type not in ["profile", "avatar", "post"]: 138 raise Exception("invalid type") 139 140 QDRANT_SERVICE.initialize() 141 142 try: 143 if type == "profile": 144 if not query: 145 console.print("[cyan]Looking up profile...[/cyan]") 146 profile = QDRANT_SERVICE.get_profile_by_did(did) 147 148 if not profile: 149 console.print(f"[red]Profile not found: {did}[/red]") 150 sys.exit(1) 151 152 description = profile.payload.get("description") 153 query_vector = profile.vector 154 155 console.print("[green]Found profile[/green]") 156 else: 157 EMBEDDING_SERVICE.initialize() 158 159 description = query 160 query_vector = EMBEDDING_SERVICE.encode(query) 161 162 console.print("[cyan]Looking up similar profiles...[/cyan]") 163 164 results = QDRANT_SERVICE.search_similar( 165 collection_name=CONFIG.qdrant_profile_collection_name, 166 query_vector=query_vector, 167 limit=limit, 168 score_threshold=math.sqrt(threshold), 169 ) 170 171 display_results(type, description, results, more) 172 elif type == "avatar": 173 if not query: 174 console.print("[cyan]Looking up avatar...[/cyan]") 175 avatar = QDRANT_SERVICE.get_avatar_by_did(did) 176 177 if not avatar: 178 console.print(f"[red]Avatar not found: {did}[/red]") 179 sys.exit(1) 180 181 cid = avatar.payload.get("cid") 182 query_vector = avatar.vector 183 else: 184 pts = query.split("/") 185 186 if len(pts) != 8: 187 console.print("[red]Invalid avatar URL provided[/red]") 188 sys.exit(1) 189 190 did = pts[6] 191 cid = pts[7].split("@")[0] 192 193 resp = RETINA_CLIENT.get_image_hash(did, cid) 194 195 if resp.quality_too_low or resp.hash is None: 196 console.print("[red]Hash quality too low[/red]") 197 sys.exit(1) 198 199 query_vector = binary_to_float_vector(hex_to_binary(resp.hash)) 200 201 console.print("[cyan]Looking up similar avatars...[/cyan]") 202 203 results = QDRANT_SERVICE.search_similar( 204 collection_name=CONFIG.qdrant_avatar_collection_name, 205 query_vector=query_vector, 206 limit=limit, 207 score_threshold=threshold, 208 ) 209 210 display_results(type, cid, results, more) 211 elif type == "post": 212 if not query: 213 console.print("[red]Must supply input for post search[/red]") 214 sys.exit(1) 215 else: 216 EMBEDDING_SERVICE.initialize() 217 218 description = query 219 query_vector = EMBEDDING_SERVICE.encode(query) 220 221 console.print("[cyan]Looking up similar posts...[/cyan]") 222 223 results = QDRANT_SERVICE.search_similar( 224 collection_name=CONFIG.qdrant_post_collection_name, 225 query_vector=query_vector, 226 limit=limit, 227 score_threshold=threshold, 228 ) 229 230 display_results(type, description, results, more) 231 232 except Exception as e: 233 console.print(f"[red]Error: {e}[/red]") 234 logger.error(f"Search error: {e}", exc_info=True) 235 sys.exit(1) 236 237 238@main.command() 239@click.argument("text", required=True) 240@click.option("--did") 241@click.option( 242 "--more", 243 is_flag=True, 244 default=False, 245) 246def did_similar_posts(text: str, did: str, more: bool): 247 QDRANT_SERVICE.initialize() 248 EMBEDDING_SERVICE.initialize() 249 250 vector = EMBEDDING_SERVICE.encode(text) 251 252 client = QDRANT_SERVICE.get_client() 253 254 console.print(f"[cyan]Searching for [bold]{did}[/bold]'s posts...[/cyan]") 255 256 results = client.query_points( 257 collection_name=CONFIG.qdrant_post_collection_name, 258 query=vector, 259 query_filter=Filter( 260 must=[FieldCondition(key="did", match=MatchValue(value=did))] 261 ), 262 limit=30, 263 score_threshold=0.85, 264 with_payload=True, 265 ).points 266 267 total_score = 0 268 for hit in results: 269 total_score += hit.score 270 avg = total_score / len(results) 271 272 console.print( 273 f"[green]Found [bold]{len(results)}[/bold] similar posts from [bold]{did}[/bold]. Average similarity was [bold]{avg}[/bold].[/green]" 274 ) 275 276 if more: 277 for hit in results: 278 text = hit.payload.get("text") 279 console.print(text) 280 console.print() 281 282 283@main.command() 284@click.option("--did") 285@click.option( 286 "--more", 287 is_flag=True, 288 default=False, 289) 290def did_similar_recent(did: str, more: bool): 291 QDRANT_SERVICE.initialize() 292 293 client = QDRANT_SERVICE.get_client() 294 295 day_ago = datetime.now(timezone.utc) - timedelta(days=1) 296 297 results = client.scroll( 298 collection_name=CONFIG.qdrant_post_collection_name, 299 scroll_filter=Filter( 300 must=[ 301 FieldCondition(key="did", match=MatchValue(value=did)), 302 FieldCondition( 303 key="timestamp", 304 range=DatetimeRange(gte=day_ago), 305 ), 306 ] 307 ), 308 order_by=OrderBy( 309 key="timestamp", 310 direction=Direction.DESC, 311 ), 312 limit=30, 313 with_payload=True, 314 with_vectors=True, 315 )[0] 316 317 if len(results) < 2: 318 console.print( 319 f"[yellow]Found only {len(results)} post(s). Need at least 2 to compare.[/yellow]" 320 ) 321 return 322 323 vectors = [point.vector for point in results] 324 325 similarities = [] 326 for i, j in combinations(range(len(vectors)), 2): 327 dot_product = np.dot(vectors[i], vectors[j]) 328 norm_i = np.linalg.norm(vectors[i]) 329 norm_j = np.linalg.norm(vectors[j]) 330 similarity = dot_product / (norm_i * norm_j) 331 similarities.append(similarity) 332 333 avg_similarity = np.mean(similarities) 334 min_similarity = np.min(similarities) 335 max_similarity = np.max(similarities) 336 337 console.print( 338 f"[green]Found [bold]{len(results)}[/bold] posts from [bold]{did}[/bold] in the last 24h.[/green]\n" 339 f"[cyan]Average pairwise similarity: [bold]{avg_similarity:.4f}[/bold][/cyan]\n" 340 f"[cyan]Min: {min_similarity:.4f}, Max: {max_similarity:.4f}[/cyan]" 341 ) 342 343 if more: 344 console.print("\n[bold]Posts:[/bold]\n") 345 for i, point in enumerate(results): 346 text = point.payload.get("text", "") 347 timestamp = point.payload.get("timestamp", "") 348 console.print(f"[dim]{i + 1}. {timestamp}[/dim]") 349 console.print(text) 350 console.print() 351 352 353if __name__ == "__main__": 354 main()