ML-based recommendation feed for Bluesky posts
at main 3.4 kB view raw
1import asyncio 2import logging 3import os 4import sys 5 6import pandas as pd 7 8 9logger = logging.getLogger(__name__) 10logger.setLevel(logging.INFO) 11 12BSKY_API_LIMIT = 10 13 14 15class RateLimit: 16 def __init__(self, per_second: int): 17 self.per_second = per_second 18 self.cur_count = 0 19 self.refresh_event = asyncio.Event() 20 self.refresh_running = False 21 22 async def sleep_then_refresh(self): 23 await asyncio.sleep(1) 24 self.cur_count = 0 25 self.refresh_event.set() 26 self.refresh_running = False 27 28 async def acquire(self): 29 # If we have remaining capacity in this second, dont block 30 if self.cur_count < self.per_second: 31 # Start timer for when our rate allocation refreshes 32 if not self.refresh_running: 33 self.refresh_running = True 34 asyncio.create_task(self.sleep_then_refresh()) 35 self.cur_count += 1 36 return 37 38 # Otherwise we need to wait until current second is over 39 # and our rate allocation refreshes 40 self.refresh_event.clear() 41 await self.refresh_event.wait() 42 43 # Just recursively call after waiting 44 return await self.acquire() 45 46 47def load_checkpoint(ckpt_dir: str) -> set[str]: 48 # If checkpoint dir doesn't exist, try to create it 49 if not os.path.isdir(ckpt_dir): 50 logger.info("Checkpoint dir doesn't exist, creating...") 51 try: 52 os.mkdir(ckpt_dir) 53 except Exception as e: 54 logger.error(f"Failed to created checkpoint dir, {ckpt_dir}\n{e}") 55 sys.exit(1) 56 57 # Checkpoint folders contain one file per user 58 completed_accounts = set() 59 try: 60 files = os.listdir(ckpt_dir) 61 for file in files: 62 # Grab entire file name except for .gz extension 63 completed_accounts.add(file[:-3]) 64 except Exception as e: 65 logger.error( 66 f"Failed to recover from checkpoint dir, {ckpt_dir}\n{e}", 67 exc_info=1, 68 ) 69 sys.exit(1) 70 71 return completed_accounts 72 73 74def get_accounts(graph_path: str, completed_accts: set[str]) -> list[tuple[str, int]]: 75 # Load follow graph parquet file 76 to_explore = dict() 77 try: 78 logger.info("Parsing follower graph file...") 79 follow_df = pd.read_parquet(graph_path) 80 # Limit to only accounts following between 100 and 1000 followers 81 follow_df = follow_df.loc[follow_df["follows"].str.len().between(100, 1000)] 82 except Exception as e: 83 logger.error(f"Failed to open follow graph file, {graph_path}\n{e}") 84 sys.exit(1) 85 86 for _, row in follow_df.iterrows(): 87 for acct in row["follows"]: 88 if acct not in completed_accts: 89 if acct not in to_explore: 90 to_explore[acct] = 0 91 to_explore[acct] += 1 92 93 accts = [(acct, follows) for acct, follows in to_explore.items()] 94 accts.sort(key=lambda x: -1 * x[1]) 95 return accts 96 97 98def get_logger(name: str): 99 100 logger = logging.getLogger(name) 101 logger.setLevel(logging.INFO) 102 103 # Create formatter 104 formatter = logging.Formatter( 105 "%(asctime)s.%(msecs)03d | %(levelname)s | %(message)s", 106 datefmt="%Y-%m-%d %H:%M:%S", 107 ) 108 109 # Console handler 110 console_handler = logging.StreamHandler(sys.stdout) 111 console_handler.setFormatter(formatter) 112 logger.addHandler(console_handler) 113 return logger