ML-based recommendation feed for Bluesky posts
5
fork

Configure Feed

Select the types of activity you want to include in your feed.

Trainin' stuff

+228 -73
+45 -14
scripts/split_follows.py
··· 5 5 import os 6 6 import random 7 7 import sys 8 + from typing import Optional 8 9 9 10 import numpy as np 10 11 12 + from scripts.utils import get_logger 11 13 12 - logger = logging.getLogger(__name__) 13 - logger.setLevel(logging.INFO) 14 + logger = get_logger(__name__) 14 15 15 - # Create formatter 16 - formatter = logging.Formatter( 17 - "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 18 - ) 19 16 20 - # Console handler 21 - console_handler = logging.StreamHandler(sys.stdout) 22 - console_handler.setFormatter(formatter) 23 - logger.addHandler(console_handler) 24 - 25 - 26 - def main(follow_dir: str, output_dir: str, val_split: float, file_size: int): 17 + def main( 18 + follow_dir: str, 19 + output_dir: str, 20 + val_split: float, 21 + file_size: int, 22 + max_accounts: Optional[int], 23 + exclude_unfetched: bool = False, 24 + ): 27 25 if val_split >= 1.0 or val_split <= 0.0: 28 26 raise ValueError("Validation split must be between 0 and 1 exclusive") 29 27 ··· 32 30 33 31 did_id_map: dict[str, int] = dict() 34 32 33 + files = os.listdir(follow_dir) 34 + if max_accounts: 35 + files = files[:max_accounts] 36 + 37 + if exclude_unfetched: 38 + for file in files: 39 + did_id_map[file[-3]] = len(did_id_map) 40 + 35 41 logger.info("Reading follows...") 36 42 train_file_idx = 1 37 43 val_file_idx = 1 ··· 39 45 val_follows: list[tuple[int, int]] = [] 40 46 train_files = [] 41 47 val_files = [] 42 - for file in os.listdir(follow_dir): 48 + for file in files: 43 49 source_did = file[:-3] # Remove .gz extension 50 + if max_accounts and len(did_id_map) >= max_accounts: 51 + continue 52 + 44 53 if source_did not in did_id_map: 45 54 did_id_map[source_did] = len(did_id_map) 46 55 47 56 for line in gzip.open(follow_dir + file, "rt"): 48 57 target_did = json.loads(line)["value"]["subject"] 58 + if exclude_unfetched and target_did not in did_id_map: 59 + continue 60 + 61 + if max_accounts and len(did_id_map) >= max_accounts: 62 + continue 63 + 49 64 if target_did not in did_id_map: 50 65 did_id_map[target_did] = len(did_id_map) 51 66 ··· 190 205 type=int, 191 206 help="Max rows per output file", 192 207 ) 208 + parser.add_argument( 209 + "--exclude-unfetched", 210 + dest="exclude_unfetched", 211 + required=False, 212 + type=bool, 213 + help="Whether to include accounts whose follows haven't been retrieved", 214 + ) 215 + parser.add_argument( 216 + "--max-accounts", 217 + dest="max_accounts", 218 + required=False, 219 + type=int, 220 + help="Maximum number of accounts to include", 221 + ) 193 222 args = parser.parse_args() 194 223 195 224 main( ··· 197 226 output_dir=args.output_dir, 198 227 val_split=args.val_split, 199 228 file_size=args.file_size, 229 + exclude_unfetched=args.exclude_unfetched, 230 + max_accounts=args.max_accounts, 200 231 )
+76 -24
scripts/training/data.py
··· 1 + from functools import reduce 1 2 import json 2 3 import glob 4 + from multiprocessing import Pool 3 5 import os 4 6 import random 5 7 ··· 8 10 from torch.utils.data import Dataset 9 11 10 12 13 + from scripts.utils import get_logger 14 + 15 + logger = get_logger(__name__) 16 + 17 + 18 + def get_frequencies(array: np.ndarray, user_count: int) -> np.ndarray: 19 + frequencies = np.zeros((user_count, 2)) 20 + for source_id, target_id in array: 21 + frequencies[source_id][0] += 1 22 + frequencies[target_id][1] += 1 23 + return frequencies 24 + 25 + 11 26 class FollowDataset(Dataset): 12 27 def __init__( 13 28 self, 14 29 dataset_path: str, 15 30 split: str, 16 - negative_sample_count: int, 31 + negative_sample_chance: float, 17 32 ): 18 33 with open(os.path.join(dataset_path, "metadata.json"), "r") as in_file: 19 34 metadata = json.load(in_file) ··· 25 40 raise ValueError(f"Could not find split in metadata file: {split}") 26 41 split_files = metadata["splits"][split] 27 42 28 - self.parquet_files: list[tuple[str, str, int, int]] = [] 43 + self.numpy_files: list[tuple[str, str, int, int]] = [] 29 44 for file in split_files: 30 45 file_idx = file["filename"].split("_")[2].split(".")[0] 31 - self.parquet_files.append((file["filename"], file["dtype"], file_idx, file["shape"][0])) # type: ignore 46 + self.numpy_files.append((file["filename"], file["dtype"], file_idx, file["shape"][0])) # type: ignore 32 47 33 - self.parquet_files.sort(key=lambda x: x[1]) 48 + self.numpy_files.sort(key=lambda x: x[1]) 34 49 self.dataframes: list[np.ndarray] = [] 35 - for filename, dtype, _, row_count in self.parquet_files: 50 + for filename, dtype, _, row_count in self.numpy_files: 36 51 self.dataframes.append( 37 52 np.memmap( 38 53 os.path.join(dataset_path, filename), ··· 42 57 ) 43 58 ) 44 59 45 - self.negative_sample_count = negative_sample_count 60 + logger.info("Calculating node frequency...") 61 + with Pool(7) as p: 62 + self.cumulative_freq = reduce( 63 + np.add, 64 + p.starmap( 65 + get_frequencies, 66 + [ 67 + (dataframe, len(self.did_id_map)) 68 + for dataframe in self.dataframes 69 + ], 70 + ), 71 + ) 72 + self.cumulative_freq[0] = self.cumulative_freq[0].cumsum() 73 + self.cumulative_freq[1] = self.cumulative_freq[1].cumsum() 74 + 75 + self.negative_sample_chance = negative_sample_chance 46 76 47 77 def __len__(self) -> int: 48 - return sum((row_count for _, _, _, row_count in self.parquet_files)) 78 + return sum((row_count for _, _, _, row_count in self.numpy_files)) 49 79 50 80 def num_users(self) -> int: 51 81 return len(self.did_id_map) ··· 58 88 row_index_total = 0 59 89 effective_idx = idx 60 90 i = 0 61 - for i, (_, _, _, row_count) in enumerate(self.parquet_files): 91 + for i, (_, _, _, row_count) in enumerate(self.numpy_files): 62 92 row_index_total += row_count 63 93 if idx < row_index_total: 64 94 break ··· 67 97 row = self.dataframes[i][effective_idx] 68 98 return (row[0].item(), row[1].item()) 69 99 70 - def __getitem__(self, idx: int) -> tuple[list[tuple[int, int]], list[int]]: 100 + def __getitem__(self, idx: int) -> tuple[tuple[int, int], int]: 71 101 """ 72 - Grab follow connection and create negative samples (follows that dont exist) 102 + Grab follow connection and corrupt it at defined frequency to another id 103 + weighted by that id's prevalence in the dataset 73 104 """ 74 - samples = [self._idx_to_row(idx)] 75 - 76 - for _ in range(self.negative_sample_count): 77 - rand_idx = random.randrange(0, len(self.did_id_map)) 78 - # Assume that users arent connected (roughly 1/700 chance they are) 79 - samples.append((samples[0][0], rand_idx)) 80 - 81 - labels = [-1] * len(samples) 82 - labels[0] = 1 83 - return samples, labels 105 + sample = self._idx_to_row(idx) 106 + return (sample, 1) 84 107 85 - @staticmethod 86 108 def collate_rows( 87 - batch: list[tuple[list[tuple[int, int]], list[int]]], 109 + self, 110 + batch: list[tuple[tuple[int, int], int]], 88 111 ) -> tuple[torch.Tensor, torch.Tensor]: 89 - follows = torch.concat(tuple(torch.IntTensor(follow) for follow, _ in batch)) 90 - labels = torch.concat(tuple(torch.IntTensor(label) for _, label in batch)) 112 + # Corrupt some rows into negative edges 113 + corrupted_sources = [] 114 + corrupted_targets = [] 115 + for i in range(len(batch)): 116 + if random.random() < self.negative_sample_chance: 117 + if random.random() < 0.5: 118 + corrupted_sources.append(i) 119 + else: 120 + corrupted_targets.append(i) 121 + 122 + new_sources = self.cumulative_freq[:, 0].searchsorted( 123 + np.random.sample(len(corrupted_sources)) * self.cumulative_freq.shape[0] 124 + ) 125 + new_targets = self.cumulative_freq[:, 1].searchsorted( 126 + np.random.sample(len(corrupted_targets)) * self.cumulative_freq.shape[0] 127 + ) 128 + new_sources[new_sources >= self.cumulative_freq.shape[0]] = ( 129 + self.cumulative_freq.shape[0] - 1 130 + ) 131 + new_targets[new_targets >= self.cumulative_freq.shape[0]] = ( 132 + self.cumulative_freq.shape[0] - 1 133 + ) 134 + 135 + for i, idx in enumerate(corrupted_sources): 136 + batch[idx] = ((new_sources[i], batch[idx][0][1]), -1) 137 + 138 + for i, idx in enumerate(corrupted_targets): 139 + batch[idx] = ((batch[idx][0][0], new_targets[i]), -1) 140 + 141 + follows = torch.concat(tuple(torch.IntTensor([follow]) for follow, _ in batch)) 142 + labels = torch.concat(tuple(torch.IntTensor([label]) for _, label in batch)) 91 143 return (follows, labels)
+6 -7
scripts/training/models.py
··· 1 1 import torch 2 2 from torch import nn 3 - from torch.utils.data import DataLoader 4 3 import torch.nn.functional as F 5 4 import lightning as pyl 6 - 7 - from scripts.training.data import FollowDataset 8 5 9 6 10 7 class UserEmbedding(nn.Module): ··· 19 16 20 17 21 18 class FollowEmbedModule(pyl.LightningModule): 22 - def __init__(self, source_embedding, target_embedding): 19 + def __init__(self, num_embeds: int, embed_dim: int, learning_rate: float): 23 20 super().__init__() 24 - self.source_embed = source_embedding 25 - self.target_embed = target_embedding 21 + self.save_hyperparameters() 22 + self.source_embed = UserEmbedding(num_embeds, embed_dim) 23 + self.target_embed = UserEmbedding(num_embeds, embed_dim) 24 + self.learning_rate = learning_rate 26 25 27 26 def training_step(self, batch, _): 28 27 x, y = batch ··· 53 52 print(f"\nTraining Loss: {avg_loss:.4f}\n") 54 53 55 54 def configure_optimizers(self): 56 - optimizers = torch.optim.Adagrad(self.parameters(), lr=1e-1) 55 + optimizers = torch.optim.Adagrad(self.parameters(), lr=self.learning_rate) 57 56 return optimizers
+38
scripts/training/serve.py
··· 1 + from argparse import ArgumentParser 2 + import json 3 + 4 + from scripts.training.models import UserEmbedding, FollowEmbedModule 5 + 6 + 7 + def main(checkpoint_path: str, id_map_path: str): 8 + model = FollowEmbedModule.load_from_checkpoint(checkpoint_path) 9 + 10 + print(f"Shape of source embed weights: {model.source_embed.embed.weight.shape}") 11 + with open(id_map_path, "r") as in_file: 12 + did_id_map = json.load(in_file) 13 + 14 + did = "did:plc:uxmy3zztxyhfk6mxrkun5tpr" 15 + user_id = did_id_map[did] 16 + print(f"did: {did}") 17 + print(f"id: {did_id_map[did]}") 18 + 19 + source_tensor = model.source_embed.embed.weight[user_id, :] 20 + print(f"Embedding: {source_tensor}") 21 + 22 + 23 + if __name__ == "__main__": 24 + parser = ArgumentParser() 25 + parser.add_argument( 26 + "--checkpoint", 27 + required=True, 28 + help="Path to model checkpoint file", 29 + ) 30 + parser.add_argument( 31 + "--did-id-map", 32 + dest="did_id_map", 33 + required=True, 34 + help="Path to did_id_map file", 35 + ) 36 + args = parser.parse_args() 37 + 38 + main(checkpoint_path=args.checkpoint, id_map_path=args.did_id_map)
+46 -28
scripts/training/train.py
··· 10 10 from torchdata.stateful_dataloader import StatefulDataLoader 11 11 12 12 from scripts.training.data import FollowDataset 13 - from scripts.training.models import UserEmbedding, FollowEmbedModule 13 + from scripts.training.models import FollowEmbedModule 14 + from scripts.utils import get_logger 14 15 16 + logger = get_logger(__name__) 15 17 16 18 PROJECT = "goodposts-followers" 17 19 SAVE_DIR = "./data/training" 18 20 NAME = "following-embedding" 19 21 20 22 21 - def main(run_id: Optional[str], model_checkpoint: Optional[str]): 23 + def main(run_id: Optional[str], model_checkpoint: Optional[str], dataset_path: str): 22 24 if run_id is None: 23 25 run_id = NAME + str(int(datetime.now().timestamp())) 24 26 25 - print("Loading validation dataset") 26 - val_dataset = FollowDataset( 27 - dataset_path="data/processed/numpy/", 28 - split="validation", 29 - negative_sample_count=0, 30 - ) 31 - 32 - print("Loading training dataset") 27 + logger.info("Loading training dataset") 33 28 train_dataset = FollowDataset( 34 - dataset_path="data/processed/numpy/", 29 + dataset_path=dataset_path, 35 30 split="training", 36 - negative_sample_count=2, 31 + negative_sample_chance=0.5, 37 32 ) 38 33 39 - val_dataloader = StatefulDataLoader( 40 - val_dataset, 41 - pin_memory=True, 42 - collate_fn=val_dataset.collate_rows, 43 - batch_size=1024, 44 - num_workers=7, 34 + logger.info("Loading validation dataset") 35 + val_dataset = FollowDataset( 36 + dataset_path=dataset_path, 37 + split="validation", 38 + negative_sample_chance=0.0, 45 39 ) 46 40 47 41 train_dataloader = StatefulDataLoader( ··· 50 44 collate_fn=train_dataset.collate_rows, 51 45 batch_size=1024, 52 46 num_workers=7, 47 + shuffle=True, 53 48 ) 54 49 55 - print("Instantiating embedding model") 50 + val_dataloader = StatefulDataLoader( 51 + val_dataset, 52 + pin_memory=True, 53 + collate_fn=val_dataset.collate_rows, 54 + batch_size=1024, 55 + num_workers=7, 56 + ) 57 + 58 + logger.info("Instantiating embedding model") 56 59 wandb_logger = WandbLogger( 57 60 project=PROJECT, 58 61 save_dir=SAVE_DIR, ··· 69 72 monitor="val_loss", 70 73 ) 71 74 72 - embedding = FollowEmbedModule( 73 - source_embedding=UserEmbedding(train_dataset.num_users(), 64), 74 - target_embedding=UserEmbedding(train_dataset.num_users(), 64), 75 - ) 75 + if model_checkpoint: 76 + embedding = FollowEmbedModule.load_from_checkpoint( 77 + checkpoint_path=model_checkpoint 78 + ) 79 + else: 80 + embedding = FollowEmbedModule( 81 + num_embeds=train_dataset.num_users(), 82 + embed_dim=256, 83 + learning_rate=1e-1, 84 + ) 76 85 77 86 torch.set_float32_matmul_precision("medium") 78 87 trainer = pyl.Trainer( ··· 81 90 accelerator="gpu", 82 91 devices=1, 83 92 precision="bf16", 84 - max_epochs=2, 85 - limit_train_batches=0.01, 86 - limit_val_batches=1_000, 87 - val_check_interval=0.05, 93 + gradient_clip_val=0.5, 94 + max_epochs=10, 95 + val_check_interval=0.5, 88 96 ) 89 97 trainer.fit( 90 98 model=embedding, ··· 108 116 dest="model_checkpoint", 109 117 help="Path to checkpoint with run data", 110 118 ) 119 + parser.add_argument( 120 + "-d", 121 + "--dataset-path", 122 + dest="dataset_path", 123 + help="Path to dataset directory", 124 + ) 111 125 112 126 args = parser.parse_args() 113 - main(run_id=args.run_id, model_checkpoint=args.model_checkpoint) 127 + main( 128 + run_id=args.run_id, 129 + model_checkpoint=args.model_checkpoint, 130 + dataset_path=args.dataset_path, 131 + )
+17
scripts/utils.py
··· 93 93 accts = [(acct, follows) for acct, follows in to_explore.items()] 94 94 accts.sort(key=lambda x: -1 * x[1]) 95 95 return accts 96 + 97 + 98 + def get_logger(name: str): 99 + 100 + logger = logging.getLogger(name) 101 + logger.setLevel(logging.INFO) 102 + 103 + # Create formatter 104 + formatter = logging.Formatter( 105 + "%(asctime)s | %(levelname)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S" 106 + ) 107 + 108 + # Console handler 109 + console_handler = logging.StreamHandler(sys.stdout) 110 + console_handler.setFormatter(formatter) 111 + logger.addHandler(console_handler) 112 + return logger