+84
build_graph.py
+84
build_graph.py
···
1
+
from collections import UserString
2
+
import logging
3
+
from typing import Dict, Optional, Set
4
+
5
+
import click
6
+
7
+
from config import CONFIG
8
+
from indexer import FollowIndexer
9
+
import indexer
10
+
11
+
12
+
logging.basicConfig(
13
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
14
+
)
15
+
16
+
logger = logging.getLogger(__name__)
17
+
18
+
19
+
@click.command
20
+
@click.option(
21
+
"--ch-host",
22
+
)
23
+
@click.option(
24
+
"--ch-port",
25
+
type=int,
26
+
)
27
+
@click.option(
28
+
"--ch-user",
29
+
)
30
+
@click.option(
31
+
"--ch-pass",
32
+
)
33
+
def main(
34
+
ch_host: Optional[str],
35
+
ch_port: Optional[int],
36
+
ch_user: Optional[str],
37
+
ch_pass: Optional[str],
38
+
):
39
+
logger.info("Building follow graph...")
40
+
41
+
indexer = FollowIndexer(
42
+
clickhouse_host=ch_host or CONFIG.clickhouse_host,
43
+
clickhouse_port=ch_port or CONFIG.clickhouse_port,
44
+
clickhouse_user=ch_user or CONFIG.clickhouse_user,
45
+
clickhouse_pass=ch_pass or CONFIG.clickhouse_pass,
46
+
batch_size=1000,
47
+
)
48
+
49
+
graph: Dict[str, Set[str]] = {}
50
+
51
+
def build_graph(did: str, subject: str):
52
+
if did not in graph:
53
+
graph[did] = set()
54
+
55
+
graph[did].add(subject)
56
+
57
+
indexer.stream_follows(build_graph)
58
+
59
+
prox_map = {}
60
+
61
+
for did in graph:
62
+
first = graph.get(did, set())
63
+
64
+
second: Set[str] = set()
65
+
for subject in first:
66
+
second.update(graph.get(subject, set()))
67
+
68
+
prox_map[did] = {
69
+
"hop1": first,
70
+
"hop2": second - first - {did},
71
+
}
72
+
73
+
import pickle
74
+
75
+
with open("prox_map.pkl", "wb") as f:
76
+
pickle.dump(prox_map, f)
77
+
78
+
logger.info(
79
+
f"Finished building proximity map, saved to prox_map.pkl. {len(prox_map):,} users in map."
80
+
)
81
+
82
+
83
+
if __name__ == "__main__":
84
+
main()
+1
-1
config.py
+1
-1
config.py
···
14
14
kafka_bootstrap_servers: List[str] = ["localhost:9092"]
15
15
kafka_input_topic: str = "tap-events"
16
16
kafka_group_id: str = "followgrap-indexer"
17
-
kafka_auto_offset_reset: str = "latest"
17
+
kafka_auto_offset_reset: str = "earliest"
18
18
19
19
metrics_port: int = 8500
20
20
metrics_host: str = "0.0.0.0"
+114
-55
indexer.py
+114
-55
indexer.py
···
4
4
from datetime import datetime
5
5
from threading import Lock
6
6
from time import time
7
-
from typing import Any, List, Optional
7
+
from typing import Any, Callable, List, Optional
8
8
9
9
import click
10
10
from aiokafka import AIOKafkaConsumer, ConsumerRecord
···
13
13
14
14
from config import CONFIG
15
15
from metrics import prom_metrics
16
-
from models import AtKafkaEvent, Follow, FollowRecord, Unfollow
16
+
from models import Follow, FollowRecord, TapEvent, Unfollow
17
17
18
18
logging.basicConfig(
19
-
level=logging.INFO,
20
-
format=logging.BASIC_FORMAT,
19
+
level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
21
20
)
22
21
23
22
logger = logging.getLogger(__name__)
···
132
131
prom_metrics.insert_duration.labels(kind="follow", status=status).observe(
133
132
time() - start_time
134
133
)
134
+
prom_metrics.inserted.labels(kind="follow", status=status).inc(len(follows))
135
135
136
136
def insert_unfollow(self, unfollow: Unfollow):
137
137
to_insert: Optional[List[Unfollow]] = None
···
168
168
prom_metrics.insert_duration.labels(kind="unfollow", status=status).observe(
169
169
time() - start_time
170
170
)
171
+
prom_metrics.inserted.labels(kind="unfollow", status=status).inc(
172
+
len(unfollows)
173
+
)
171
174
172
175
def flush_all(self):
173
176
with self._follow_lock:
···
182
185
self._unfollow_batch = []
183
186
self._flush_unfollows(batch_to_flush)
184
187
188
+
def stream_follows(self, cb: Callable[[str, str], None], batch_size: int = 100_000):
189
+
query = """
190
+
SELECT f.did, f.subject
191
+
FROM follows f
192
+
LEFT ANTI JOIN unfollows u ON f.uri = u.uri
193
+
"""
194
+
195
+
try:
196
+
with self.client.query_row_block_stream(
197
+
query, settings={"max_block_size": batch_size}
198
+
) as stream:
199
+
total_handled = 0
200
+
for block in stream:
201
+
for row in block:
202
+
cb(row[0], row[1])
203
+
total_handled += 1
204
+
205
+
if total_handled % 1_000_000 == 0:
206
+
logger.info(f"Handled {total_handled:,} follows so far")
207
+
logger.info(f"Finished streaming {total_handled:,} follows")
208
+
except Exception as e:
209
+
logger.error(f"Error streaming follows: {e}")
210
+
185
211
186
212
class Consumer:
187
213
def __init__(
···
190
216
bootstrap_servers: List[str],
191
217
input_topic: str,
192
218
group_id: str,
219
+
max_concurrent_tasks: int = 100,
193
220
):
194
221
self.indexer = indexer
195
222
self.bootstrap_servers = bootstrap_servers
196
223
self.input_topic = input_topic
197
224
self.group_id = group_id
225
+
self.max_concurrent_tasks = max_concurrent_tasks
198
226
self.consumer: Optional[AIOKafkaConsumer] = None
199
227
self._flush_task: Optional[asyncio.Task[Any]] = None
228
+
self._semaphore: Optional[asyncio.Semaphore] = None
229
+
self._shutdown_event: Optional[asyncio.Event] = None
200
230
201
231
async def stop(self):
232
+
if self._shutdown_event:
233
+
self._shutdown_event.set()
234
+
202
235
if self._flush_task:
203
236
self._flush_task.cancel()
204
237
try:
···
226
259
kind = "unk"
227
260
228
261
try:
229
-
evt = AtKafkaEvent.model_validate(message.value)
262
+
evt = TapEvent.model_validate(message.value)
230
263
231
-
if not evt.operation or evt.operation.collection != "app.bsky.graph.follow":
264
+
if not evt.record or evt.record.collection != "app.bsky.graph.follow":
232
265
kind = "skipped"
233
266
status = "ok"
234
267
return
235
268
236
-
op = evt.operation
269
+
op = evt.record
270
+
uri = f"at://{op.did}/{op.collection}/{op.rkey}"
237
271
238
272
if op.action == "update":
239
273
kind = "update"
···
242
276
rec = FollowRecord.model_validate(op.record)
243
277
created_at = isoparse(rec.created_at)
244
278
follow = Follow(
245
-
uri=op.uri, did=evt.did, subject=rec.subject, created_at=created_at
279
+
uri=uri, did=op.did, subject=rec.subject, created_at=created_at
246
280
)
247
281
self.indexer.insert_follow(follow)
248
282
else:
249
283
kind = "delete"
250
-
unfollow = Unfollow(uri=op.uri, created_at=datetime.now())
284
+
285
+
unfollow = Unfollow(uri=uri, created_at=datetime.now())
286
+
251
287
self.indexer.insert_unfollow(unfollow)
252
288
253
289
status = "ok"
254
290
except Exception as e:
255
291
logger.error(f"Failed to handle event: {e}")
256
292
finally:
257
-
prom_metrics.events_handled.labels(kind=kind, status=status)
293
+
prom_metrics.events_handled.labels(kind=kind, status=status).inc()
294
+
295
+
async def _handle_event_with_semaphore(self, message: ConsumerRecord[Any, Any]):
296
+
assert self._semaphore is not None
297
+
async with self._semaphore:
298
+
await self._handle_event(message)
258
299
259
300
async def run(self):
301
+
self._semaphore = asyncio.Semaphore(self.max_concurrent_tasks)
302
+
self._shutdown_event = asyncio.Event()
303
+
260
304
self.consumer = AIOKafkaConsumer(
261
305
self.input_topic,
262
306
bootstrap_servers=",".join(self.bootstrap_servers),
···
270
314
)
271
315
await self.consumer.start()
272
316
logger.info(
273
-
f"Started Kafak consumer for topic: {self.bootstrap_servers}, {self.input_topic}"
317
+
f"Started Kafka consumer for topic: {self.bootstrap_servers}, {self.input_topic}"
274
318
)
275
319
276
-
if not self.consumer:
277
-
raise RuntimeError("Consumer not started, call start() first.")
320
+
self._flush_task = asyncio.create_task(self._periodic_flush())
321
+
322
+
pending_tasks: set[asyncio.Task[Any]] = set()
278
323
279
324
try:
280
325
async for message in self.consumer:
281
-
asyncio.ensure_future(self._handle_event(message))
282
326
prom_metrics.events_received.inc()
327
+
328
+
task = asyncio.create_task(self._handle_event_with_semaphore(message))
329
+
pending_tasks.add(task)
330
+
task.add_done_callback(pending_tasks.discard)
331
+
332
+
if len(pending_tasks) >= self.max_concurrent_tasks * 2:
333
+
done, pending_tasks_set = await asyncio.wait(
334
+
pending_tasks, timeout=0, return_when=asyncio.FIRST_COMPLETED
335
+
)
336
+
pending_tasks = pending_tasks_set
337
+
for t in done:
338
+
if t.exception():
339
+
logger.error(f"Task failed with exception: {t.exception()}")
340
+
283
341
except Exception as e:
284
342
logger.error(f"Error consuming messages: {e}")
285
343
raise
286
344
finally:
345
+
if pending_tasks:
346
+
logger.info(
347
+
f"Waiting for {len(pending_tasks)} pending tasks to complete..."
348
+
)
349
+
await asyncio.gather(*pending_tasks, return_exceptions=True)
287
350
self.indexer.flush_all()
288
351
289
352
290
353
@click.command()
291
-
@click.option(
292
-
"--ch-host",
293
-
)
294
-
@click.option(
295
-
"--ch-port",
296
-
type=int,
297
-
)
298
-
@click.option(
299
-
"--ch-user",
300
-
)
301
-
@click.option(
302
-
"--ch-pass",
303
-
)
304
-
@click.option(
305
-
"--batch-size",
306
-
type=int,
307
-
)
308
-
@click.option(
309
-
"--bootstrap-servers",
310
-
type=List[str],
311
-
)
354
+
@click.option("--ch-host")
355
+
@click.option("--ch-port", type=int)
356
+
@click.option("--ch-user")
357
+
@click.option("--ch-pass")
358
+
@click.option("--batch-size", type=int)
312
359
@click.option(
313
-
"--input-topic",
314
-
)
315
-
@click.option(
316
-
"--group-id",
317
-
)
318
-
@click.option(
319
-
"--metrics-host",
320
-
)
321
-
@click.option(
322
-
"--metrics-port",
323
-
type=int,
360
+
"--bootstrap-servers", help="Comma-separated list of Kafka bootstrap servers"
324
361
)
362
+
@click.option("--input-topic")
363
+
@click.option("--group-id")
364
+
@click.option("--metrics-host")
365
+
@click.option("--metrics-port", type=int)
325
366
def main(
326
367
ch_host: Optional[str],
327
368
ch_port: Optional[int],
328
369
ch_user: Optional[str],
329
370
ch_pass: Optional[str],
330
371
batch_size: Optional[int],
331
-
bootstrap_servers: Optional[List[str]],
372
+
bootstrap_servers: Optional[str],
332
373
input_topic: Optional[str],
333
374
group_id: Optional[str],
334
375
metrics_host: Optional[str],
···
348
389
)
349
390
indexer.init_schema()
350
391
392
+
kafka_servers = (
393
+
bootstrap_servers.split(",")
394
+
if bootstrap_servers
395
+
else CONFIG.kafka_bootstrap_servers
396
+
)
397
+
351
398
consumer = Consumer(
352
399
indexer=indexer,
353
-
bootstrap_servers=bootstrap_servers or CONFIG.kafka_bootstrap_servers,
400
+
bootstrap_servers=kafka_servers,
354
401
input_topic=input_topic or CONFIG.kafka_input_topic,
355
402
group_id=group_id or CONFIG.kafka_group_id,
356
403
)
357
404
358
-
try:
359
-
asyncio.run(consumer.run())
360
-
except KeyboardInterrupt:
361
-
logger.info("Shutting down...")
362
-
finally:
363
-
asyncio.run(consumer.stop())
405
+
async def run_with_shutdown():
406
+
loop = asyncio.get_event_loop()
407
+
408
+
import signal
409
+
410
+
def handle_signal():
411
+
logger.info("Received shutdown signal...")
412
+
asyncio.create_task(consumer.stop())
413
+
414
+
for sig in (signal.SIGTERM, signal.SIGINT):
415
+
loop.add_signal_handler(sig, handle_signal)
364
416
365
-
pass
417
+
try:
418
+
await consumer.run()
419
+
except asyncio.CancelledError:
420
+
pass
421
+
finally:
422
+
await consumer.stop()
423
+
424
+
asyncio.run(run_with_shutdown())
366
425
367
426
368
427
if __name__ == "__main__":
+8
-1
metrics.py
+8
-1
metrics.py
···
34
34
)
35
35
36
36
self.insert_duration = Histogram(
37
-
name="embedding_duration_seconds",
37
+
name="insert_duration_seconds",
38
38
namespace=NAMESPACE,
39
39
buckets=(
40
40
0.001,
···
52
52
),
53
53
labelnames=["kind", "status"],
54
54
documentation="Time taken to insert a batch",
55
+
)
56
+
57
+
self.inserted = Counter(
58
+
name="inserted",
59
+
namespace=NAMESPACE,
60
+
documentation="Number of items inserted",
61
+
labelnames=["kind", "status"],
55
62
)
56
63
57
64
self._initialized = True
+122
-69
models.py
+122
-69
models.py
···
1
+
import base64
2
+
1
3
from datetime import datetime
2
-
from typing import Any, Dict, Optional, List
4
+
import logging
5
+
from typing import Any, Dict, Optional
3
6
4
-
from pydantic import BaseModel, Field
7
+
from pydantic import BaseModel, Field, field_validator
5
8
9
+
logger = logging.getLogger(__name__)
6
10
7
-
class AtKafkaOp(BaseModel):
8
-
action: str
11
+
12
+
class RecordEvent(BaseModel):
13
+
"""
14
+
A model for record events that come from Tap, in Kafka mode
15
+
"""
16
+
17
+
live: Optional[bool] = False
18
+
did: str
19
+
rev: str
9
20
collection: str
10
21
rkey: str
11
-
uri: str
12
-
cid: str
13
-
path: str
14
-
record: Optional[Dict[str, Any]]
22
+
action: str
23
+
record: Optional[Dict[str, Any]] = None
24
+
cid: Optional[str] = None
15
25
16
26
17
-
class AtKafkaIdentity(BaseModel):
18
-
seq: int
19
-
handle: str
27
+
class IdentityEvent(BaseModel):
28
+
"""
29
+
A model for identity events taht come from Tap, in Kafka mode
30
+
"""
20
31
32
+
live: Optional[bool] = False
33
+
handle: Optional[str]
34
+
is_active: bool
35
+
status: str
21
36
22
-
class AtKafkaInfo(BaseModel):
23
-
name: str
24
-
message: Optional[str] = None
25
37
38
+
class TapEvent(BaseModel):
39
+
"""
40
+
The base model for events that come from Tap, in Kafka mode
41
+
"""
26
42
27
-
class AtKafkaAccount(BaseModel):
28
-
active: bool
29
-
seq: int
30
-
status: Optional[str] = None
43
+
id: int
44
+
type: str
45
+
record: Optional[RecordEvent] = None
46
+
identity: Optional[IdentityEvent] = None
31
47
48
+
@field_validator("record", "identity", mode="before")
49
+
@classmethod
50
+
def decode_base64(cls, v: Any):
51
+
if v is not None and isinstance(v, str):
52
+
try:
53
+
return base64.b64decode(v).decode("utf-8")
54
+
except Exception as e:
55
+
logger.error(f"Error decoding event base64: {e}")
56
+
return v
57
+
return v
32
58
33
-
class DIDDocument(BaseModel):
34
-
context: Optional[List[Any]] = Field(None, alias="@context")
35
-
id: Optional[str] = None
36
-
also_known_as: Optional[List[Any]] = Field(None, alias="alsoKnownAs")
37
-
verification_method: Optional[List[Any]] = Field(None, alias="verificationMethod")
38
-
service: Optional[List[Any]] = None
39
59
40
-
class Config:
41
-
populate_by_name = True
42
-
43
-
44
-
class ProfileViewDetailed(BaseModel):
45
-
did: str
46
-
handle: str
47
-
display_name: Optional[str] = Field(None, alias="displayName")
48
-
description: Optional[str] = None
49
-
avatar: Optional[str] = None
50
-
banner: Optional[str] = None
51
-
followers_count: Optional[int] = Field(None, alias="followersCount")
52
-
follows_count: Optional[int] = Field(None, alias="followsCount")
53
-
posts_count: Optional[int] = Field(None, alias="postsCount")
54
-
indexed_at: Optional[str] = Field(None, alias="indexedAt")
55
-
viewer: Optional[Dict[str, Any]] = None
56
-
labels: Optional[List[Any]] = None
57
-
58
-
class Config:
59
-
populate_by_name = True
60
-
61
-
62
-
class EventMetadata(BaseModel):
63
-
did_document: Optional[DIDDocument] = Field(None, alias="didDocument")
64
-
pds_host: Optional[str] = Field(None, alias="pdsHost")
65
-
handle: Optional[str] = None
66
-
did_created_at: Optional[str] = Field(None, alias="didCreatedAt")
67
-
account_age: Optional[int] = Field(None, alias="accountAge")
68
-
profile: Optional[ProfileViewDetailed] = None
69
-
70
-
class Config:
71
-
populate_by_name = True
72
-
73
-
74
-
class AtKafkaEvent(BaseModel):
75
-
did: str
76
-
timestamp: str
77
-
metadata: Optional[EventMetadata] = Field(None, alias="eventMetadata")
78
-
operation: Optional[AtKafkaOp] = None
79
-
account: Optional[AtKafkaAccount] = None
80
-
identity: Optional[AtKafkaIdentity] = None
81
-
info: Optional[AtKafkaInfo] = None
82
-
83
-
class Config:
84
-
populate_by_name = True
60
+
# class AtKafkaOp(BaseModel):
61
+
# action: str
62
+
# collection: str
63
+
# rkey: str
64
+
# uri: str
65
+
# cid: str
66
+
# path: str
67
+
# record: Optional[Dict[str, Any]]
68
+
#
69
+
#
70
+
# class AtKafkaIdentity(BaseModel):
71
+
# seq: int
72
+
# handle: str
73
+
#
74
+
#
75
+
# class AtKafkaInfo(BaseModel):
76
+
# name: str
77
+
# message: Optional[str] = None
78
+
#
79
+
#
80
+
# class AtKafkaAccount(BaseModel):
81
+
# active: bool
82
+
# seq: int
83
+
# status: Optional[str] = None
84
+
#
85
+
#
86
+
# class DIDDocument(BaseModel):
87
+
# context: Optional[List[Any]] = Field(None, alias="@context")
88
+
# id: Optional[str] = None
89
+
# also_known_as: Optional[List[Any]] = Field(None, alias="alsoKnownAs")
90
+
# verification_method: Optional[List[Any]] = Field(None, alias="verificationMethod")
91
+
# service: Optional[List[Any]] = None
92
+
#
93
+
# class Config:
94
+
# populate_by_name = True
95
+
#
96
+
#
97
+
# class ProfileViewDetailed(BaseModel):
98
+
# did: str
99
+
# handle: str
100
+
# display_name: Optional[str] = Field(None, alias="displayName")
101
+
# description: Optional[str] = None
102
+
# avatar: Optional[str] = None
103
+
# banner: Optional[str] = None
104
+
# followers_count: Optional[int] = Field(None, alias="followersCount")
105
+
# follows_count: Optional[int] = Field(None, alias="followsCount")
106
+
# posts_count: Optional[int] = Field(None, alias="postsCount")
107
+
# indexed_at: Optional[str] = Field(None, alias="indexedAt")
108
+
# viewer: Optional[Dict[str, Any]] = None
109
+
# labels: Optional[List[Any]] = None
110
+
#
111
+
# class Config:
112
+
# populate_by_name = True
113
+
#
114
+
#
115
+
# class EventMetadata(BaseModel):
116
+
# did_document: Optional[DIDDocument] = Field(None, alias="didDocument")
117
+
# pds_host: Optional[str] = Field(None, alias="pdsHost")
118
+
# handle: Optional[str] = None
119
+
# did_created_at: Optional[str] = Field(None, alias="didCreatedAt")
120
+
# account_age: Optional[int] = Field(None, alias="accountAge")
121
+
# profile: Optional[ProfileViewDetailed] = None
122
+
#
123
+
# class Config:
124
+
# populate_by_name = True
125
+
#
126
+
#
127
+
# class AtKafkaEvent(BaseModel):
128
+
# did: str
129
+
# timestamp: str
130
+
# metadata: Optional[EventMetadata] = Field(None, alias="eventMetadata")
131
+
# operation: Optional[AtKafkaOp] = None
132
+
# account: Optional[AtKafkaAccount] = None
133
+
# identity: Optional[AtKafkaIdentity] = None
134
+
# info: Optional[AtKafkaInfo] = None
135
+
#
136
+
# class Config:
137
+
# populate_by_name = True
85
138
86
139
87
140
class FollowRecord(BaseModel):