Mirror of https://git.olaren.dev/Olaren/shitsky
1use futures_util::future::join_all;
2use reqwest::Client;
3use serde::Deserialize;
4use thiserror::Error;
5
6#[derive(Debug, Error)]
7pub enum PdsError {
8 #[error("Network request failed: {0}")]
9 RequestError(#[from] reqwest::Error),
10 #[error("Failed to join task: {0}")]
11 JoinError(#[from] tokio::task::JoinError),
12 #[error("Environment variable not found: {0}")]
13 EnvVarError(#[from] std::env::VarError),
14}
15
16#[derive(Deserialize, Debug)]
17struct Repo {
18 did: String,
19}
20
21#[derive(Deserialize, Debug)]
22#[serde(rename_all = "camelCase")]
23struct ListReposResponse {
24 cursor: Option<String>,
25 repos: Vec<Repo>,
26}
27
28#[derive(Deserialize, Debug)]
29#[serde(rename_all = "camelCase")]
30struct DescribeRepoResponse {
31 handle_is_correct: bool,
32}
33
34pub async fn get_all_active_dids_from_pdses(pds_hosts: &[String]) -> Result<Vec<String>, PdsError> {
35 let client = Client::new();
36 let mut tasks = Vec::new();
37
38 for host in pds_hosts {
39 let host_clone = host.clone();
40 let client_clone = client.clone();
41 tasks.push(tokio::spawn(async move {
42 fetch_active_dids_from_single_pds(host_clone, client_clone).await
43 }));
44 }
45
46 let results = join_all(tasks).await;
47 let mut all_dids = Vec::new();
48
49 for result in results {
50 let pds_dids = result??;
51 all_dids.extend(pds_dids);
52 }
53
54 tracing::info!("--- Sample of fetched ACTIVE DIDs (first 10) ---");
55 for did in all_dids.iter().take(10) {
56 tracing::info!("{}", did);
57 }
58 tracing::info!("... and {} more.", all_dids.len().saturating_sub(10));
59 tracing::info!("--- Total Active DIDs fetched: {} ---", all_dids.len());
60
61 Ok(all_dids)
62}
63
64async fn fetch_active_dids_from_single_pds(
65 host: String,
66 client: Client,
67) -> Result<Vec<String>, PdsError> {
68 let mut active_dids = Vec::new();
69 let mut cursor: Option<String> = None;
70 let limit = 1000;
71
72 tracing::info!("Fetching active DIDs from PDS: {}", host);
73
74 loop {
75 let url = match &cursor {
76 Some(c) => format!(
77 "https://{}/xrpc/com.atproto.sync.listRepos?limit={}&cursor={}",
78 host, limit, c
79 ),
80 None => format!(
81 "https://{}/xrpc/com.atproto.sync.listRepos?limit={}",
82 host, limit
83 ),
84 };
85
86 let response: ListReposResponse = client.get(&url).send().await?.json().await?;
87 let dids_on_page: Vec<String> = response.repos.into_iter().map(|r| r.did).collect();
88
89 if !dids_on_page.is_empty() {
90 let mut check_tasks = Vec::new();
91 for did in dids_on_page {
92 let host_clone = host.clone();
93 let client_clone = client.clone();
94 check_tasks.push(tokio::spawn(async move {
95 let url = format!(
96 "https://{}/xrpc/com.atproto.repo.describeRepo?repo={}",
97 host_clone, did
98 );
99
100 let repo_info_result = async {
101 client_clone
102 .get(&url)
103 .send()
104 .await?
105 .json::<DescribeRepoResponse>()
106 .await
107 }
108 .await;
109
110 if let Ok(repo_info) = repo_info_result
111 && repo_info.handle_is_correct
112 {
113 Some(did)
114 } else {
115 None
116 }
117 }));
118 }
119
120 let checked_results = join_all(check_tasks).await;
121 for result in checked_results {
122 if let Ok(Some(active_did)) = result {
123 active_dids.push(active_did);
124 }
125 }
126 }
127
128 if let Some(next_cursor) = response.cursor {
129 cursor = Some(next_cursor);
130 } else {
131 break;
132 }
133 }
134
135 tracing::info!(
136 "Finished fetching {} active DIDs from {}.",
137 active_dids.len(),
138 host
139 );
140
141 Ok(active_dids)
142}