Mirror of https://git.olaren.dev/Olaren/shitsky

Get DIDs of our wanted PDSes at startup time

So that we can use the DIDs to prune the saved firehose stream into a nice subset! This PR doesn't have any re-checking which could be done periodically to make sure we're getting the most up to date data from the PDSes (which could also be done by listening to the account move events from the firehose).

Tangled d6c847c9 f02896b7

Changed files
+176 -6
src
+3
.env.example
··· 1 + # the main relay to listen to 2 + RELAY_URL=wss://pds.upcloud.world 3 + 1 4 # shitsky saves state here 2 5 DATABASE_URL=postgres://postgres:postgres@localhost:5432/shitsky 3 6
+1 -1
Cargo.toml
··· 13 13 futures-util = "0.3.31" 14 14 listenfd = "1.0.2" 15 15 maud = { version = "0.27", features = ["axum"] } 16 - reqwest = "0.12.23" 16 + reqwest = { version = "0.12", features = ["json"] } 17 17 rs-car = "0.5.0" 18 18 serde = { version = "1.0", features = ["derive"] } 19 19 serde_bytes = "0.11.19"
+30 -5
src/main.rs
··· 12 12 pub mod firehose; 13 13 use firehose::{FirehoseEvent, FirehoseOptions, subscribe_repos}; 14 14 15 + mod pds; 16 + use pds::get_all_active_dids_from_pdses; 17 + 15 18 type Db = PgPool; 16 19 17 20 #[tokio::main] ··· 19 22 tracing_subscriber::fmt::init(); 20 23 dotenvy::dotenv().ok(); 21 24 25 + let pds_hosts_str = std::env::var("PDS_LIST")?; 26 + 27 + let pds_hosts: Vec<String> = pds_hosts_str 28 + .split(',') 29 + .map(|s| s.trim().to_string()) 30 + .filter(|s| !s.is_empty()) 31 + .collect(); 32 + 33 + if pds_hosts.is_empty() { 34 + tracing::error!("Error: PDS_LIST environment variable is empty or contains only commas."); 35 + return Ok(()); 36 + } 37 + 38 + tracing::info!("Querying {} PDS(es): {:?}", pds_hosts.len(), pds_hosts); 39 + 40 + let _all_dids = get_all_active_dids_from_pdses(&pds_hosts).await?; 41 + 22 42 let db_url = std::env::var("DATABASE_URL").expect("DATABASE_URL must be set"); 23 43 let pool = PgPoolOptions::new() 24 44 .max_connections(5) ··· 32 52 let web_server_pool = pool.clone(); 33 53 tokio::spawn(async move { web_server(web_server_pool).await }); 34 54 35 - firehose_subscriber(pool).await; 55 + let relay_url = std::env::var("RELAY_URL").unwrap_or_default(); 56 + firehose_subscriber(pool, relay_url).await; 36 57 37 58 Ok(()) 38 59 } 39 60 40 - async fn firehose_subscriber(db: Db) { 61 + async fn firehose_subscriber(db: Db, relay_url: String) { 41 62 tracing::info!("Starting firehose subscriber..."); 42 63 43 - let options = FirehoseOptions { 44 - relay_url: "wss://bsky.network".to_string(), 45 - ..Default::default() 64 + let options = if relay_url.is_empty() { 65 + FirehoseOptions::default() 66 + } else { 67 + FirehoseOptions { 68 + relay_url, 69 + ..Default::default() 70 + } 46 71 }; 47 72 48 73 let mut stream = Box::pin(subscribe_repos(options));
+142
src/pds.rs
··· 1 + use futures_util::future::join_all; 2 + use reqwest::Client; 3 + use serde::Deserialize; 4 + use thiserror::Error; 5 + 6 + #[derive(Debug, Error)] 7 + pub enum PdsError { 8 + #[error("Network request failed: {0}")] 9 + RequestError(#[from] reqwest::Error), 10 + #[error("Failed to join task: {0}")] 11 + JoinError(#[from] tokio::task::JoinError), 12 + #[error("Environment variable not found: {0}")] 13 + EnvVarError(#[from] std::env::VarError), 14 + } 15 + 16 + #[derive(Deserialize, Debug)] 17 + struct Repo { 18 + did: String, 19 + } 20 + 21 + #[derive(Deserialize, Debug)] 22 + #[serde(rename_all = "camelCase")] 23 + struct ListReposResponse { 24 + cursor: Option<String>, 25 + repos: Vec<Repo>, 26 + } 27 + 28 + #[derive(Deserialize, Debug)] 29 + #[serde(rename_all = "camelCase")] 30 + struct DescribeRepoResponse { 31 + handle_is_correct: bool, 32 + } 33 + 34 + pub async fn get_all_active_dids_from_pdses(pds_hosts: &[String]) -> Result<Vec<String>, PdsError> { 35 + let client = Client::new(); 36 + let mut tasks = Vec::new(); 37 + 38 + for host in pds_hosts { 39 + let host_clone = host.clone(); 40 + let client_clone = client.clone(); 41 + tasks.push(tokio::spawn(async move { 42 + fetch_active_dids_from_single_pds(host_clone, client_clone).await 43 + })); 44 + } 45 + 46 + let results = join_all(tasks).await; 47 + let mut all_dids = Vec::new(); 48 + 49 + for result in results { 50 + let pds_dids = result??; 51 + all_dids.extend(pds_dids); 52 + } 53 + 54 + tracing::info!("--- Sample of fetched ACTIVE DIDs (first 10) ---"); 55 + for did in all_dids.iter().take(10) { 56 + tracing::info!("{}", did); 57 + } 58 + tracing::info!("... and {} more.", all_dids.len().saturating_sub(10)); 59 + tracing::info!("--- Total Active DIDs fetched: {} ---", all_dids.len()); 60 + 61 + Ok(all_dids) 62 + } 63 + 64 + async fn fetch_active_dids_from_single_pds( 65 + host: String, 66 + client: Client, 67 + ) -> Result<Vec<String>, PdsError> { 68 + let mut active_dids = Vec::new(); 69 + let mut cursor: Option<String> = None; 70 + let limit = 1000; 71 + 72 + tracing::info!("Fetching active DIDs from PDS: {}", host); 73 + 74 + loop { 75 + let url = match &cursor { 76 + Some(c) => format!( 77 + "https://{}/xrpc/com.atproto.sync.listRepos?limit={}&cursor={}", 78 + host, limit, c 79 + ), 80 + None => format!( 81 + "https://{}/xrpc/com.atproto.sync.listRepos?limit={}", 82 + host, limit 83 + ), 84 + }; 85 + 86 + let response: ListReposResponse = client.get(&url).send().await?.json().await?; 87 + let dids_on_page: Vec<String> = response.repos.into_iter().map(|r| r.did).collect(); 88 + 89 + if !dids_on_page.is_empty() { 90 + let mut check_tasks = Vec::new(); 91 + for did in dids_on_page { 92 + let host_clone = host.clone(); 93 + let client_clone = client.clone(); 94 + check_tasks.push(tokio::spawn(async move { 95 + let url = format!( 96 + "https://{}/xrpc/com.atproto.repo.describeRepo?repo={}", 97 + host_clone, did 98 + ); 99 + 100 + let repo_info_result = async { 101 + client_clone 102 + .get(&url) 103 + .send() 104 + .await? 105 + .json::<DescribeRepoResponse>() 106 + .await 107 + } 108 + .await; 109 + 110 + if let Ok(repo_info) = repo_info_result 111 + && repo_info.handle_is_correct 112 + { 113 + Some(did) 114 + } else { 115 + None 116 + } 117 + })); 118 + } 119 + 120 + let checked_results = join_all(check_tasks).await; 121 + for result in checked_results { 122 + if let Ok(Some(active_did)) = result { 123 + active_dids.push(active_did); 124 + } 125 + } 126 + } 127 + 128 + if let Some(next_cursor) = response.cursor { 129 + cursor = Some(next_cursor); 130 + } else { 131 + break; 132 + } 133 + } 134 + 135 + tracing::info!( 136 + "Finished fetching {} active DIDs from {}.", 137 + active_dids.len(), 138 + host 139 + ); 140 + 141 + Ok(active_dids) 142 + }