ML-based recommendation feed for Bluesky posts
1import asyncio
2import logging
3import os
4import sys
5
6import pandas as pd
7
8
9logger = logging.getLogger(__name__)
10logger.setLevel(logging.INFO)
11
12BSKY_API_LIMIT = 10
13
14
15class RateLimit:
16 def __init__(self, per_second: int):
17 self.per_second = per_second
18 self.cur_count = 0
19 self.refresh_event = asyncio.Event()
20 self.refresh_running = False
21
22 async def sleep_then_refresh(self):
23 await asyncio.sleep(1)
24 self.cur_count = 0
25 self.refresh_event.set()
26 self.refresh_running = False
27
28 async def acquire(self):
29 # If we have remaining capacity in this second, dont block
30 if self.cur_count < self.per_second:
31 # Start timer for when our rate allocation refreshes
32 if not self.refresh_running:
33 self.refresh_running = True
34 asyncio.create_task(self.sleep_then_refresh())
35 self.cur_count += 1
36 return
37
38 # Otherwise we need to wait until current second is over
39 # and our rate allocation refreshes
40 self.refresh_event.clear()
41 await self.refresh_event.wait()
42
43 # Just recursively call after waiting
44 return await self.acquire()
45
46
47def load_checkpoint(ckpt_dir: str) -> set[str]:
48 # If checkpoint dir doesn't exist, try to create it
49 if not os.path.isdir(ckpt_dir):
50 logger.info("Checkpoint dir doesn't exist, creating...")
51 try:
52 os.mkdir(ckpt_dir)
53 except Exception as e:
54 logger.error(f"Failed to created checkpoint dir, {ckpt_dir}\n{e}")
55 sys.exit(1)
56
57 # Checkpoint folders contain one file per user
58 completed_accounts = set()
59 try:
60 files = os.listdir(ckpt_dir)
61 for file in files:
62 # Grab entire file name except for .gz extension
63 completed_accounts.add(file[:-3])
64 except Exception as e:
65 logger.error(
66 f"Failed to recover from checkpoint dir, {ckpt_dir}\n{e}",
67 exc_info=1,
68 )
69 sys.exit(1)
70
71 return completed_accounts
72
73
74def get_accounts(graph_path: str, completed_accts: set[str]) -> list[tuple[str, int]]:
75 # Load follow graph parquet file
76 to_explore = dict()
77 try:
78 logger.info("Parsing follower graph file...")
79 follow_df = pd.read_parquet(graph_path)
80 # Limit to only accounts following between 100 and 1000 followers
81 follow_df = follow_df.loc[follow_df["follows"].str.len().between(100, 1000)]
82 except Exception as e:
83 logger.error(f"Failed to open follow graph file, {graph_path}\n{e}")
84 sys.exit(1)
85
86 for _, row in follow_df.iterrows():
87 for acct in row["follows"]:
88 if acct not in completed_accts:
89 if acct not in to_explore:
90 to_explore[acct] = 0
91 to_explore[acct] += 1
92
93 accts = [(acct, follows) for acct, follows in to_explore.items()]
94 accts.sort(key=lambda x: -1 * x[1])
95 return accts
96
97
98def get_logger(name: str):
99
100 logger = logging.getLogger(name)
101 logger.setLevel(logging.INFO)
102
103 # Create formatter
104 formatter = logging.Formatter(
105 "%(asctime)s.%(msecs)03d | %(levelname)s | %(message)s",
106 datefmt="%Y-%m-%d %H:%M:%S",
107 )
108
109 # Console handler
110 console_handler = logging.StreamHandler(sys.stdout)
111 console_handler.setFormatter(formatter)
112 logger.addHandler(console_handler)
113 return logger