Personal ATProto tools.
1//! Requests to the atproto API.
2use std::env;
3use std::fs;
4use std::sync::Arc;
5
6use base64::Engine;
7use reqwest::Client;
8use tokio::sync::Mutex;
9
10use crate::types::{LabelsVecWithSeq, RetrievedLabelResponse, SignatureBytes};
11enum ApiEndpoint {
12 Authorized,
13 Public,
14}
15/// Agent for interactions with atproto.
16#[derive(Clone)]
17pub struct Agent {
18 /// The access JWT.
19 pub access_jwt: Arc<Mutex<String>>,
20 /// The refresh JWT.
21 pub refresh_jwt: Arc<Mutex<String>>,
22 /// The reqwest client.
23 pub client: Client,
24 /// The DID of the labeler.
25 pub self_did: Arc<String>,
26}
27impl Default for Agent {
28 fn default() -> Self {
29 drop(dotenvy::dotenv().expect("Failed to load .env file"));
30 Self {
31 access_jwt: Arc::new(Mutex::new(env::var("ACCESS_JWT").expect("ACCESS_JWT must be set"))),
32 refresh_jwt: Arc::new(Mutex::new(env::var("REFRESH_JWT").expect("REFRESH_JWT must be set"))),
33 client: Client::new(),
34 self_did: Arc::new(env::var("SELF_DID").expect("SELF_DID must be set")),
35 }
36 }
37}
38impl Agent {
39 /// The base URL of the atproto API's XRPC endpoint.
40 /// Rate limit: 3_000 per 5 minutes
41 const AUTH_URL: &'static str = "https://bsky.social/xrpc/";
42 const PUBLIC_URL: &'static str = "https://public.api.bsky.app/xrpc/";
43 async fn client_get(
44 &self,
45 path: &str,
46 parameters: &[(&str, &str)],
47 api_endpoint: &ApiEndpoint,
48 ) -> reqwest::Response {
49 self.client
50 .get(format!("{}{}", match api_endpoint {
51 ApiEndpoint::Authorized => Self::AUTH_URL,
52 ApiEndpoint::Public => Self::PUBLIC_URL,
53 }, &path))
54 .header("Content-Type", "application/json")
55 .header("Authorization", format!("Bearer {}", self.access_jwt.lock().await))
56 .header("atproto-accept-labelers", self.self_did.as_str())
57 .query(parameters)
58 .send()
59 .await.expect("Expected to be able to send request, but failed.")
60 }
61 async fn client_refresh(&self) {
62 tracing::warn!("Token expired, refreshing");
63 let response = self.client
64 .post(format!(
65 "{}{}",
66 Self::AUTH_URL,
67 "com.atproto.server.refreshSession"
68 ))
69 .header("Content-Type", "application/json")
70 .header("Authorization", format!("Bearer {}", self.refresh_jwt.lock().await))
71 .header("atproto-accept-labelers", self.self_did.as_str())
72 .send()
73 .await.expect("Expected to be able to send request, but failed.");
74 let json = response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed.");
75 if let Some(error) = json["error"].as_str() {
76 match error {
77 "InvalidRequest" => {
78 tracing::warn!("Invalid request");
79 return;
80 },
81 "ExpiredToken" => {
82 tracing::warn!("Token expired");
83 return;
84 },
85 "AccountDeactivated" => {
86 tracing::warn!("Account deactivated");
87 return;
88 },
89 "AccountTakedown" => {
90 tracing::warn!("Account has been suspended (Takedown)");
91 return;
92 },
93 _ => {
94 tracing::warn!("Unknown error from HTTP response: {:?}", json);
95 return;
96 }
97 }
98 }
99 *self.refresh_jwt.lock().await = json["refreshJwt"].as_str().expect("Expected to be able to read refreshJwt as str, but failed.").to_owned();
100 *self.access_jwt.lock().await = json["accessJwt"].as_str().expect("Expected to be able to read accessJwt as str, but failed.").to_owned();
101 let new_env = format!(
102 "ACCESS_JWT={}\nREFRESH_JWT={}\n",
103 self.access_jwt.lock().await, self.refresh_jwt.lock().await
104 );
105 fs::write(".env", new_env).expect("Failed to write to .env");
106 tracing::info!("Token refreshed");
107 }
108 /// Get a JSON response from the atproto API. Used internal to this struct.
109 async fn get(
110 &self,
111 path: &str,
112 parameters: &[(&str, &str)],
113 api_endpoint: ApiEndpoint,
114 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
115 let response = self.client_get(path, parameters, &api_endpoint).await;
116 if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS {
117 tracing::warn!("Rate limited, sleeping for 5 minutes");
118 tracing::warn!("We were working on {} with parameters {:?}", path, parameters);
119 tokio::time::sleep(std::time::Duration::from_secs(305)).await; // 5 minutes and 5 seconds
120 let response = self.client_get(path, parameters, &api_endpoint).await;
121 return Ok(response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed."));
122 }
123 if response.status() == reqwest::StatusCode::BAD_REQUEST {
124 let json = &response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed.");
125 match json["error"].as_str().expect("Expected to be able to read error as str, but failed.") {
126 "ExpiredToken" => {
127 self.client_refresh().await;
128 let response = self.client_get(path, parameters, &api_endpoint).await;
129 return Ok(response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed."));
130 },
131 "AccountDeactivated" => {
132 tracing::warn!("Account deactivated");
133 return Err(Box::new(std::io::Error::new(
134 std::io::ErrorKind::Other,
135 "Account deactivated",
136 )));
137 },
138 "AccountTakedown" => {
139 tracing::warn!("Account has been suspended (Takedown)");
140 return Err(Box::new(std::io::Error::new(
141 std::io::ErrorKind::Other,
142 "Account deactivated",
143 )));
144 },
145 "InvalidRequest" => {
146 // Check if the message is "Profile not found"
147 if json["message"].as_str().expect("Expected to be able to read message as str, but failed.") == "Profile not found" {
148 tracing::warn!("Profile not found");
149 return Err(Box::new(std::io::Error::new(
150 std::io::ErrorKind::NotFound,
151 "Profile not found",
152 )));
153 }
154 tracing::warn!("Unknown invalid request: {:?}", json);
155 return Err(Box::new(std::io::Error::new(
156 std::io::ErrorKind::Other,
157 "Unknown invalid request",
158 )));
159 },
160 _ => {
161 tracing::warn!("Unknown error from HTTP response: {:?}", json);
162 return Err(Box::new(std::io::Error::new(
163 std::io::ErrorKind::Other,
164 "Unknown bad request",
165 )));
166 }
167 };
168 }
169 if response.status() != reqwest::StatusCode::OK {
170 return Err(Box::new(std::io::Error::new(
171 std::io::ErrorKind::Other,
172 "Unknown HTTP error",
173 )));
174 }
175 let json = response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed.");
176 Ok(json)
177 }
178 /// Get a profile from the atproto API.
179 pub async fn get_profile(
180 &mut self,
181 profile_id: &str,
182 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
183 let path = "app.bsky.actor.getProfile";
184 let parameters = [("actor", profile_id)];
185 self.get(path, ¶meters, ApiEndpoint::Public).await
186 }
187 /// Get multiple profiles.
188 pub async fn get_profiles(
189 &mut self,
190 profile_ids: &[String],
191 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> {
192 let path = "app.bsky.actor.getProfiles";
193 let mut parameters = Vec::new();
194 for profile_id in profile_ids {
195 parameters.push(("actors", profile_id.as_str()));
196 }
197 self.get(path, parameters.as_slice(), ApiEndpoint::Authorized).await
198 }
199 /// Check if a list of profiles has a label from us.
200 pub async fn check_profiles(
201 &mut self,
202 profile_ids: &[(String, i64)],
203 ) -> Result<Vec<(bool, (String, i64))>, Box<dyn std::error::Error + Send + Sync>> {
204 let mut found_labels: Vec<(bool, (String, i64))> = Vec::new();
205 let profile_ids_uris = profile_ids.iter().map(|(profile_id, _)| profile_id.clone()).collect::<Vec<String>>();
206 let profile_ids_seqs = profile_ids.iter().map(|(_, seq)| seq).collect::<Vec<&i64>>();
207 let profiles = self.get_profiles(profile_ids_uris.as_slice()).await?;
208 let profiles_array = profiles["profiles"].as_array();
209 if profiles_array.is_none() {
210 tracing::warn!("No profiles json found for profiles: {:?}", profiles);
211 return Ok(vec![]);
212 }
213 for profile in profiles_array.unwrap_or_else(|| panic!("Expected to be able to read profiles as array, but failed. Profiles: {:?}", profiles)) {
214 let labels = &profile["labels"];
215 let mut found = false;
216 let label_array = labels.as_array();
217 if label_array.is_none() {
218 tracing::warn!("No labels json found for profile: {:?}", profile);
219 continue;
220 }
221 let did = profile["did"].as_str().expect("Expected to be able to read did as str, but failed.");
222 let seq = profile_ids_seqs[profile_ids_uris.iter().position(|x| x == did).expect("Expected to be able to find the index of the uri.")];
223 for label in label_array.unwrap_or_else(|| panic!("Expected to be able to read labels as array, but failed. Profile: {:?}", profile)) {
224 if label["src"].as_str().expect("Expected to be able to read src as str, but failed.") == self.self_did.as_str() {
225 found = true;
226 break;
227 }
228 }
229 found_labels.push((found, (did.to_owned(), *seq)));
230 }
231 Ok(found_labels)
232 }
233 /// After getting a profile, check the labels on it, and see if one from us ("src:") is there.
234 pub async fn check_profile(
235 &mut self,
236 profile_did: &str,
237 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> {
238 let profile = self.get_profile(profile_did).await?;
239 let labels = &profile["labels"];
240 let label_array = labels.as_array();
241 if label_array.is_none() {
242 tracing::warn!("No labels json found for profile: {:?}", profile);
243 return Ok(false);
244 }
245 for label in label_array.unwrap_or_else(|| panic!("Expected to be able to read labels as array, but failed. Profile: {:?}", profile)) {
246 if label["src"].as_str().expect("Expected to be able to read src as str, but failed.") == self.self_did.as_str() {
247 return Ok(true);
248 }
249 }
250 Ok(false)
251 }
252 /// Get a label from the provided URL, then validate the signature.
253 pub async fn get_label_and_validate(
254 &self,
255 url: &str,
256 ) -> Result<(), Box<dyn std::error::Error>> {
257 tracing::debug!("Getting label from {}", url);
258 let response = reqwest::get(url).await.expect("Expected to be able to get response, but failed.");
259 tracing::debug!("Response: {:?}", response);
260 let response_json = response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed.");
261 tracing::debug!("Response JSON: {:?}", response_json);
262 let sig = &response_json["labels"][0]["sig"];
263 tracing::debug!("Signature: {:?}", sig);
264 let retrieved_label = RetrievedLabelResponse {
265 // id: response_json["labels"][0]["id"].as_u64().unwrap(),
266 cts: response_json["labels"][0]["cts"].as_str().expect("Expected to be able to read cts as str, but failed.").to_owned(),
267 neg: response_json["labels"][0]["neg"].as_str() == Some("true"),
268 src: response_json["labels"][0]["src"].as_str().expect("Expected to be able to read src as str, but failed.").to_owned().parse().expect("Failed to parse DID"),
269 uri: response_json["labels"][0]["uri"].as_str().expect("Expected to be able to read uri as str, but failed.").to_owned().parse().expect("Failed to parse URI"),
270 val: response_json["labels"][0]["val"].as_str().expect("Expected to be able to read val as str, but failed.").to_owned(),
271 ver: response_json["labels"][0]["ver"].as_u64().expect("Expected to be able to read ver as u64, but failed."),
272 };
273 let crypto = crate::crypto::Crypto::new();
274 let pub_key = "zQ3shreqyXEdouQeEQSFKfoSEN5eig74BXuqQyTaiE9uzADqZ";
275 let sig_string = sig["$bytes"].as_str().expect("Expected to be able to read sig as str, but failed.");
276 if crypto.validate(retrieved_label, sig_string, pub_key) {
277 tracing::info!("Valid signature");
278 Ok(())
279 } else {
280 tracing::info!("Invalid signature");
281 Err(Box::new(std::io::Error::new(
282 std::io::ErrorKind::Other,
283 "Invalid signature",
284 )))
285 }
286 }
287 /// Get a label from a websocket URL, then validate the signature.
288 /// Similar to what's done in webserve.rs, but in reverse, we'll need to decode the message.
289 pub async fn get_label_and_validate_ws(
290 &self,
291 // url: &str,
292 ) -> Result<(), Box<dyn std::error::Error>> {
293 // For now, use this mock response, represented in base64:
294 let response = "omF0ZyNsYWJlbHNib3ABomNzZXEYG2ZsYWJlbHOBp2NjdHN4GzIwMjUtMDItMDlUMDM6MjU6MjcuOTI4MDIzWmNuZWf0Y3NpZ1hAXLIRXAG5mF5bCWWCwEhbYvC8YYVP9fWwbVVL6IBXXlIrZ6sr6MQ4DfNdpGhwRWawA4Mq44HlEDsJ7OvcGsDCDWNzcmN4IGRpZDpwbGM6bTZhZHB0bjYyZGNhaGZhcTM0dGNlM2o1Y3VyaXggZGlkOnBsYzptNmFkcHRuNjJkY2FoZmFxMzR0Y2UzajVjdmFsbmpvaW5lZC0yMDI1LTAyY3ZlcgE=";
295 tracing::debug!("Response: {:?}", response);
296 let response_bytes = base64::engine::GeneralPurpose::new(
297 &base64::alphabet::STANDARD,
298 base64::engine::general_purpose::PAD).decode(response).expect("Expected to be able to decode base64 response.");
299 tracing::debug!("Response bytes: {:?}", response_bytes);
300 let reponse_bytes_in_hex = hex::encode(&response_bytes);
301 tracing::debug!("Response bytes in hex: {:?}", reponse_bytes_in_hex);
302 let response_0 = &response_bytes[0..response_bytes.iter().position(|&r| r == 0x01).expect("Expected to find 0x01 in response bytes.")];
303 let response_1 = &response_bytes[response_bytes.iter().position(|&r| r == 0x01).expect("Expected to find 0x01 in response bytes.") + 1..];
304 tracing::debug!("Response 0: {:?}", hex::encode(response_0));
305 tracing::debug!("Response 1: {:?}", hex::encode(response_1));
306 let response_cbor: LabelsVecWithSeq = serde_cbor::from_slice(response_1).expect("Expected to be able to deserialize response 1 as LabelsVecWithSeq, but failed.");
307 tracing::debug!("Response CBOR: {:?}", response_cbor);
308 let unsigned_response = RetrievedLabelResponse {
309 cts: response_cbor.labels[0].cts.clone(),
310 neg: response_cbor.labels[0].neg,
311 src: response_cbor.labels[0].src.clone(),
312 uri: response_cbor.labels[0].uri.clone(),
313 val: response_cbor.labels[0].val.clone(),
314 ver: response_cbor.labels[0].ver,
315 };
316 let sig_base64 = SignatureBytes::from_bytes(response_cbor.labels[0].sig).as_base64();
317 tracing::debug!("Retrieved label: {:?}", response_cbor);
318 let crypto = crate::crypto::Crypto::new();
319 let public_key = "zQ3shreqyXEdouQeEQSFKfoSEN5eig74BXuqQyTaiE9uzADqZ";
320 if crypto.validate(unsigned_response, &sig_base64, public_key) {
321 tracing::info!("Valid signature");
322 Ok(())
323 } else {
324 tracing::info!("Invalid signature");
325 Err(Box::new(std::io::Error::new(
326 std::io::ErrorKind::Other,
327 "Invalid signature",
328 )))
329 }
330 }
331 /// getLikes
332 pub async fn get_likes(
333 &mut self,
334 uri: &str,
335 ) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error + Send + Sync>> {
336 let path = "app.bsky.feed.getLikes";
337 let parameters = [("uri", uri)];
338 self.get(path, ¶meters, ApiEndpoint::Public).await.map(|response| response["likes"].as_array().expect("Expected to be able to read likes as array, but failed.").to_owned())
339 }
340}