Personal ATProto tools.
at main 17 kB view raw
1//! Requests to the atproto API. 2use std::env; 3use std::fs; 4use std::sync::Arc; 5 6use base64::Engine; 7use reqwest::Client; 8use tokio::sync::Mutex; 9 10use crate::types::{LabelsVecWithSeq, RetrievedLabelResponse, SignatureBytes}; 11enum ApiEndpoint { 12 Authorized, 13 Public, 14} 15/// Agent for interactions with atproto. 16#[derive(Clone)] 17pub struct Agent { 18 /// The access JWT. 19 pub access_jwt: Arc<Mutex<String>>, 20 /// The refresh JWT. 21 pub refresh_jwt: Arc<Mutex<String>>, 22 /// The reqwest client. 23 pub client: Client, 24 /// The DID of the labeler. 25 pub self_did: Arc<String>, 26} 27impl Default for Agent { 28 fn default() -> Self { 29 drop(dotenvy::dotenv().expect("Failed to load .env file")); 30 Self { 31 access_jwt: Arc::new(Mutex::new(env::var("ACCESS_JWT").expect("ACCESS_JWT must be set"))), 32 refresh_jwt: Arc::new(Mutex::new(env::var("REFRESH_JWT").expect("REFRESH_JWT must be set"))), 33 client: Client::new(), 34 self_did: Arc::new(env::var("SELF_DID").expect("SELF_DID must be set")), 35 } 36 } 37} 38impl Agent { 39 /// The base URL of the atproto API's XRPC endpoint. 40 /// Rate limit: 3_000 per 5 minutes 41 const AUTH_URL: &'static str = "https://bsky.social/xrpc/"; 42 const PUBLIC_URL: &'static str = "https://public.api.bsky.app/xrpc/"; 43 async fn client_get( 44 &self, 45 path: &str, 46 parameters: &[(&str, &str)], 47 api_endpoint: &ApiEndpoint, 48 ) -> reqwest::Response { 49 self.client 50 .get(format!("{}{}", match api_endpoint { 51 ApiEndpoint::Authorized => Self::AUTH_URL, 52 ApiEndpoint::Public => Self::PUBLIC_URL, 53 }, &path)) 54 .header("Content-Type", "application/json") 55 .header("Authorization", format!("Bearer {}", self.access_jwt.lock().await)) 56 .header("atproto-accept-labelers", self.self_did.as_str()) 57 .query(parameters) 58 .send() 59 .await.expect("Expected to be able to send request, but failed.") 60 } 61 async fn client_refresh(&self) { 62 tracing::warn!("Token expired, refreshing"); 63 let response = self.client 64 .post(format!( 65 "{}{}", 66 Self::AUTH_URL, 67 "com.atproto.server.refreshSession" 68 )) 69 .header("Content-Type", "application/json") 70 .header("Authorization", format!("Bearer {}", self.refresh_jwt.lock().await)) 71 .header("atproto-accept-labelers", self.self_did.as_str()) 72 .send() 73 .await.expect("Expected to be able to send request, but failed."); 74 let json = response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed."); 75 if let Some(error) = json["error"].as_str() { 76 match error { 77 "InvalidRequest" => { 78 tracing::warn!("Invalid request"); 79 return; 80 }, 81 "ExpiredToken" => { 82 tracing::warn!("Token expired"); 83 return; 84 }, 85 "AccountDeactivated" => { 86 tracing::warn!("Account deactivated"); 87 return; 88 }, 89 "AccountTakedown" => { 90 tracing::warn!("Account has been suspended (Takedown)"); 91 return; 92 }, 93 _ => { 94 tracing::warn!("Unknown error from HTTP response: {:?}", json); 95 return; 96 } 97 } 98 } 99 *self.refresh_jwt.lock().await = json["refreshJwt"].as_str().expect("Expected to be able to read refreshJwt as str, but failed.").to_owned(); 100 *self.access_jwt.lock().await = json["accessJwt"].as_str().expect("Expected to be able to read accessJwt as str, but failed.").to_owned(); 101 let new_env = format!( 102 "ACCESS_JWT={}\nREFRESH_JWT={}\n", 103 self.access_jwt.lock().await, self.refresh_jwt.lock().await 104 ); 105 fs::write(".env", new_env).expect("Failed to write to .env"); 106 tracing::info!("Token refreshed"); 107 } 108 /// Get a JSON response from the atproto API. Used internal to this struct. 109 async fn get( 110 &self, 111 path: &str, 112 parameters: &[(&str, &str)], 113 api_endpoint: ApiEndpoint, 114 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> { 115 let response = self.client_get(path, parameters, &api_endpoint).await; 116 if response.status() == reqwest::StatusCode::TOO_MANY_REQUESTS { 117 tracing::warn!("Rate limited, sleeping for 5 minutes"); 118 tracing::warn!("We were working on {} with parameters {:?}", path, parameters); 119 tokio::time::sleep(std::time::Duration::from_secs(305)).await; // 5 minutes and 5 seconds 120 let response = self.client_get(path, parameters, &api_endpoint).await; 121 return Ok(response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed.")); 122 } 123 if response.status() == reqwest::StatusCode::BAD_REQUEST { 124 let json = &response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed."); 125 match json["error"].as_str().expect("Expected to be able to read error as str, but failed.") { 126 "ExpiredToken" => { 127 self.client_refresh().await; 128 let response = self.client_get(path, parameters, &api_endpoint).await; 129 return Ok(response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed.")); 130 }, 131 "AccountDeactivated" => { 132 tracing::warn!("Account deactivated"); 133 return Err(Box::new(std::io::Error::new( 134 std::io::ErrorKind::Other, 135 "Account deactivated", 136 ))); 137 }, 138 "AccountTakedown" => { 139 tracing::warn!("Account has been suspended (Takedown)"); 140 return Err(Box::new(std::io::Error::new( 141 std::io::ErrorKind::Other, 142 "Account deactivated", 143 ))); 144 }, 145 "InvalidRequest" => { 146 // Check if the message is "Profile not found" 147 if json["message"].as_str().expect("Expected to be able to read message as str, but failed.") == "Profile not found" { 148 tracing::warn!("Profile not found"); 149 return Err(Box::new(std::io::Error::new( 150 std::io::ErrorKind::NotFound, 151 "Profile not found", 152 ))); 153 } 154 tracing::warn!("Unknown invalid request: {:?}", json); 155 return Err(Box::new(std::io::Error::new( 156 std::io::ErrorKind::Other, 157 "Unknown invalid request", 158 ))); 159 }, 160 _ => { 161 tracing::warn!("Unknown error from HTTP response: {:?}", json); 162 return Err(Box::new(std::io::Error::new( 163 std::io::ErrorKind::Other, 164 "Unknown bad request", 165 ))); 166 } 167 }; 168 } 169 if response.status() != reqwest::StatusCode::OK { 170 return Err(Box::new(std::io::Error::new( 171 std::io::ErrorKind::Other, 172 "Unknown HTTP error", 173 ))); 174 } 175 let json = response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed."); 176 Ok(json) 177 } 178 /// Get a profile from the atproto API. 179 pub async fn get_profile( 180 &mut self, 181 profile_id: &str, 182 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> { 183 let path = "app.bsky.actor.getProfile"; 184 let parameters = [("actor", profile_id)]; 185 self.get(path, &parameters, ApiEndpoint::Public).await 186 } 187 /// Get multiple profiles. 188 pub async fn get_profiles( 189 &mut self, 190 profile_ids: &[String], 191 ) -> Result<serde_json::Value, Box<dyn std::error::Error + Send + Sync>> { 192 let path = "app.bsky.actor.getProfiles"; 193 let mut parameters = Vec::new(); 194 for profile_id in profile_ids { 195 parameters.push(("actors", profile_id.as_str())); 196 } 197 self.get(path, parameters.as_slice(), ApiEndpoint::Authorized).await 198 } 199 /// Check if a list of profiles has a label from us. 200 pub async fn check_profiles( 201 &mut self, 202 profile_ids: &[(String, i64)], 203 ) -> Result<Vec<(bool, (String, i64))>, Box<dyn std::error::Error + Send + Sync>> { 204 let mut found_labels: Vec<(bool, (String, i64))> = Vec::new(); 205 let profile_ids_uris = profile_ids.iter().map(|(profile_id, _)| profile_id.clone()).collect::<Vec<String>>(); 206 let profile_ids_seqs = profile_ids.iter().map(|(_, seq)| seq).collect::<Vec<&i64>>(); 207 let profiles = self.get_profiles(profile_ids_uris.as_slice()).await?; 208 let profiles_array = profiles["profiles"].as_array(); 209 if profiles_array.is_none() { 210 tracing::warn!("No profiles json found for profiles: {:?}", profiles); 211 return Ok(vec![]); 212 } 213 for profile in profiles_array.unwrap_or_else(|| panic!("Expected to be able to read profiles as array, but failed. Profiles: {:?}", profiles)) { 214 let labels = &profile["labels"]; 215 let mut found = false; 216 let label_array = labels.as_array(); 217 if label_array.is_none() { 218 tracing::warn!("No labels json found for profile: {:?}", profile); 219 continue; 220 } 221 let did = profile["did"].as_str().expect("Expected to be able to read did as str, but failed."); 222 let seq = profile_ids_seqs[profile_ids_uris.iter().position(|x| x == did).expect("Expected to be able to find the index of the uri.")]; 223 for label in label_array.unwrap_or_else(|| panic!("Expected to be able to read labels as array, but failed. Profile: {:?}", profile)) { 224 if label["src"].as_str().expect("Expected to be able to read src as str, but failed.") == self.self_did.as_str() { 225 found = true; 226 break; 227 } 228 } 229 found_labels.push((found, (did.to_owned(), *seq))); 230 } 231 Ok(found_labels) 232 } 233 /// After getting a profile, check the labels on it, and see if one from us ("src:") is there. 234 pub async fn check_profile( 235 &mut self, 236 profile_did: &str, 237 ) -> Result<bool, Box<dyn std::error::Error + Send + Sync>> { 238 let profile = self.get_profile(profile_did).await?; 239 let labels = &profile["labels"]; 240 let label_array = labels.as_array(); 241 if label_array.is_none() { 242 tracing::warn!("No labels json found for profile: {:?}", profile); 243 return Ok(false); 244 } 245 for label in label_array.unwrap_or_else(|| panic!("Expected to be able to read labels as array, but failed. Profile: {:?}", profile)) { 246 if label["src"].as_str().expect("Expected to be able to read src as str, but failed.") == self.self_did.as_str() { 247 return Ok(true); 248 } 249 } 250 Ok(false) 251 } 252 /// Get a label from the provided URL, then validate the signature. 253 pub async fn get_label_and_validate( 254 &self, 255 url: &str, 256 ) -> Result<(), Box<dyn std::error::Error>> { 257 tracing::debug!("Getting label from {}", url); 258 let response = reqwest::get(url).await.expect("Expected to be able to get response, but failed."); 259 tracing::debug!("Response: {:?}", response); 260 let response_json = response.json::<serde_json::Value>().await.expect("Expected to be able to read response as JSON, but failed."); 261 tracing::debug!("Response JSON: {:?}", response_json); 262 let sig = &response_json["labels"][0]["sig"]; 263 tracing::debug!("Signature: {:?}", sig); 264 let retrieved_label = RetrievedLabelResponse { 265 // id: response_json["labels"][0]["id"].as_u64().unwrap(), 266 cts: response_json["labels"][0]["cts"].as_str().expect("Expected to be able to read cts as str, but failed.").to_owned(), 267 neg: response_json["labels"][0]["neg"].as_str() == Some("true"), 268 src: response_json["labels"][0]["src"].as_str().expect("Expected to be able to read src as str, but failed.").to_owned().parse().expect("Failed to parse DID"), 269 uri: response_json["labels"][0]["uri"].as_str().expect("Expected to be able to read uri as str, but failed.").to_owned().parse().expect("Failed to parse URI"), 270 val: response_json["labels"][0]["val"].as_str().expect("Expected to be able to read val as str, but failed.").to_owned(), 271 ver: response_json["labels"][0]["ver"].as_u64().expect("Expected to be able to read ver as u64, but failed."), 272 }; 273 let crypto = crate::crypto::Crypto::new(); 274 let pub_key = "zQ3shreqyXEdouQeEQSFKfoSEN5eig74BXuqQyTaiE9uzADqZ"; 275 let sig_string = sig["$bytes"].as_str().expect("Expected to be able to read sig as str, but failed."); 276 if crypto.validate(retrieved_label, sig_string, pub_key) { 277 tracing::info!("Valid signature"); 278 Ok(()) 279 } else { 280 tracing::info!("Invalid signature"); 281 Err(Box::new(std::io::Error::new( 282 std::io::ErrorKind::Other, 283 "Invalid signature", 284 ))) 285 } 286 } 287 /// Get a label from a websocket URL, then validate the signature. 288 /// Similar to what's done in webserve.rs, but in reverse, we'll need to decode the message. 289 pub async fn get_label_and_validate_ws( 290 &self, 291 // url: &str, 292 ) -> Result<(), Box<dyn std::error::Error>> { 293 // For now, use this mock response, represented in base64: 294 let response = "omF0ZyNsYWJlbHNib3ABomNzZXEYG2ZsYWJlbHOBp2NjdHN4GzIwMjUtMDItMDlUMDM6MjU6MjcuOTI4MDIzWmNuZWf0Y3NpZ1hAXLIRXAG5mF5bCWWCwEhbYvC8YYVP9fWwbVVL6IBXXlIrZ6sr6MQ4DfNdpGhwRWawA4Mq44HlEDsJ7OvcGsDCDWNzcmN4IGRpZDpwbGM6bTZhZHB0bjYyZGNhaGZhcTM0dGNlM2o1Y3VyaXggZGlkOnBsYzptNmFkcHRuNjJkY2FoZmFxMzR0Y2UzajVjdmFsbmpvaW5lZC0yMDI1LTAyY3ZlcgE="; 295 tracing::debug!("Response: {:?}", response); 296 let response_bytes = base64::engine::GeneralPurpose::new( 297 &base64::alphabet::STANDARD, 298 base64::engine::general_purpose::PAD).decode(response).expect("Expected to be able to decode base64 response."); 299 tracing::debug!("Response bytes: {:?}", response_bytes); 300 let reponse_bytes_in_hex = hex::encode(&response_bytes); 301 tracing::debug!("Response bytes in hex: {:?}", reponse_bytes_in_hex); 302 let response_0 = &response_bytes[0..response_bytes.iter().position(|&r| r == 0x01).expect("Expected to find 0x01 in response bytes.")]; 303 let response_1 = &response_bytes[response_bytes.iter().position(|&r| r == 0x01).expect("Expected to find 0x01 in response bytes.") + 1..]; 304 tracing::debug!("Response 0: {:?}", hex::encode(response_0)); 305 tracing::debug!("Response 1: {:?}", hex::encode(response_1)); 306 let response_cbor: LabelsVecWithSeq = serde_cbor::from_slice(response_1).expect("Expected to be able to deserialize response 1 as LabelsVecWithSeq, but failed."); 307 tracing::debug!("Response CBOR: {:?}", response_cbor); 308 let unsigned_response = RetrievedLabelResponse { 309 cts: response_cbor.labels[0].cts.clone(), 310 neg: response_cbor.labels[0].neg, 311 src: response_cbor.labels[0].src.clone(), 312 uri: response_cbor.labels[0].uri.clone(), 313 val: response_cbor.labels[0].val.clone(), 314 ver: response_cbor.labels[0].ver, 315 }; 316 let sig_base64 = SignatureBytes::from_bytes(response_cbor.labels[0].sig).as_base64(); 317 tracing::debug!("Retrieved label: {:?}", response_cbor); 318 let crypto = crate::crypto::Crypto::new(); 319 let public_key = "zQ3shreqyXEdouQeEQSFKfoSEN5eig74BXuqQyTaiE9uzADqZ"; 320 if crypto.validate(unsigned_response, &sig_base64, public_key) { 321 tracing::info!("Valid signature"); 322 Ok(()) 323 } else { 324 tracing::info!("Invalid signature"); 325 Err(Box::new(std::io::Error::new( 326 std::io::ErrorKind::Other, 327 "Invalid signature", 328 ))) 329 } 330 } 331 /// getLikes 332 pub async fn get_likes( 333 &mut self, 334 uri: &str, 335 ) -> Result<Vec<serde_json::Value>, Box<dyn std::error::Error + Send + Sync>> { 336 let path = "app.bsky.feed.getLikes"; 337 let parameters = [("uri", uri)]; 338 self.get(path, &parameters, ApiEndpoint::Public).await.map(|response| response["likes"].as_array().expect("Expected to be able to read likes as array, but failed.").to_owned()) 339 } 340}