this repo has no description

Compare changes

Choose any two refs to compare.

+2
.gitignore
··· 8 8 9 9 # Virtual environments 10 10 .venv 11 + 12 + *.pkl
+84
build_graph.py
··· 1 + from collections import UserString 2 + import logging 3 + from typing import Dict, Optional, Set 4 + 5 + import click 6 + 7 + from config import CONFIG 8 + from indexer import FollowIndexer 9 + import indexer 10 + 11 + 12 + logging.basicConfig( 13 + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 14 + ) 15 + 16 + logger = logging.getLogger(__name__) 17 + 18 + 19 + @click.command 20 + @click.option( 21 + "--ch-host", 22 + ) 23 + @click.option( 24 + "--ch-port", 25 + type=int, 26 + ) 27 + @click.option( 28 + "--ch-user", 29 + ) 30 + @click.option( 31 + "--ch-pass", 32 + ) 33 + def main( 34 + ch_host: Optional[str], 35 + ch_port: Optional[int], 36 + ch_user: Optional[str], 37 + ch_pass: Optional[str], 38 + ): 39 + logger.info("Building follow graph...") 40 + 41 + indexer = FollowIndexer( 42 + clickhouse_host=ch_host or CONFIG.clickhouse_host, 43 + clickhouse_port=ch_port or CONFIG.clickhouse_port, 44 + clickhouse_user=ch_user or CONFIG.clickhouse_user, 45 + clickhouse_pass=ch_pass or CONFIG.clickhouse_pass, 46 + batch_size=1000, 47 + ) 48 + 49 + graph: Dict[str, Set[str]] = {} 50 + 51 + def build_graph(did: str, subject: str): 52 + if did not in graph: 53 + graph[did] = set() 54 + 55 + graph[did].add(subject) 56 + 57 + indexer.stream_follows(build_graph) 58 + 59 + prox_map = {} 60 + 61 + for did in graph: 62 + first = graph.get(did, set()) 63 + 64 + second: Set[str] = set() 65 + for subject in first: 66 + second.update(graph.get(subject, set())) 67 + 68 + prox_map[did] = { 69 + "hop1": first, 70 + "hop2": second - first - {did}, 71 + } 72 + 73 + import pickle 74 + 75 + with open("prox_map.pkl", "wb") as f: 76 + pickle.dump(prox_map, f) 77 + 78 + logger.info( 79 + f"Finished building proximity map, saved to prox_map.pkl. {len(prox_map):,} users in map." 80 + ) 81 + 82 + 83 + if __name__ == "__main__": 84 + main()
+1 -1
config.py
··· 14 14 kafka_bootstrap_servers: List[str] = ["localhost:9092"] 15 15 kafka_input_topic: str = "tap-events" 16 16 kafka_group_id: str = "followgrap-indexer" 17 - kafka_auto_offset_reset: str = "latest" 17 + kafka_auto_offset_reset: str = "earliest" 18 18 19 19 metrics_port: int = 8500 20 20 metrics_host: str = "0.0.0.0"
+114 -55
indexer.py
··· 4 4 from datetime import datetime 5 5 from threading import Lock 6 6 from time import time 7 - from typing import Any, List, Optional 7 + from typing import Any, Callable, List, Optional 8 8 9 9 import click 10 10 from aiokafka import AIOKafkaConsumer, ConsumerRecord ··· 13 13 14 14 from config import CONFIG 15 15 from metrics import prom_metrics 16 - from models import AtKafkaEvent, Follow, FollowRecord, Unfollow 16 + from models import Follow, FollowRecord, TapEvent, Unfollow 17 17 18 18 logging.basicConfig( 19 - level=logging.INFO, 20 - format=logging.BASIC_FORMAT, 19 + level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" 21 20 ) 22 21 23 22 logger = logging.getLogger(__name__) ··· 132 131 prom_metrics.insert_duration.labels(kind="follow", status=status).observe( 133 132 time() - start_time 134 133 ) 134 + prom_metrics.inserted.labels(kind="follow", status=status).inc(len(follows)) 135 135 136 136 def insert_unfollow(self, unfollow: Unfollow): 137 137 to_insert: Optional[List[Unfollow]] = None ··· 168 168 prom_metrics.insert_duration.labels(kind="unfollow", status=status).observe( 169 169 time() - start_time 170 170 ) 171 + prom_metrics.inserted.labels(kind="unfollow", status=status).inc( 172 + len(unfollows) 173 + ) 171 174 172 175 def flush_all(self): 173 176 with self._follow_lock: ··· 182 185 self._unfollow_batch = [] 183 186 self._flush_unfollows(batch_to_flush) 184 187 188 + def stream_follows(self, cb: Callable[[str, str], None], batch_size: int = 100_000): 189 + query = """ 190 + SELECT f.did, f.subject 191 + FROM follows f 192 + LEFT ANTI JOIN unfollows u ON f.uri = u.uri 193 + """ 194 + 195 + try: 196 + with self.client.query_row_block_stream( 197 + query, settings={"max_block_size": batch_size} 198 + ) as stream: 199 + total_handled = 0 200 + for block in stream: 201 + for row in block: 202 + cb(row[0], row[1]) 203 + total_handled += 1 204 + 205 + if total_handled % 1_000_000 == 0: 206 + logger.info(f"Handled {total_handled:,} follows so far") 207 + logger.info(f"Finished streaming {total_handled:,} follows") 208 + except Exception as e: 209 + logger.error(f"Error streaming follows: {e}") 210 + 185 211 186 212 class Consumer: 187 213 def __init__( ··· 190 216 bootstrap_servers: List[str], 191 217 input_topic: str, 192 218 group_id: str, 219 + max_concurrent_tasks: int = 100, 193 220 ): 194 221 self.indexer = indexer 195 222 self.bootstrap_servers = bootstrap_servers 196 223 self.input_topic = input_topic 197 224 self.group_id = group_id 225 + self.max_concurrent_tasks = max_concurrent_tasks 198 226 self.consumer: Optional[AIOKafkaConsumer] = None 199 227 self._flush_task: Optional[asyncio.Task[Any]] = None 228 + self._semaphore: Optional[asyncio.Semaphore] = None 229 + self._shutdown_event: Optional[asyncio.Event] = None 200 230 201 231 async def stop(self): 232 + if self._shutdown_event: 233 + self._shutdown_event.set() 234 + 202 235 if self._flush_task: 203 236 self._flush_task.cancel() 204 237 try: ··· 226 259 kind = "unk" 227 260 228 261 try: 229 - evt = AtKafkaEvent.model_validate(message.value) 262 + evt = TapEvent.model_validate(message.value) 230 263 231 - if not evt.operation or evt.operation.collection != "app.bsky.graph.follow": 264 + if not evt.record or evt.record.collection != "app.bsky.graph.follow": 232 265 kind = "skipped" 233 266 status = "ok" 234 267 return 235 268 236 - op = evt.operation 269 + op = evt.record 270 + uri = f"at://{op.did}/{op.collection}/{op.rkey}" 237 271 238 272 if op.action == "update": 239 273 kind = "update" ··· 242 276 rec = FollowRecord.model_validate(op.record) 243 277 created_at = isoparse(rec.created_at) 244 278 follow = Follow( 245 - uri=op.uri, did=evt.did, subject=rec.subject, created_at=created_at 279 + uri=uri, did=op.did, subject=rec.subject, created_at=created_at 246 280 ) 247 281 self.indexer.insert_follow(follow) 248 282 else: 249 283 kind = "delete" 250 - unfollow = Unfollow(uri=op.uri, created_at=datetime.now()) 284 + 285 + unfollow = Unfollow(uri=uri, created_at=datetime.now()) 286 + 251 287 self.indexer.insert_unfollow(unfollow) 252 288 253 289 status = "ok" 254 290 except Exception as e: 255 291 logger.error(f"Failed to handle event: {e}") 256 292 finally: 257 - prom_metrics.events_handled.labels(kind=kind, status=status) 293 + prom_metrics.events_handled.labels(kind=kind, status=status).inc() 294 + 295 + async def _handle_event_with_semaphore(self, message: ConsumerRecord[Any, Any]): 296 + assert self._semaphore is not None 297 + async with self._semaphore: 298 + await self._handle_event(message) 258 299 259 300 async def run(self): 301 + self._semaphore = asyncio.Semaphore(self.max_concurrent_tasks) 302 + self._shutdown_event = asyncio.Event() 303 + 260 304 self.consumer = AIOKafkaConsumer( 261 305 self.input_topic, 262 306 bootstrap_servers=",".join(self.bootstrap_servers), ··· 270 314 ) 271 315 await self.consumer.start() 272 316 logger.info( 273 - f"Started Kafak consumer for topic: {self.bootstrap_servers}, {self.input_topic}" 317 + f"Started Kafka consumer for topic: {self.bootstrap_servers}, {self.input_topic}" 274 318 ) 275 319 276 - if not self.consumer: 277 - raise RuntimeError("Consumer not started, call start() first.") 320 + self._flush_task = asyncio.create_task(self._periodic_flush()) 321 + 322 + pending_tasks: set[asyncio.Task[Any]] = set() 278 323 279 324 try: 280 325 async for message in self.consumer: 281 - asyncio.ensure_future(self._handle_event(message)) 282 326 prom_metrics.events_received.inc() 327 + 328 + task = asyncio.create_task(self._handle_event_with_semaphore(message)) 329 + pending_tasks.add(task) 330 + task.add_done_callback(pending_tasks.discard) 331 + 332 + if len(pending_tasks) >= self.max_concurrent_tasks * 2: 333 + done, pending_tasks_set = await asyncio.wait( 334 + pending_tasks, timeout=0, return_when=asyncio.FIRST_COMPLETED 335 + ) 336 + pending_tasks = pending_tasks_set 337 + for t in done: 338 + if t.exception(): 339 + logger.error(f"Task failed with exception: {t.exception()}") 340 + 283 341 except Exception as e: 284 342 logger.error(f"Error consuming messages: {e}") 285 343 raise 286 344 finally: 345 + if pending_tasks: 346 + logger.info( 347 + f"Waiting for {len(pending_tasks)} pending tasks to complete..." 348 + ) 349 + await asyncio.gather(*pending_tasks, return_exceptions=True) 287 350 self.indexer.flush_all() 288 351 289 352 290 353 @click.command() 291 - @click.option( 292 - "--ch-host", 293 - ) 294 - @click.option( 295 - "--ch-port", 296 - type=int, 297 - ) 298 - @click.option( 299 - "--ch-user", 300 - ) 301 - @click.option( 302 - "--ch-pass", 303 - ) 304 - @click.option( 305 - "--batch-size", 306 - type=int, 307 - ) 308 - @click.option( 309 - "--bootstrap-servers", 310 - type=List[str], 311 - ) 354 + @click.option("--ch-host") 355 + @click.option("--ch-port", type=int) 356 + @click.option("--ch-user") 357 + @click.option("--ch-pass") 358 + @click.option("--batch-size", type=int) 312 359 @click.option( 313 - "--input-topic", 314 - ) 315 - @click.option( 316 - "--group-id", 317 - ) 318 - @click.option( 319 - "--metrics-host", 320 - ) 321 - @click.option( 322 - "--metrics-port", 323 - type=int, 360 + "--bootstrap-servers", help="Comma-separated list of Kafka bootstrap servers" 324 361 ) 362 + @click.option("--input-topic") 363 + @click.option("--group-id") 364 + @click.option("--metrics-host") 365 + @click.option("--metrics-port", type=int) 325 366 def main( 326 367 ch_host: Optional[str], 327 368 ch_port: Optional[int], 328 369 ch_user: Optional[str], 329 370 ch_pass: Optional[str], 330 371 batch_size: Optional[int], 331 - bootstrap_servers: Optional[List[str]], 372 + bootstrap_servers: Optional[str], 332 373 input_topic: Optional[str], 333 374 group_id: Optional[str], 334 375 metrics_host: Optional[str], ··· 348 389 ) 349 390 indexer.init_schema() 350 391 392 + kafka_servers = ( 393 + bootstrap_servers.split(",") 394 + if bootstrap_servers 395 + else CONFIG.kafka_bootstrap_servers 396 + ) 397 + 351 398 consumer = Consumer( 352 399 indexer=indexer, 353 - bootstrap_servers=bootstrap_servers or CONFIG.kafka_bootstrap_servers, 400 + bootstrap_servers=kafka_servers, 354 401 input_topic=input_topic or CONFIG.kafka_input_topic, 355 402 group_id=group_id or CONFIG.kafka_group_id, 356 403 ) 357 404 358 - try: 359 - asyncio.run(consumer.run()) 360 - except KeyboardInterrupt: 361 - logger.info("Shutting down...") 362 - finally: 363 - asyncio.run(consumer.stop()) 405 + async def run_with_shutdown(): 406 + loop = asyncio.get_event_loop() 407 + 408 + import signal 409 + 410 + def handle_signal(): 411 + logger.info("Received shutdown signal...") 412 + asyncio.create_task(consumer.stop()) 413 + 414 + for sig in (signal.SIGTERM, signal.SIGINT): 415 + loop.add_signal_handler(sig, handle_signal) 364 416 365 - pass 417 + try: 418 + await consumer.run() 419 + except asyncio.CancelledError: 420 + pass 421 + finally: 422 + await consumer.stop() 423 + 424 + asyncio.run(run_with_shutdown()) 366 425 367 426 368 427 if __name__ == "__main__":
+8 -1
metrics.py
··· 34 34 ) 35 35 36 36 self.insert_duration = Histogram( 37 - name="embedding_duration_seconds", 37 + name="insert_duration_seconds", 38 38 namespace=NAMESPACE, 39 39 buckets=( 40 40 0.001, ··· 52 52 ), 53 53 labelnames=["kind", "status"], 54 54 documentation="Time taken to insert a batch", 55 + ) 56 + 57 + self.inserted = Counter( 58 + name="inserted", 59 + namespace=NAMESPACE, 60 + documentation="Number of items inserted", 61 + labelnames=["kind", "status"], 55 62 ) 56 63 57 64 self._initialized = True
+122 -69
models.py
··· 1 + import base64 2 + 1 3 from datetime import datetime 2 - from typing import Any, Dict, Optional, List 4 + import logging 5 + from typing import Any, Dict, Optional 3 6 4 - from pydantic import BaseModel, Field 7 + from pydantic import BaseModel, Field, field_validator 5 8 9 + logger = logging.getLogger(__name__) 6 10 7 - class AtKafkaOp(BaseModel): 8 - action: str 11 + 12 + class RecordEvent(BaseModel): 13 + """ 14 + A model for record events that come from Tap, in Kafka mode 15 + """ 16 + 17 + live: Optional[bool] = False 18 + did: str 19 + rev: str 9 20 collection: str 10 21 rkey: str 11 - uri: str 12 - cid: str 13 - path: str 14 - record: Optional[Dict[str, Any]] 22 + action: str 23 + record: Optional[Dict[str, Any]] = None 24 + cid: Optional[str] = None 15 25 16 26 17 - class AtKafkaIdentity(BaseModel): 18 - seq: int 19 - handle: str 27 + class IdentityEvent(BaseModel): 28 + """ 29 + A model for identity events taht come from Tap, in Kafka mode 30 + """ 20 31 32 + live: Optional[bool] = False 33 + handle: Optional[str] 34 + is_active: bool 35 + status: str 21 36 22 - class AtKafkaInfo(BaseModel): 23 - name: str 24 - message: Optional[str] = None 25 37 38 + class TapEvent(BaseModel): 39 + """ 40 + The base model for events that come from Tap, in Kafka mode 41 + """ 26 42 27 - class AtKafkaAccount(BaseModel): 28 - active: bool 29 - seq: int 30 - status: Optional[str] = None 43 + id: int 44 + type: str 45 + record: Optional[RecordEvent] = None 46 + identity: Optional[IdentityEvent] = None 31 47 48 + @field_validator("record", "identity", mode="before") 49 + @classmethod 50 + def decode_base64(cls, v: Any): 51 + if v is not None and isinstance(v, str): 52 + try: 53 + return base64.b64decode(v).decode("utf-8") 54 + except Exception as e: 55 + logger.error(f"Error decoding event base64: {e}") 56 + return v 57 + return v 32 58 33 - class DIDDocument(BaseModel): 34 - context: Optional[List[Any]] = Field(None, alias="@context") 35 - id: Optional[str] = None 36 - also_known_as: Optional[List[Any]] = Field(None, alias="alsoKnownAs") 37 - verification_method: Optional[List[Any]] = Field(None, alias="verificationMethod") 38 - service: Optional[List[Any]] = None 39 59 40 - class Config: 41 - populate_by_name = True 42 - 43 - 44 - class ProfileViewDetailed(BaseModel): 45 - did: str 46 - handle: str 47 - display_name: Optional[str] = Field(None, alias="displayName") 48 - description: Optional[str] = None 49 - avatar: Optional[str] = None 50 - banner: Optional[str] = None 51 - followers_count: Optional[int] = Field(None, alias="followersCount") 52 - follows_count: Optional[int] = Field(None, alias="followsCount") 53 - posts_count: Optional[int] = Field(None, alias="postsCount") 54 - indexed_at: Optional[str] = Field(None, alias="indexedAt") 55 - viewer: Optional[Dict[str, Any]] = None 56 - labels: Optional[List[Any]] = None 57 - 58 - class Config: 59 - populate_by_name = True 60 - 61 - 62 - class EventMetadata(BaseModel): 63 - did_document: Optional[DIDDocument] = Field(None, alias="didDocument") 64 - pds_host: Optional[str] = Field(None, alias="pdsHost") 65 - handle: Optional[str] = None 66 - did_created_at: Optional[str] = Field(None, alias="didCreatedAt") 67 - account_age: Optional[int] = Field(None, alias="accountAge") 68 - profile: Optional[ProfileViewDetailed] = None 69 - 70 - class Config: 71 - populate_by_name = True 72 - 73 - 74 - class AtKafkaEvent(BaseModel): 75 - did: str 76 - timestamp: str 77 - metadata: Optional[EventMetadata] = Field(None, alias="eventMetadata") 78 - operation: Optional[AtKafkaOp] = None 79 - account: Optional[AtKafkaAccount] = None 80 - identity: Optional[AtKafkaIdentity] = None 81 - info: Optional[AtKafkaInfo] = None 82 - 83 - class Config: 84 - populate_by_name = True 60 + # class AtKafkaOp(BaseModel): 61 + # action: str 62 + # collection: str 63 + # rkey: str 64 + # uri: str 65 + # cid: str 66 + # path: str 67 + # record: Optional[Dict[str, Any]] 68 + # 69 + # 70 + # class AtKafkaIdentity(BaseModel): 71 + # seq: int 72 + # handle: str 73 + # 74 + # 75 + # class AtKafkaInfo(BaseModel): 76 + # name: str 77 + # message: Optional[str] = None 78 + # 79 + # 80 + # class AtKafkaAccount(BaseModel): 81 + # active: bool 82 + # seq: int 83 + # status: Optional[str] = None 84 + # 85 + # 86 + # class DIDDocument(BaseModel): 87 + # context: Optional[List[Any]] = Field(None, alias="@context") 88 + # id: Optional[str] = None 89 + # also_known_as: Optional[List[Any]] = Field(None, alias="alsoKnownAs") 90 + # verification_method: Optional[List[Any]] = Field(None, alias="verificationMethod") 91 + # service: Optional[List[Any]] = None 92 + # 93 + # class Config: 94 + # populate_by_name = True 95 + # 96 + # 97 + # class ProfileViewDetailed(BaseModel): 98 + # did: str 99 + # handle: str 100 + # display_name: Optional[str] = Field(None, alias="displayName") 101 + # description: Optional[str] = None 102 + # avatar: Optional[str] = None 103 + # banner: Optional[str] = None 104 + # followers_count: Optional[int] = Field(None, alias="followersCount") 105 + # follows_count: Optional[int] = Field(None, alias="followsCount") 106 + # posts_count: Optional[int] = Field(None, alias="postsCount") 107 + # indexed_at: Optional[str] = Field(None, alias="indexedAt") 108 + # viewer: Optional[Dict[str, Any]] = None 109 + # labels: Optional[List[Any]] = None 110 + # 111 + # class Config: 112 + # populate_by_name = True 113 + # 114 + # 115 + # class EventMetadata(BaseModel): 116 + # did_document: Optional[DIDDocument] = Field(None, alias="didDocument") 117 + # pds_host: Optional[str] = Field(None, alias="pdsHost") 118 + # handle: Optional[str] = None 119 + # did_created_at: Optional[str] = Field(None, alias="didCreatedAt") 120 + # account_age: Optional[int] = Field(None, alias="accountAge") 121 + # profile: Optional[ProfileViewDetailed] = None 122 + # 123 + # class Config: 124 + # populate_by_name = True 125 + # 126 + # 127 + # class AtKafkaEvent(BaseModel): 128 + # did: str 129 + # timestamp: str 130 + # metadata: Optional[EventMetadata] = Field(None, alias="eventMetadata") 131 + # operation: Optional[AtKafkaOp] = None 132 + # account: Optional[AtKafkaAccount] = None 133 + # identity: Optional[AtKafkaIdentity] = None 134 + # info: Optional[AtKafkaInfo] = None 135 + # 136 + # class Config: 137 + # populate_by_name = True 85 138 86 139 87 140 class FollowRecord(BaseModel):