A library for ATProtocol identities.

feature: atproto-tap crate and tools

+64
Cargo.lock
··· 322 322 ] 323 323 324 324 [[package]] 325 + name = "atproto-tap" 326 + version = "0.13.0" 327 + dependencies = [ 328 + "atproto-client", 329 + "atproto-identity", 330 + "base64", 331 + "clap", 332 + "compact_str", 333 + "futures", 334 + "http", 335 + "itoa", 336 + "reqwest", 337 + "serde", 338 + "serde_json", 339 + "thiserror 2.0.17", 340 + "tokio", 341 + "tokio-stream", 342 + "tokio-websockets", 343 + "tracing", 344 + "tracing-subscriber", 345 + ] 346 + 347 + [[package]] 325 348 name = "atproto-xrpcs" 326 349 version = "0.13.0" 327 350 dependencies = [ ··· 506 529 checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" 507 530 508 531 [[package]] 532 + name = "castaway" 533 + version = "0.2.4" 534 + source = "registry+https://github.com/rust-lang/crates.io-index" 535 + checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" 536 + dependencies = [ 537 + "rustversion", 538 + ] 539 + 540 + [[package]] 509 541 name = "cbor4ii" 510 542 version = "0.2.14" 511 543 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 607 639 version = "1.0.4" 608 640 source = "registry+https://github.com/rust-lang/crates.io-index" 609 641 checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" 642 + 643 + [[package]] 644 + name = "compact_str" 645 + version = "0.8.1" 646 + source = "registry+https://github.com/rust-lang/crates.io-index" 647 + checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" 648 + dependencies = [ 649 + "castaway", 650 + "cfg-if", 651 + "itoa", 652 + "rustversion", 653 + "ryu", 654 + "serde", 655 + "static_assertions", 656 + ] 610 657 611 658 [[package]] 612 659 name = "const-oid" ··· 2385 2432 checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" 2386 2433 2387 2434 [[package]] 2435 + name = "static_assertions" 2436 + version = "1.1.0" 2437 + source = "registry+https://github.com/rust-lang/crates.io-index" 2438 + checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" 2439 + 2440 + [[package]] 2388 2441 name = "strsim" 2389 2442 version = "0.11.1" 2390 2443 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2574 2627 checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" 2575 2628 dependencies = [ 2576 2629 "rustls", 2630 + "tokio", 2631 + ] 2632 + 2633 + [[package]] 2634 + name = "tokio-stream" 2635 + version = "0.1.17" 2636 + source = "registry+https://github.com/rust-lang/crates.io-index" 2637 + checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" 2638 + dependencies = [ 2639 + "futures-core", 2640 + "pin-project-lite", 2577 2641 "tokio", 2578 2642 ] 2579 2643
+2
Cargo.toml
··· 8 8 "crates/atproto-oauth-axum", 9 9 "crates/atproto-oauth", 10 10 "crates/atproto-record", 11 + "crates/atproto-tap", 11 12 "crates/atproto-xrpcs-helloworld", 12 13 "crates/atproto-xrpcs", 13 14 "crates/atproto-lexicon", ··· 34 35 atproto-oauth-aip = { version = "0.13.0", path = "crates/atproto-oauth-aip" } 35 36 atproto-oauth-axum = { version = "0.13.0", path = "crates/atproto-oauth-axum" } 36 37 atproto-record = { version = "0.13.0", path = "crates/atproto-record" } 38 + atproto-tap = { version = "0.13.0", path = "crates/atproto-tap" } 37 39 atproto-xrpcs = { version = "0.13.0", path = "crates/atproto-xrpcs" } 38 40 39 41 ammonia = "4.0"
-2
crates/atproto-jetstream/src/consumer.rs
··· 383 383 break; 384 384 }, 385 385 () = &mut sleeper => { 386 - // consumer_control_insert(&self.pool, &self.config.jetstream_hostname, time_usec).await?; 387 - 388 386 sleeper.as_mut().reset(Instant::now() + interval); 389 387 }, 390 388 item = client.next() => {
+53
crates/atproto-tap/Cargo.toml
··· 1 + [package] 2 + name = "atproto-tap" 3 + version = "0.13.0" 4 + description = "AT Protocol TAP (Trusted Attestation Protocol) service consumer" 5 + readme = "README.md" 6 + homepage = "https://tangled.sh/@smokesignal.events/atproto-identity-rs" 7 + documentation = "https://docs.rs/atproto-tap" 8 + 9 + edition.workspace = true 10 + rust-version.workspace = true 11 + authors.workspace = true 12 + repository.workspace = true 13 + license.workspace = true 14 + keywords.workspace = true 15 + categories.workspace = true 16 + 17 + [dependencies] 18 + tokio = { workspace = true, features = ["sync", "time"] } 19 + tokio-stream = "0.1" 20 + tokio-websockets = { workspace = true } 21 + futures = { workspace = true } 22 + reqwest = { workspace = true } 23 + serde = { workspace = true } 24 + serde_json = { workspace = true } 25 + thiserror = { workspace = true } 26 + tracing = { workspace = true } 27 + http = { workspace = true } 28 + base64 = { workspace = true } 29 + atproto-identity.workspace = true 30 + atproto-client = { workspace = true, optional = true } 31 + 32 + # Memory efficiency 33 + compact_str = { version = "0.8", features = ["serde"] } 34 + itoa = "1.0" 35 + 36 + # Optional for CLI 37 + clap = { workspace = true, optional = true } 38 + tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true } 39 + 40 + [features] 41 + default = [] 42 + cli = ["dep:clap", "dep:tracing-subscriber", "dep:atproto-client", "tokio/rt-multi-thread", "tokio/macros", "tokio/signal"] 43 + 44 + [[bin]] 45 + name = "atproto-tap-client" 46 + required-features = ["cli"] 47 + 48 + [[bin]] 49 + name = "atproto-tap-extras" 50 + required-features = ["cli"] 51 + 52 + [lints] 53 + workspace = true
+351
crates/atproto-tap/src/bin/atproto-tap-client.rs
··· 1 + //! Command-line client for TAP services. 2 + //! 3 + //! This tool provides commands for consuming TAP events and managing tracked repositories. 4 + //! 5 + //! # Usage 6 + //! 7 + //! ```bash 8 + //! # Stream events from a TAP service 9 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 read 10 + //! 11 + //! # Stream with authentication and filters 12 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret read --live-only 13 + //! 14 + //! # Add repositories to track 15 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret repos add did:plc:xyz did:plc:abc 16 + //! 17 + //! # Remove repositories from tracking 18 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret repos remove did:plc:xyz 19 + //! 20 + //! # Resolve a DID to its DID document 21 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 resolve did:plc:xyz 22 + //! 23 + //! # Resolve a DID and only output the handle 24 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 resolve did:plc:xyz --handle-only 25 + //! 26 + //! # Get repository tracking info 27 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 info did:plc:xyz 28 + //! ``` 29 + 30 + use atproto_tap::{TapClient, TapConfig, TapEvent, connect}; 31 + use clap::{Parser, Subcommand}; 32 + use std::time::Duration; 33 + use tokio_stream::StreamExt; 34 + 35 + /// TAP service client for consuming events and managing repositories. 36 + #[derive(Parser)] 37 + #[command( 38 + name = "atproto-tap-client", 39 + version, 40 + about = "TAP service client for AT Protocol", 41 + long_about = "Connect to a TAP service to stream repository/identity events or manage tracked repositories.\n\n\ 42 + Events are printed to stdout as JSON, one per line.\n\ 43 + Use Ctrl+C to gracefully stop the consumer." 44 + )] 45 + struct Args { 46 + /// TAP service hostname (e.g., localhost:2480) 47 + hostname: String, 48 + 49 + /// Admin password for authentication 50 + #[arg(short, long, global = true)] 51 + password: Option<String>, 52 + 53 + #[command(subcommand)] 54 + command: Command, 55 + } 56 + 57 + #[derive(Subcommand)] 58 + enum Command { 59 + /// Connect to TAP and stream events as JSON 60 + Read { 61 + /// Disable acknowledgments 62 + #[arg(long)] 63 + no_acks: bool, 64 + 65 + /// Maximum reconnection attempts (0 = unlimited) 66 + #[arg(long, default_value = "0")] 67 + max_reconnects: u32, 68 + 69 + /// Print debug information to stderr 70 + #[arg(short, long)] 71 + debug: bool, 72 + 73 + /// Filter to specific collections (comma-separated) 74 + #[arg(long)] 75 + collections: Option<String>, 76 + 77 + /// Only show live events (skip backfill) 78 + #[arg(long)] 79 + live_only: bool, 80 + }, 81 + 82 + /// Manage tracked repositories 83 + Repos { 84 + #[command(subcommand)] 85 + action: ReposAction, 86 + }, 87 + 88 + /// Resolve a DID to its DID document 89 + Resolve { 90 + /// DID to resolve (e.g., did:plc:xyz123) 91 + did: String, 92 + 93 + /// Only output the handle (instead of full DID document) 94 + #[arg(long)] 95 + handle_only: bool, 96 + }, 97 + 98 + /// Get tracking info for a repository 99 + Info { 100 + /// DID to get info for (e.g., did:plc:xyz123) 101 + did: String, 102 + }, 103 + } 104 + 105 + #[derive(Subcommand)] 106 + enum ReposAction { 107 + /// Add repositories to track 108 + Add { 109 + /// DIDs to add (e.g., did:plc:xyz123) 110 + #[arg(required = true)] 111 + dids: Vec<String>, 112 + }, 113 + 114 + /// Remove repositories from tracking 115 + Remove { 116 + /// DIDs to remove 117 + #[arg(required = true)] 118 + dids: Vec<String>, 119 + }, 120 + } 121 + 122 + #[tokio::main] 123 + async fn main() { 124 + let args = Args::parse(); 125 + 126 + match args.command { 127 + Command::Read { 128 + no_acks, 129 + max_reconnects, 130 + debug, 131 + collections, 132 + live_only, 133 + } => { 134 + run_read( 135 + &args.hostname, 136 + args.password, 137 + no_acks, 138 + max_reconnects, 139 + debug, 140 + collections, 141 + live_only, 142 + ) 143 + .await; 144 + } 145 + Command::Repos { action } => { 146 + run_repos(&args.hostname, args.password, action).await; 147 + } 148 + Command::Resolve { did, handle_only } => { 149 + run_resolve(&args.hostname, args.password, &did, handle_only).await; 150 + } 151 + Command::Info { did } => { 152 + run_info(&args.hostname, args.password, &did).await; 153 + } 154 + } 155 + } 156 + 157 + async fn run_read( 158 + hostname: &str, 159 + password: Option<String>, 160 + no_acks: bool, 161 + max_reconnects: u32, 162 + debug: bool, 163 + collections: Option<String>, 164 + live_only: bool, 165 + ) { 166 + // Initialize tracing if debug mode 167 + if debug { 168 + tracing_subscriber::fmt() 169 + .with_env_filter("atproto_tap=debug") 170 + .with_writer(std::io::stderr) 171 + .init(); 172 + } 173 + 174 + // Build configuration 175 + let mut config_builder = TapConfig::builder() 176 + .hostname(hostname) 177 + .send_acks(!no_acks); 178 + 179 + if let Some(password) = password { 180 + config_builder = config_builder.admin_password(password); 181 + } 182 + 183 + if max_reconnects > 0 { 184 + config_builder = config_builder.max_reconnect_attempts(Some(max_reconnects)); 185 + } 186 + 187 + // Set reasonable defaults for CLI usage 188 + config_builder = config_builder 189 + .initial_reconnect_delay(Duration::from_secs(1)) 190 + .max_reconnect_delay(Duration::from_secs(30)); 191 + 192 + let config = config_builder.build(); 193 + 194 + eprintln!("Connecting to TAP service at {}...", hostname); 195 + 196 + let mut stream = connect(config); 197 + 198 + // Parse collection filters 199 + let collection_filters: Vec<String> = collections 200 + .map(|c| c.split(',').map(|s| s.trim().to_string()).collect()) 201 + .unwrap_or_default(); 202 + 203 + // Handle Ctrl+C 204 + let ctrl_c = tokio::signal::ctrl_c(); 205 + tokio::pin!(ctrl_c); 206 + 207 + loop { 208 + tokio::select! { 209 + Some(result) = stream.next() => { 210 + match result { 211 + Ok(event) => { 212 + // Apply filters 213 + let should_print = match event.as_ref() { 214 + TapEvent::Record { record, .. } => { 215 + // Filter by live flag 216 + if live_only && !record.live { 217 + false 218 + } 219 + // Filter by collection 220 + else if !collection_filters.is_empty() { 221 + collection_filters.iter().any(|c| record.collection.as_ref() == c) 222 + } else { 223 + true 224 + } 225 + } 226 + TapEvent::Identity { .. } => !live_only, // Always show identity unless live_only 227 + }; 228 + 229 + if should_print { 230 + // Print as JSON to stdout 231 + match serde_json::to_string(event.as_ref()) { 232 + Ok(json) => println!("{}", json), 233 + Err(e) => { 234 + eprintln!("Failed to serialize event: {}", e); 235 + } 236 + } 237 + } 238 + } 239 + Err(e) => { 240 + eprintln!("Error: {}", e); 241 + 242 + // Exit on fatal errors 243 + if e.is_fatal() { 244 + eprintln!("Fatal error, exiting"); 245 + std::process::exit(1); 246 + } 247 + } 248 + } 249 + } 250 + _ = &mut ctrl_c => { 251 + eprintln!("\nReceived Ctrl+C, shutting down..."); 252 + stream.close().await; 253 + break; 254 + } 255 + } 256 + } 257 + 258 + eprintln!("Client stopped"); 259 + } 260 + 261 + async fn run_repos(hostname: &str, password: Option<String>, action: ReposAction) { 262 + let client = TapClient::new(hostname, password); 263 + 264 + match action { 265 + ReposAction::Add { dids } => { 266 + let did_refs: Vec<&str> = dids.iter().map(|s| s.as_str()).collect(); 267 + 268 + match client.add_repos(&did_refs).await { 269 + Ok(()) => { 270 + eprintln!("Added {} repository(ies) to tracking", dids.len()); 271 + for did in &dids { 272 + println!("{}", did); 273 + } 274 + } 275 + Err(e) => { 276 + eprintln!("Failed to add repositories: {}", e); 277 + std::process::exit(1); 278 + } 279 + } 280 + } 281 + ReposAction::Remove { dids } => { 282 + let did_refs: Vec<&str> = dids.iter().map(|s| s.as_str()).collect(); 283 + 284 + match client.remove_repos(&did_refs).await { 285 + Ok(()) => { 286 + eprintln!("Removed {} repository(ies) from tracking", dids.len()); 287 + for did in &dids { 288 + println!("{}", did); 289 + } 290 + } 291 + Err(e) => { 292 + eprintln!("Failed to remove repositories: {}", e); 293 + std::process::exit(1); 294 + } 295 + } 296 + } 297 + } 298 + } 299 + 300 + async fn run_resolve(hostname: &str, password: Option<String>, did: &str, handle_only: bool) { 301 + let client = TapClient::new(hostname, password); 302 + 303 + match client.resolve(did).await { 304 + Ok(doc) => { 305 + if handle_only { 306 + // Use the handles() method from atproto_identity::model::Document 307 + match doc.handles() { 308 + Some(handle) => println!("{}", handle), 309 + None => { 310 + eprintln!("No handle found in DID document"); 311 + std::process::exit(1); 312 + } 313 + } 314 + } else { 315 + // Print full DID document as JSON 316 + match serde_json::to_string_pretty(&doc) { 317 + Ok(json) => println!("{}", json), 318 + Err(e) => { 319 + eprintln!("Failed to serialize DID document: {}", e); 320 + std::process::exit(1); 321 + } 322 + } 323 + } 324 + } 325 + Err(e) => { 326 + eprintln!("Failed to resolve DID: {}", e); 327 + std::process::exit(1); 328 + } 329 + } 330 + } 331 + 332 + async fn run_info(hostname: &str, password: Option<String>, did: &str) { 333 + let client = TapClient::new(hostname, password); 334 + 335 + match client.info(did).await { 336 + Ok(info) => { 337 + // Print as JSON for easy parsing 338 + match serde_json::to_string_pretty(&info) { 339 + Ok(json) => println!("{}", json), 340 + Err(e) => { 341 + eprintln!("Failed to serialize info: {}", e); 342 + std::process::exit(1); 343 + } 344 + } 345 + } 346 + Err(e) => { 347 + eprintln!("Failed to get repository info: {}", e); 348 + std::process::exit(1); 349 + } 350 + } 351 + }
+214
crates/atproto-tap/src/bin/atproto-tap-extras.rs
··· 1 + //! Additional TAP client utilities for AT Protocol. 2 + //! 3 + //! This tool provides extra commands for managing TAP tracked repositories 4 + //! based on social graph data. 5 + //! 6 + //! # Usage 7 + //! 8 + //! ```bash 9 + //! # Add all accounts followed by a DID to TAP tracking 10 + //! cargo run --features cli --bin atproto-tap-extras -- localhost:2480 repos-add-followers did:plc:xyz 11 + //! 12 + //! # With authentication 13 + //! cargo run --features cli --bin atproto-tap-extras -- localhost:2480 -p secret repos-add-followers did:plc:xyz 14 + //! ``` 15 + 16 + use atproto_client::client::Auth; 17 + use atproto_client::com::atproto::repo::{ListRecordsParams, list_records}; 18 + use atproto_identity::plc::query as plc_query; 19 + use atproto_tap::TapClient; 20 + use clap::{Parser, Subcommand}; 21 + use serde::Deserialize; 22 + 23 + /// TAP extras utility for managing tracked repositories. 24 + #[derive(Parser)] 25 + #[command( 26 + name = "atproto-tap-extras", 27 + version, 28 + about = "TAP extras utility for AT Protocol", 29 + long_about = "Additional utilities for managing TAP tracked repositories based on social graph data." 30 + )] 31 + struct Args { 32 + /// TAP service hostname (e.g., localhost:2480) 33 + hostname: String, 34 + 35 + /// Admin password for TAP authentication 36 + #[arg(short, long, global = true)] 37 + password: Option<String>, 38 + 39 + /// PLC directory hostname for DID resolution 40 + #[arg(long, default_value = "plc.directory", global = true)] 41 + plc_hostname: String, 42 + 43 + #[command(subcommand)] 44 + command: Command, 45 + } 46 + 47 + #[derive(Subcommand)] 48 + enum Command { 49 + /// Add accounts followed by a DID to TAP tracking. 50 + /// 51 + /// Fetches all app.bsky.graph.follow records from the specified DID's repository 52 + /// and adds the followed DIDs to TAP for tracking. 53 + ReposAddFollowers { 54 + /// DID to read followers from (e.g., did:plc:xyz123) 55 + did: String, 56 + 57 + /// Batch size for adding repos to TAP 58 + #[arg(long, default_value = "100")] 59 + batch_size: usize, 60 + 61 + /// Dry run - print DIDs without adding to TAP 62 + #[arg(long)] 63 + dry_run: bool, 64 + }, 65 + } 66 + 67 + /// Follow record structure from app.bsky.graph.follow. 68 + #[derive(Debug, Deserialize)] 69 + struct FollowRecord { 70 + /// The DID of the account being followed. 71 + subject: String, 72 + } 73 + 74 + #[tokio::main] 75 + async fn main() { 76 + let args = Args::parse(); 77 + 78 + match args.command { 79 + Command::ReposAddFollowers { 80 + did, 81 + batch_size, 82 + dry_run, 83 + } => { 84 + run_repos_add_followers( 85 + &args.hostname, 86 + args.password, 87 + &args.plc_hostname, 88 + &did, 89 + batch_size, 90 + dry_run, 91 + ) 92 + .await; 93 + } 94 + } 95 + } 96 + 97 + async fn run_repos_add_followers( 98 + tap_hostname: &str, 99 + tap_password: Option<String>, 100 + plc_hostname: &str, 101 + did: &str, 102 + batch_size: usize, 103 + dry_run: bool, 104 + ) { 105 + let http_client = reqwest::Client::new(); 106 + 107 + // Resolve the DID to get the PDS endpoint 108 + eprintln!("Resolving DID: {}", did); 109 + let document = match plc_query(&http_client, plc_hostname, did).await { 110 + Ok(doc) => doc, 111 + Err(e) => { 112 + eprintln!("Failed to resolve DID: {}", e); 113 + std::process::exit(1); 114 + } 115 + }; 116 + 117 + let pds_endpoints = document.pds_endpoints(); 118 + if pds_endpoints.is_empty() { 119 + eprintln!("No PDS endpoint found in DID document"); 120 + std::process::exit(1); 121 + } 122 + let pds_url = pds_endpoints[0]; 123 + eprintln!("Using PDS: {}", pds_url); 124 + 125 + // Collect all followed DIDs 126 + let mut followed_dids: Vec<String> = Vec::new(); 127 + let mut cursor: Option<String> = None; 128 + let collection = "app.bsky.graph.follow".to_string(); 129 + 130 + eprintln!("Fetching follow records..."); 131 + 132 + loop { 133 + let params = if let Some(c) = cursor.take() { 134 + ListRecordsParams::new().limit(100).cursor(c) 135 + } else { 136 + ListRecordsParams::new().limit(100) 137 + }; 138 + 139 + let response = match list_records::<FollowRecord>( 140 + &http_client, 141 + &Auth::None, 142 + pds_url, 143 + did.to_string(), 144 + collection.clone(), 145 + params, 146 + ) 147 + .await 148 + { 149 + Ok(resp) => resp, 150 + Err(e) => { 151 + eprintln!("Failed to list records: {}", e); 152 + std::process::exit(1); 153 + } 154 + }; 155 + 156 + for record in &response.records { 157 + followed_dids.push(record.value.subject.clone()); 158 + } 159 + 160 + eprintln!( 161 + " Fetched {} records (total: {})", 162 + response.records.len(), 163 + followed_dids.len() 164 + ); 165 + 166 + match response.cursor { 167 + Some(c) if !response.records.is_empty() => { 168 + cursor = Some(c); 169 + } 170 + _ => break, 171 + } 172 + } 173 + 174 + if followed_dids.is_empty() { 175 + eprintln!("No follow records found"); 176 + return; 177 + } 178 + 179 + eprintln!("Found {} followed accounts", followed_dids.len()); 180 + 181 + if dry_run { 182 + eprintln!("\nDry run - would add these DIDs to TAP:"); 183 + for did in &followed_dids { 184 + println!("{}", did); 185 + } 186 + return; 187 + } 188 + 189 + // Add to TAP in batches 190 + let tap_client = TapClient::new(tap_hostname, tap_password); 191 + let mut added = 0; 192 + 193 + for chunk in followed_dids.chunks(batch_size) { 194 + let did_refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect(); 195 + 196 + match tap_client.add_repos(&did_refs).await { 197 + Ok(()) => { 198 + added += chunk.len(); 199 + eprintln!("Added {} DIDs to TAP (total: {})", chunk.len(), added); 200 + } 201 + Err(e) => { 202 + eprintln!("Failed to add repos to TAP: {}", e); 203 + std::process::exit(1); 204 + } 205 + } 206 + } 207 + 208 + eprintln!("Successfully added {} DIDs to TAP", added); 209 + 210 + // Print all added DIDs 211 + for did in &followed_dids { 212 + println!("{}", did); 213 + } 214 + }
+371
crates/atproto-tap/src/client.rs
··· 1 + //! HTTP client for TAP management API. 2 + //! 3 + //! This module provides [`TapClient`] for interacting with the TAP service's 4 + //! HTTP management endpoints, including adding/removing tracked repositories. 5 + 6 + use crate::errors::TapError; 7 + use atproto_identity::model::Document; 8 + use base64::Engine; 9 + use base64::engine::general_purpose::STANDARD as BASE64; 10 + use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue}; 11 + use serde::{Deserialize, Serialize}; 12 + 13 + /// HTTP client for TAP management API. 14 + /// 15 + /// Provides methods for managing which repositories the TAP service tracks, 16 + /// checking service health, and querying repository status. 17 + /// 18 + /// # Example 19 + /// 20 + /// ```ignore 21 + /// use atproto_tap::TapClient; 22 + /// 23 + /// let client = TapClient::new("localhost:2480", Some("admin_password".to_string())); 24 + /// 25 + /// // Add repositories to track 26 + /// client.add_repos(&["did:plc:xyz123", "did:plc:abc456"]).await?; 27 + /// 28 + /// // Check health 29 + /// if client.health().await? { 30 + /// println!("TAP service is healthy"); 31 + /// } 32 + /// ``` 33 + #[derive(Debug, Clone)] 34 + pub struct TapClient { 35 + http_client: reqwest::Client, 36 + base_url: String, 37 + auth_header: Option<HeaderValue>, 38 + } 39 + 40 + impl TapClient { 41 + /// Create a new TAP management client. 42 + /// 43 + /// # Arguments 44 + /// 45 + /// * `hostname` - TAP service hostname (e.g., "localhost:2480") 46 + /// * `admin_password` - Optional admin password for authentication 47 + pub fn new(hostname: &str, admin_password: Option<String>) -> Self { 48 + let auth_header = admin_password.map(|password| { 49 + let credentials = format!("admin:{}", password); 50 + let encoded = BASE64.encode(credentials.as_bytes()); 51 + HeaderValue::from_str(&format!("Basic {}", encoded)) 52 + .expect("Invalid auth header value") 53 + }); 54 + 55 + Self { 56 + http_client: reqwest::Client::new(), 57 + base_url: format!("http://{}", hostname), 58 + auth_header, 59 + } 60 + } 61 + 62 + /// Create default headers for requests. 63 + fn default_headers(&self) -> HeaderMap { 64 + let mut headers = HeaderMap::new(); 65 + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 66 + if let Some(auth) = &self.auth_header { 67 + headers.insert(AUTHORIZATION, auth.clone()); 68 + } 69 + headers 70 + } 71 + 72 + /// Add repositories to track. 73 + /// 74 + /// Sends a POST request to `/repos/add` with the list of DIDs. 75 + /// 76 + /// # Arguments 77 + /// 78 + /// * `dids` - Slice of DID strings to track 79 + /// 80 + /// # Example 81 + /// 82 + /// ```ignore 83 + /// client.add_repos(&[ 84 + /// "did:plc:z72i7hdynmk6r22z27h6tvur", 85 + /// "did:plc:ewvi7nxzyoun6zhxrhs64oiz", 86 + /// ]).await?; 87 + /// ``` 88 + pub async fn add_repos(&self, dids: &[&str]) -> Result<(), TapError> { 89 + let url = format!("{}/repos/add", self.base_url); 90 + let body = AddReposRequest { 91 + dids: dids.iter().map(|s| s.to_string()).collect(), 92 + }; 93 + 94 + let response = self 95 + .http_client 96 + .post(&url) 97 + .headers(self.default_headers()) 98 + .json(&body) 99 + .send() 100 + .await?; 101 + 102 + if response.status().is_success() { 103 + tracing::debug!(count = dids.len(), "Added repositories to TAP"); 104 + Ok(()) 105 + } else { 106 + let status = response.status().as_u16(); 107 + let message = response.text().await.unwrap_or_default(); 108 + Err(TapError::HttpResponseError { status, message }) 109 + } 110 + } 111 + 112 + /// Remove repositories from tracking. 113 + /// 114 + /// Sends a POST request to `/repos/remove` with the list of DIDs. 115 + /// 116 + /// # Arguments 117 + /// 118 + /// * `dids` - Slice of DID strings to stop tracking 119 + pub async fn remove_repos(&self, dids: &[&str]) -> Result<(), TapError> { 120 + let url = format!("{}/repos/remove", self.base_url); 121 + let body = AddReposRequest { 122 + dids: dids.iter().map(|s| s.to_string()).collect(), 123 + }; 124 + 125 + let response = self 126 + .http_client 127 + .post(&url) 128 + .headers(self.default_headers()) 129 + .json(&body) 130 + .send() 131 + .await?; 132 + 133 + if response.status().is_success() { 134 + tracing::debug!(count = dids.len(), "Removed repositories from TAP"); 135 + Ok(()) 136 + } else { 137 + let status = response.status().as_u16(); 138 + let message = response.text().await.unwrap_or_default(); 139 + Err(TapError::HttpResponseError { status, message }) 140 + } 141 + } 142 + 143 + /// Check service health. 144 + /// 145 + /// Sends a GET request to `/health`. 146 + /// 147 + /// # Returns 148 + /// 149 + /// `true` if the service is healthy, `false` otherwise. 150 + pub async fn health(&self) -> Result<bool, TapError> { 151 + let url = format!("{}/health", self.base_url); 152 + 153 + let response = self 154 + .http_client 155 + .get(&url) 156 + .headers(self.default_headers()) 157 + .send() 158 + .await?; 159 + 160 + Ok(response.status().is_success()) 161 + } 162 + 163 + /// Resolve a DID to its DID document. 164 + /// 165 + /// Sends a GET request to `/resolve/:did`. 166 + /// 167 + /// # Arguments 168 + /// 169 + /// * `did` - The DID to resolve 170 + /// 171 + /// # Returns 172 + /// 173 + /// The DID document for the identity. 174 + pub async fn resolve(&self, did: &str) -> Result<Document, TapError> { 175 + let url = format!("{}/resolve/{}", self.base_url, did); 176 + 177 + let response = self 178 + .http_client 179 + .get(&url) 180 + .headers(self.default_headers()) 181 + .send() 182 + .await?; 183 + 184 + if response.status().is_success() { 185 + let doc: Document = response.json().await?; 186 + Ok(doc) 187 + } else { 188 + let status = response.status().as_u16(); 189 + let message = response.text().await.unwrap_or_default(); 190 + Err(TapError::HttpResponseError { status, message }) 191 + } 192 + } 193 + 194 + /// Get info about a tracked repository. 195 + /// 196 + /// Sends a GET request to `/info/:did`. 197 + /// 198 + /// # Arguments 199 + /// 200 + /// * `did` - The DID to get info for 201 + /// 202 + /// # Returns 203 + /// 204 + /// Repository tracking information. 205 + pub async fn info(&self, did: &str) -> Result<RepoInfo, TapError> { 206 + let url = format!("{}/info/{}", self.base_url, did); 207 + 208 + let response = self 209 + .http_client 210 + .get(&url) 211 + .headers(self.default_headers()) 212 + .send() 213 + .await?; 214 + 215 + if response.status().is_success() { 216 + let info: RepoInfo = response.json().await?; 217 + Ok(info) 218 + } else { 219 + let status = response.status().as_u16(); 220 + let message = response.text().await.unwrap_or_default(); 221 + Err(TapError::HttpResponseError { status, message }) 222 + } 223 + } 224 + } 225 + 226 + /// Request body for adding/removing repositories. 227 + #[derive(Debug, Serialize)] 228 + struct AddReposRequest { 229 + dids: Vec<String>, 230 + } 231 + 232 + /// Repository tracking information. 233 + #[derive(Debug, Clone, Serialize, Deserialize)] 234 + pub struct RepoInfo { 235 + /// The repository DID. 236 + pub did: Box<str>, 237 + /// Current sync state. 238 + pub state: RepoState, 239 + /// The handle for the repository. 240 + #[serde(default)] 241 + pub handle: Option<Box<str>>, 242 + /// Number of records in the repository. 243 + #[serde(default)] 244 + pub records: u64, 245 + /// Current repository revision. 246 + #[serde(default)] 247 + pub rev: Option<Box<str>>, 248 + /// Number of retries for syncing. 249 + #[serde(default)] 250 + pub retries: u32, 251 + /// Error message if any. 252 + #[serde(default)] 253 + pub error: Option<Box<str>>, 254 + /// Additional fields may be present depending on TAP version. 255 + #[serde(flatten)] 256 + pub extra: serde_json::Value, 257 + } 258 + 259 + /// Repository sync state. 260 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 261 + #[serde(rename_all = "lowercase")] 262 + pub enum RepoState { 263 + /// Repository is active and synced. 264 + Active, 265 + /// Repository is currently syncing. 266 + Syncing, 267 + /// Repository is fully synced. 268 + Synced, 269 + /// Sync failed for this repository. 270 + Failed, 271 + /// Repository is queued for sync. 272 + Queued, 273 + /// Unknown state. 274 + #[serde(other)] 275 + Unknown, 276 + } 277 + 278 + /// Deprecated alias for RepoState. 279 + #[deprecated(since = "0.13.0", note = "Use RepoState instead")] 280 + pub type RepoStatus = RepoState; 281 + 282 + impl std::fmt::Display for RepoState { 283 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 284 + match self { 285 + RepoState::Active => write!(f, "active"), 286 + RepoState::Syncing => write!(f, "syncing"), 287 + RepoState::Synced => write!(f, "synced"), 288 + RepoState::Failed => write!(f, "failed"), 289 + RepoState::Queued => write!(f, "queued"), 290 + RepoState::Unknown => write!(f, "unknown"), 291 + } 292 + } 293 + } 294 + 295 + #[cfg(test)] 296 + mod tests { 297 + use super::*; 298 + 299 + #[test] 300 + fn test_client_creation() { 301 + let client = TapClient::new("localhost:2480", None); 302 + assert_eq!(client.base_url, "http://localhost:2480"); 303 + assert!(client.auth_header.is_none()); 304 + 305 + let client = TapClient::new("localhost:2480", Some("secret".to_string())); 306 + assert!(client.auth_header.is_some()); 307 + } 308 + 309 + #[test] 310 + fn test_repo_state_display() { 311 + assert_eq!(RepoState::Active.to_string(), "active"); 312 + assert_eq!(RepoState::Syncing.to_string(), "syncing"); 313 + assert_eq!(RepoState::Synced.to_string(), "synced"); 314 + assert_eq!(RepoState::Failed.to_string(), "failed"); 315 + assert_eq!(RepoState::Queued.to_string(), "queued"); 316 + assert_eq!(RepoState::Unknown.to_string(), "unknown"); 317 + } 318 + 319 + #[test] 320 + fn test_repo_state_deserialize() { 321 + let json = r#""active""#; 322 + let state: RepoState = serde_json::from_str(json).unwrap(); 323 + assert_eq!(state, RepoState::Active); 324 + 325 + let json = r#""syncing""#; 326 + let state: RepoState = serde_json::from_str(json).unwrap(); 327 + assert_eq!(state, RepoState::Syncing); 328 + 329 + let json = r#""some_new_state""#; 330 + let state: RepoState = serde_json::from_str(json).unwrap(); 331 + assert_eq!(state, RepoState::Unknown); 332 + } 333 + 334 + #[test] 335 + fn test_repo_info_deserialize() { 336 + let json = r#"{"did":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","error":"","handle":"ngerakines.me","records":21382,"retries":0,"rev":"3mam4aazabs2m","state":"active"}"#; 337 + let info: RepoInfo = serde_json::from_str(json).unwrap(); 338 + assert_eq!(&*info.did, "did:plc:cbkjy5n7bk3ax2wplmtjofq2"); 339 + assert_eq!(info.state, RepoState::Active); 340 + assert_eq!(info.handle.as_deref(), Some("ngerakines.me")); 341 + assert_eq!(info.records, 21382); 342 + assert_eq!(info.retries, 0); 343 + assert_eq!(info.rev.as_deref(), Some("3mam4aazabs2m")); 344 + // Empty string deserializes as Some("") 345 + assert_eq!(info.error.as_deref(), Some("")); 346 + } 347 + 348 + #[test] 349 + fn test_repo_info_deserialize_minimal() { 350 + // Test with only required fields 351 + let json = r#"{"did":"did:plc:test","state":"syncing"}"#; 352 + let info: RepoInfo = serde_json::from_str(json).unwrap(); 353 + assert_eq!(&*info.did, "did:plc:test"); 354 + assert_eq!(info.state, RepoState::Syncing); 355 + assert_eq!(info.handle, None); 356 + assert_eq!(info.records, 0); 357 + assert_eq!(info.retries, 0); 358 + assert_eq!(info.rev, None); 359 + assert_eq!(info.error, None); 360 + } 361 + 362 + #[test] 363 + fn test_add_repos_request_serialize() { 364 + let req = AddReposRequest { 365 + dids: vec!["did:plc:xyz".to_string(), "did:plc:abc".to_string()], 366 + }; 367 + let json = serde_json::to_string(&req).unwrap(); 368 + assert!(json.contains("dids")); 369 + assert!(json.contains("did:plc:xyz")); 370 + } 371 + }
+220
crates/atproto-tap/src/config.rs
··· 1 + //! Configuration for TAP stream connections. 2 + //! 3 + //! This module provides the [`TapConfig`] struct for configuring TAP stream 4 + //! connections, including hostname, authentication, and reconnection behavior. 5 + 6 + use std::time::Duration; 7 + 8 + /// Configuration for a TAP stream connection. 9 + /// 10 + /// Use [`TapConfig::builder()`] for ergonomic construction with defaults. 11 + /// 12 + /// # Example 13 + /// 14 + /// ``` 15 + /// use atproto_tap::TapConfig; 16 + /// use std::time::Duration; 17 + /// 18 + /// let config = TapConfig::builder() 19 + /// .hostname("localhost:2480") 20 + /// .admin_password("secret") 21 + /// .send_acks(true) 22 + /// .max_reconnect_attempts(Some(10)) 23 + /// .build(); 24 + /// ``` 25 + #[derive(Debug, Clone)] 26 + pub struct TapConfig { 27 + /// TAP service hostname (e.g., "localhost:2480"). 28 + /// 29 + /// The WebSocket URL is constructed as `ws://{hostname}/channel`. 30 + pub hostname: String, 31 + 32 + /// Optional admin password for authentication. 33 + /// 34 + /// If set, HTTP Basic Auth is used with username "admin". 35 + pub admin_password: Option<String>, 36 + 37 + /// Whether to send acknowledgments for received messages. 38 + /// 39 + /// Default: `true`. Set to `false` if the TAP service has acks disabled. 40 + pub send_acks: bool, 41 + 42 + /// User-Agent header value for WebSocket connections. 43 + pub user_agent: String, 44 + 45 + /// Maximum reconnection attempts before giving up. 46 + /// 47 + /// `None` means unlimited reconnection attempts (default). 48 + pub max_reconnect_attempts: Option<u32>, 49 + 50 + /// Initial delay before first reconnection attempt. 51 + /// 52 + /// Default: 1 second. 53 + pub initial_reconnect_delay: Duration, 54 + 55 + /// Maximum delay between reconnection attempts. 56 + /// 57 + /// Default: 60 seconds. 58 + pub max_reconnect_delay: Duration, 59 + 60 + /// Multiplier for exponential backoff between reconnections. 61 + /// 62 + /// Default: 2.0 (doubles the delay each attempt). 63 + pub reconnect_backoff_multiplier: f64, 64 + } 65 + 66 + impl Default for TapConfig { 67 + fn default() -> Self { 68 + Self { 69 + hostname: "localhost:2480".to_string(), 70 + admin_password: None, 71 + send_acks: true, 72 + user_agent: format!("atproto-tap/{}", env!("CARGO_PKG_VERSION")), 73 + max_reconnect_attempts: None, 74 + initial_reconnect_delay: Duration::from_secs(1), 75 + max_reconnect_delay: Duration::from_secs(60), 76 + reconnect_backoff_multiplier: 2.0, 77 + } 78 + } 79 + } 80 + 81 + impl TapConfig { 82 + /// Create a new configuration builder with defaults. 83 + pub fn builder() -> TapConfigBuilder { 84 + TapConfigBuilder::default() 85 + } 86 + 87 + /// Create a minimal configuration for the given hostname. 88 + pub fn new(hostname: impl Into<String>) -> Self { 89 + Self { 90 + hostname: hostname.into(), 91 + ..Default::default() 92 + } 93 + } 94 + 95 + /// Returns the WebSocket URL for the TAP channel. 96 + pub fn ws_url(&self) -> String { 97 + format!("ws://{}/channel", self.hostname) 98 + } 99 + 100 + /// Returns the HTTP base URL for the TAP management API. 101 + pub fn http_base_url(&self) -> String { 102 + format!("http://{}", self.hostname) 103 + } 104 + } 105 + 106 + /// Builder for [`TapConfig`]. 107 + #[derive(Debug, Clone, Default)] 108 + pub struct TapConfigBuilder { 109 + config: TapConfig, 110 + } 111 + 112 + impl TapConfigBuilder { 113 + /// Set the TAP service hostname. 114 + pub fn hostname(mut self, hostname: impl Into<String>) -> Self { 115 + self.config.hostname = hostname.into(); 116 + self 117 + } 118 + 119 + /// Set the admin password for authentication. 120 + pub fn admin_password(mut self, password: impl Into<String>) -> Self { 121 + self.config.admin_password = Some(password.into()); 122 + self 123 + } 124 + 125 + /// Set whether to send acknowledgments. 126 + pub fn send_acks(mut self, send_acks: bool) -> Self { 127 + self.config.send_acks = send_acks; 128 + self 129 + } 130 + 131 + /// Set the User-Agent header value. 132 + pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self { 133 + self.config.user_agent = user_agent.into(); 134 + self 135 + } 136 + 137 + /// Set the maximum reconnection attempts. 138 + /// 139 + /// `None` means unlimited attempts. 140 + pub fn max_reconnect_attempts(mut self, max: Option<u32>) -> Self { 141 + self.config.max_reconnect_attempts = max; 142 + self 143 + } 144 + 145 + /// Set the initial reconnection delay. 146 + pub fn initial_reconnect_delay(mut self, delay: Duration) -> Self { 147 + self.config.initial_reconnect_delay = delay; 148 + self 149 + } 150 + 151 + /// Set the maximum reconnection delay. 152 + pub fn max_reconnect_delay(mut self, delay: Duration) -> Self { 153 + self.config.max_reconnect_delay = delay; 154 + self 155 + } 156 + 157 + /// Set the reconnection backoff multiplier. 158 + pub fn reconnect_backoff_multiplier(mut self, multiplier: f64) -> Self { 159 + self.config.reconnect_backoff_multiplier = multiplier; 160 + self 161 + } 162 + 163 + /// Build the configuration. 164 + pub fn build(self) -> TapConfig { 165 + self.config 166 + } 167 + } 168 + 169 + #[cfg(test)] 170 + mod tests { 171 + use super::*; 172 + 173 + #[test] 174 + fn test_default_config() { 175 + let config = TapConfig::default(); 176 + assert_eq!(config.hostname, "localhost:2480"); 177 + assert!(config.admin_password.is_none()); 178 + assert!(config.send_acks); 179 + assert!(config.max_reconnect_attempts.is_none()); 180 + assert_eq!(config.initial_reconnect_delay, Duration::from_secs(1)); 181 + assert_eq!(config.max_reconnect_delay, Duration::from_secs(60)); 182 + assert!((config.reconnect_backoff_multiplier - 2.0).abs() < f64::EPSILON); 183 + } 184 + 185 + #[test] 186 + fn test_builder() { 187 + let config = TapConfig::builder() 188 + .hostname("tap.example.com:2480") 189 + .admin_password("secret123") 190 + .send_acks(false) 191 + .max_reconnect_attempts(Some(5)) 192 + .initial_reconnect_delay(Duration::from_millis(500)) 193 + .max_reconnect_delay(Duration::from_secs(30)) 194 + .reconnect_backoff_multiplier(1.5) 195 + .build(); 196 + 197 + assert_eq!(config.hostname, "tap.example.com:2480"); 198 + assert_eq!(config.admin_password, Some("secret123".to_string())); 199 + assert!(!config.send_acks); 200 + assert_eq!(config.max_reconnect_attempts, Some(5)); 201 + assert_eq!(config.initial_reconnect_delay, Duration::from_millis(500)); 202 + assert_eq!(config.max_reconnect_delay, Duration::from_secs(30)); 203 + assert!((config.reconnect_backoff_multiplier - 1.5).abs() < f64::EPSILON); 204 + } 205 + 206 + #[test] 207 + fn test_ws_url() { 208 + let config = TapConfig::new("localhost:2480"); 209 + assert_eq!(config.ws_url(), "ws://localhost:2480/channel"); 210 + 211 + let config = TapConfig::new("tap.example.com:8080"); 212 + assert_eq!(config.ws_url(), "ws://tap.example.com:8080/channel"); 213 + } 214 + 215 + #[test] 216 + fn test_http_base_url() { 217 + let config = TapConfig::new("localhost:2480"); 218 + assert_eq!(config.http_base_url(), "http://localhost:2480"); 219 + } 220 + }
+168
crates/atproto-tap/src/connection.rs
··· 1 + //! WebSocket connection management for TAP streams. 2 + //! 3 + //! This module handles the low-level WebSocket connection to a TAP service, 4 + //! including authentication and message sending/receiving. 5 + 6 + use crate::config::TapConfig; 7 + use crate::errors::TapError; 8 + use base64::Engine; 9 + use base64::engine::general_purpose::STANDARD as BASE64; 10 + use futures::{SinkExt, StreamExt}; 11 + use http::Uri; 12 + use std::str::FromStr; 13 + use tokio_websockets::{ClientBuilder, Message, WebSocketStream}; 14 + use tokio_websockets::MaybeTlsStream; 15 + use tokio::net::TcpStream; 16 + 17 + /// WebSocket connection to a TAP service. 18 + pub(crate) struct TapConnection { 19 + /// The underlying WebSocket stream. 20 + ws: WebSocketStream<MaybeTlsStream<TcpStream>>, 21 + /// Pre-allocated buffer for acknowledgment messages. 22 + ack_buffer: Vec<u8>, 23 + } 24 + 25 + impl TapConnection { 26 + /// Establish a new WebSocket connection to the TAP service. 27 + pub async fn connect(config: &TapConfig) -> Result<Self, TapError> { 28 + let uri = Uri::from_str(&config.ws_url()) 29 + .map_err(|e| TapError::InvalidUrl(e.to_string()))?; 30 + 31 + let mut builder = ClientBuilder::from_uri(uri); 32 + 33 + // Add User-Agent header 34 + builder = builder 35 + .add_header( 36 + http::header::USER_AGENT, 37 + http::HeaderValue::from_str(&config.user_agent) 38 + .map_err(|e| TapError::ConnectionFailed(format!("Invalid user agent: {}", e)))?, 39 + ) 40 + .map_err(|e| TapError::ConnectionFailed(format!("Failed to add header: {}", e)))?; 41 + 42 + // Add Basic Auth header if password is configured 43 + if let Some(password) = &config.admin_password { 44 + let credentials = format!("admin:{}", password); 45 + let encoded = BASE64.encode(credentials.as_bytes()); 46 + let auth_value = format!("Basic {}", encoded); 47 + 48 + builder = builder 49 + .add_header( 50 + http::header::AUTHORIZATION, 51 + http::HeaderValue::from_str(&auth_value) 52 + .map_err(|e| TapError::ConnectionFailed(format!("Invalid auth header: {}", e)))?, 53 + ) 54 + .map_err(|e| TapError::ConnectionFailed(format!("Failed to add auth header: {}", e)))?; 55 + } 56 + 57 + // Connect 58 + let (ws, _response) = builder 59 + .connect() 60 + .await 61 + .map_err(|e| TapError::ConnectionFailed(e.to_string()))?; 62 + 63 + tracing::debug!(hostname = %config.hostname, "Connected to TAP service"); 64 + 65 + Ok(Self { 66 + ws, 67 + ack_buffer: Vec::with_capacity(48), // {"type":"ack","id":18446744073709551615} is 40 bytes max 68 + }) 69 + } 70 + 71 + /// Receive the next message from the WebSocket. 72 + /// 73 + /// Returns `None` if the connection was closed cleanly. 74 + pub async fn recv(&mut self) -> Result<Option<String>, TapError> { 75 + match self.ws.next().await { 76 + Some(Ok(msg)) => { 77 + if msg.is_text() { 78 + msg.as_text() 79 + .map(|s| Some(s.to_string())) 80 + .ok_or_else(|| TapError::ParseError("Failed to get text from message".into())) 81 + } else if msg.is_close() { 82 + tracing::debug!("Received close frame from TAP service"); 83 + Ok(None) 84 + } else { 85 + // Ignore ping/pong and binary messages 86 + tracing::trace!("Received non-text message, ignoring"); 87 + // Recurse to get the next text message 88 + Box::pin(self.recv()).await 89 + } 90 + } 91 + Some(Err(e)) => Err(TapError::ConnectionFailed(e.to_string())), 92 + None => { 93 + tracing::debug!("WebSocket stream ended"); 94 + Ok(None) 95 + } 96 + } 97 + } 98 + 99 + /// Send an acknowledgment for the given event ID. 100 + /// 101 + /// Uses a pre-allocated buffer and itoa for allocation-free formatting. 102 + /// Format: `{"type":"ack","id":12345}` 103 + pub async fn send_ack(&mut self, id: u64) -> Result<(), TapError> { 104 + self.ack_buffer.clear(); 105 + self.ack_buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 106 + let mut itoa_buf = itoa::Buffer::new(); 107 + self.ack_buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 108 + self.ack_buffer.push(b'}'); 109 + 110 + // All bytes are ASCII so this is always valid UTF-8 111 + let msg = std::str::from_utf8(&self.ack_buffer) 112 + .expect("ack buffer contains only ASCII"); 113 + 114 + self.ws 115 + .send(Message::text(msg.to_string())) 116 + .await 117 + .map_err(|e| TapError::AckFailed(e.to_string()))?; 118 + 119 + // Flush to ensure the ack is sent immediately 120 + self.ws 121 + .flush() 122 + .await 123 + .map_err(|e| TapError::AckFailed(format!("Failed to flush ack: {}", e)))?; 124 + 125 + tracing::trace!(id, "Sent ack"); 126 + Ok(()) 127 + } 128 + 129 + /// Close the WebSocket connection gracefully. 130 + pub async fn close(&mut self) -> Result<(), TapError> { 131 + self.ws 132 + .close() 133 + .await 134 + .map_err(|e| TapError::ConnectionFailed(format!("Failed to close: {}", e)))?; 135 + Ok(()) 136 + } 137 + } 138 + 139 + #[cfg(test)] 140 + mod tests { 141 + #[test] 142 + fn test_ack_buffer_format() { 143 + // Test that our manual JSON formatting is correct 144 + // Format: {"type":"ack","id":12345} 145 + let mut buffer = Vec::with_capacity(64); 146 + 147 + let id: u64 = 12345; 148 + buffer.clear(); 149 + buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 150 + let mut itoa_buf = itoa::Buffer::new(); 151 + buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 152 + buffer.push(b'}'); 153 + 154 + let result = std::str::from_utf8(&buffer).unwrap(); 155 + assert_eq!(result, r#"{"type":"ack","id":12345}"#); 156 + 157 + // Test max u64 158 + let id: u64 = u64::MAX; 159 + buffer.clear(); 160 + buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 161 + buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 162 + buffer.push(b'}'); 163 + 164 + let result = std::str::from_utf8(&buffer).unwrap(); 165 + assert_eq!(result, r#"{"type":"ack","id":18446744073709551615}"#); 166 + assert!(buffer.len() <= 64); // Fits in our pre-allocated buffer 167 + } 168 + }
+143
crates/atproto-tap/src/errors.rs
··· 1 + //! Error types for TAP operations. 2 + //! 3 + //! This module defines the error types returned by TAP stream and client operations. 4 + 5 + use thiserror::Error; 6 + 7 + /// Errors that can occur during TAP operations. 8 + #[derive(Debug, Error)] 9 + pub enum TapError { 10 + /// WebSocket connection failed. 11 + #[error("error-atproto-tap-connection-1 WebSocket connection failed: {0}")] 12 + ConnectionFailed(String), 13 + 14 + /// Connection was closed unexpectedly. 15 + #[error("error-atproto-tap-connection-2 Connection closed unexpectedly")] 16 + ConnectionClosed, 17 + 18 + /// Maximum reconnection attempts exceeded. 19 + #[error("error-atproto-tap-connection-3 Maximum reconnection attempts exceeded after {0} attempts")] 20 + MaxReconnectAttemptsExceeded(u32), 21 + 22 + /// Authentication failed. 23 + #[error("error-atproto-tap-auth-1 Authentication failed: {0}")] 24 + AuthenticationFailed(String), 25 + 26 + /// Failed to parse a message from the server. 27 + #[error("error-atproto-tap-parse-1 Failed to parse message: {0}")] 28 + ParseError(String), 29 + 30 + /// Failed to send an acknowledgment. 31 + #[error("error-atproto-tap-ack-1 Failed to send acknowledgment: {0}")] 32 + AckFailed(String), 33 + 34 + /// HTTP request failed. 35 + #[error("error-atproto-tap-http-1 HTTP request failed: {0}")] 36 + HttpError(String), 37 + 38 + /// HTTP response indicated an error. 39 + #[error("error-atproto-tap-http-2 HTTP error response: {status} - {message}")] 40 + HttpResponseError { 41 + /// HTTP status code. 42 + status: u16, 43 + /// Error message from response. 44 + message: String, 45 + }, 46 + 47 + /// Invalid URL. 48 + #[error("error-atproto-tap-url-1 Invalid URL: {0}")] 49 + InvalidUrl(String), 50 + 51 + /// I/O error. 52 + #[error("error-atproto-tap-io-1 I/O error: {0}")] 53 + IoError(#[from] std::io::Error), 54 + 55 + /// JSON serialization/deserialization error. 56 + #[error("error-atproto-tap-json-1 JSON error: {0}")] 57 + JsonError(#[from] serde_json::Error), 58 + 59 + /// Stream has been closed and cannot be used. 60 + #[error("error-atproto-tap-stream-1 Stream is closed")] 61 + StreamClosed, 62 + 63 + /// Operation timed out. 64 + #[error("error-atproto-tap-timeout-1 Operation timed out")] 65 + Timeout, 66 + } 67 + 68 + impl TapError { 69 + /// Returns true if this error indicates a connection issue that may be recoverable. 70 + pub fn is_connection_error(&self) -> bool { 71 + matches!( 72 + self, 73 + TapError::ConnectionFailed(_) 74 + | TapError::ConnectionClosed 75 + | TapError::IoError(_) 76 + | TapError::Timeout 77 + ) 78 + } 79 + 80 + /// Returns true if this error is a parse error that doesn't affect connection state. 81 + pub fn is_parse_error(&self) -> bool { 82 + matches!(self, TapError::ParseError(_) | TapError::JsonError(_)) 83 + } 84 + 85 + /// Returns true if this error is fatal and the stream should not attempt recovery. 86 + pub fn is_fatal(&self) -> bool { 87 + matches!( 88 + self, 89 + TapError::MaxReconnectAttemptsExceeded(_) 90 + | TapError::AuthenticationFailed(_) 91 + | TapError::StreamClosed 92 + ) 93 + } 94 + } 95 + 96 + impl From<reqwest::Error> for TapError { 97 + fn from(err: reqwest::Error) -> Self { 98 + if err.is_timeout() { 99 + TapError::Timeout 100 + } else if err.is_connect() { 101 + TapError::ConnectionFailed(err.to_string()) 102 + } else { 103 + TapError::HttpError(err.to_string()) 104 + } 105 + } 106 + } 107 + 108 + #[cfg(test)] 109 + mod tests { 110 + use super::*; 111 + 112 + #[test] 113 + fn test_error_classification() { 114 + assert!(TapError::ConnectionFailed("test".into()).is_connection_error()); 115 + assert!(TapError::ConnectionClosed.is_connection_error()); 116 + assert!(TapError::Timeout.is_connection_error()); 117 + 118 + assert!(TapError::ParseError("test".into()).is_parse_error()); 119 + assert!(TapError::JsonError(serde_json::from_str::<()>("invalid").unwrap_err()).is_parse_error()); 120 + 121 + assert!(TapError::MaxReconnectAttemptsExceeded(5).is_fatal()); 122 + assert!(TapError::AuthenticationFailed("test".into()).is_fatal()); 123 + assert!(TapError::StreamClosed.is_fatal()); 124 + 125 + // Non-fatal errors 126 + assert!(!TapError::ConnectionFailed("test".into()).is_fatal()); 127 + assert!(!TapError::ParseError("test".into()).is_fatal()); 128 + } 129 + 130 + #[test] 131 + fn test_error_display() { 132 + let err = TapError::ConnectionFailed("refused".to_string()); 133 + assert!(err.to_string().contains("error-atproto-tap-connection-1")); 134 + assert!(err.to_string().contains("refused")); 135 + 136 + let err = TapError::HttpResponseError { 137 + status: 404, 138 + message: "Not Found".to_string(), 139 + }; 140 + assert!(err.to_string().contains("404")); 141 + assert!(err.to_string().contains("Not Found")); 142 + } 143 + }
+376
crates/atproto-tap/src/events.rs
··· 1 + //! TAP event types for AT Protocol record and identity events. 2 + //! 3 + //! This module defines the event structures received from a TAP service. 4 + //! Events are optimized for memory efficiency using: 5 + //! - `CompactString` for small strings (SSO for ≤24 bytes) 6 + //! - `Box<str>` for immutable strings (no capacity overhead) 7 + //! - `serde_json::Value` for record payloads (allows lazy access) 8 + 9 + use compact_str::CompactString; 10 + use serde::{Deserialize, Serialize, de::DeserializeOwned}; 11 + 12 + /// A TAP event received from the stream. 13 + /// 14 + /// TAP delivers two types of events: 15 + /// - `Record`: Repository record changes (create, update, delete) 16 + /// - `Identity`: Identity/handle changes for accounts 17 + #[derive(Debug, Clone, Serialize, Deserialize)] 18 + #[serde(tag = "type", rename_all = "lowercase")] 19 + pub enum TapEvent { 20 + /// A repository record event (create, update, or delete). 21 + Record { 22 + /// Sequential event identifier. 23 + id: u64, 24 + /// The record event data. 25 + record: RecordEvent, 26 + }, 27 + /// An identity change event. 28 + Identity { 29 + /// Sequential event identifier. 30 + id: u64, 31 + /// The identity event data. 32 + identity: IdentityEvent, 33 + }, 34 + } 35 + 36 + impl TapEvent { 37 + /// Returns the event ID. 38 + pub fn id(&self) -> u64 { 39 + match self { 40 + TapEvent::Record { id, .. } => *id, 41 + TapEvent::Identity { id, .. } => *id, 42 + } 43 + } 44 + } 45 + 46 + /// A repository record event from TAP. 47 + /// 48 + /// Contains information about a record change in a user's repository, 49 + /// including the action taken and the record data (for creates/updates). 50 + #[derive(Debug, Clone, Serialize, Deserialize)] 51 + pub struct RecordEvent { 52 + /// True if from live firehose, false if from backfill/resync. 53 + /// 54 + /// During initial sync or recovery, TAP delivers historical events 55 + /// with `live: false`. Once caught up, live events have `live: true`. 56 + pub live: bool, 57 + 58 + /// Repository revision identifier. 59 + /// 60 + /// Typically 13 characters, stored inline via CompactString SSO. 61 + pub rev: CompactString, 62 + 63 + /// Actor DID (e.g., "did:plc:xyz123"). 64 + pub did: Box<str>, 65 + 66 + /// Collection NSID (e.g., "app.bsky.feed.post"). 67 + pub collection: Box<str>, 68 + 69 + /// Record key within the collection. 70 + /// 71 + /// Typically a TID (13 characters), stored inline via CompactString SSO. 72 + pub rkey: CompactString, 73 + 74 + /// The action performed on the record. 75 + pub action: RecordAction, 76 + 77 + /// Content identifier (CID) of the record. 78 + /// 79 + /// Present for create and update actions, absent for delete. 80 + #[serde(skip_serializing_if = "Option::is_none")] 81 + pub cid: Option<CompactString>, 82 + 83 + /// Record data as JSON value. 84 + /// 85 + /// Present for create and update actions, absent for delete. 86 + /// Use [`parse_record`](Self::parse_record) to deserialize on demand. 87 + #[serde(skip_serializing_if = "Option::is_none")] 88 + pub record: Option<serde_json::Value>, 89 + } 90 + 91 + impl RecordEvent { 92 + /// Parse the record payload into a typed structure. 93 + /// 94 + /// This method deserializes the raw JSON on demand, avoiding 95 + /// unnecessary allocations when the record data isn't needed. 96 + /// 97 + /// # Errors 98 + /// 99 + /// Returns an error if the record is absent (delete events) or 100 + /// if deserialization fails. 101 + /// 102 + /// # Example 103 + /// 104 + /// ```ignore 105 + /// use serde::Deserialize; 106 + /// 107 + /// #[derive(Deserialize)] 108 + /// struct Post { 109 + /// text: String, 110 + /// #[serde(rename = "createdAt")] 111 + /// created_at: String, 112 + /// } 113 + /// 114 + /// let post: Post = record_event.parse_record()?; 115 + /// println!("Post text: {}", post.text); 116 + /// ``` 117 + pub fn parse_record<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> { 118 + match &self.record { 119 + Some(value) => serde_json::from_value(value.clone()), 120 + None => Err(serde::de::Error::custom("no record data (delete event)")), 121 + } 122 + } 123 + 124 + /// Returns the record as a JSON Value reference, if present. 125 + pub fn record_value(&self) -> Option<&serde_json::Value> { 126 + self.record.as_ref() 127 + } 128 + 129 + /// Returns true if this is a delete event. 130 + pub fn is_delete(&self) -> bool { 131 + self.action == RecordAction::Delete 132 + } 133 + 134 + /// Returns the AT-URI for this record. 135 + /// 136 + /// Format: `at://{did}/{collection}/{rkey}` 137 + pub fn at_uri(&self) -> String { 138 + format!("at://{}/{}/{}", self.did, self.collection, self.rkey) 139 + } 140 + } 141 + 142 + /// The action performed on a record. 143 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 144 + #[serde(rename_all = "lowercase")] 145 + pub enum RecordAction { 146 + /// A new record was created. 147 + Create, 148 + /// An existing record was updated. 149 + Update, 150 + /// A record was deleted. 151 + Delete, 152 + } 153 + 154 + impl std::fmt::Display for RecordAction { 155 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 156 + match self { 157 + RecordAction::Create => write!(f, "create"), 158 + RecordAction::Update => write!(f, "update"), 159 + RecordAction::Delete => write!(f, "delete"), 160 + } 161 + } 162 + } 163 + 164 + /// An identity change event from TAP. 165 + /// 166 + /// Contains information about handle or account status changes. 167 + #[derive(Debug, Clone, Serialize, Deserialize)] 168 + pub struct IdentityEvent { 169 + /// Actor DID. 170 + pub did: Box<str>, 171 + 172 + /// Current handle for the account. 173 + pub handle: Box<str>, 174 + 175 + /// Whether the account is currently active. 176 + #[serde(default)] 177 + pub is_active: bool, 178 + 179 + /// Account status. 180 + #[serde(default)] 181 + pub status: IdentityStatus, 182 + } 183 + 184 + /// Account status in an identity event. 185 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)] 186 + #[serde(rename_all = "lowercase")] 187 + pub enum IdentityStatus { 188 + /// Account is active and in good standing. 189 + #[default] 190 + Active, 191 + /// Account has been deactivated by the user. 192 + Deactivated, 193 + /// Account has been suspended. 194 + Suspended, 195 + /// Account has been deleted. 196 + Deleted, 197 + /// Account has been taken down. 198 + Takendown, 199 + } 200 + 201 + impl std::fmt::Display for IdentityStatus { 202 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 203 + match self { 204 + IdentityStatus::Active => write!(f, "active"), 205 + IdentityStatus::Deactivated => write!(f, "deactivated"), 206 + IdentityStatus::Suspended => write!(f, "suspended"), 207 + IdentityStatus::Deleted => write!(f, "deleted"), 208 + IdentityStatus::Takendown => write!(f, "takendown"), 209 + } 210 + } 211 + } 212 + 213 + #[cfg(test)] 214 + mod tests { 215 + use super::*; 216 + 217 + #[test] 218 + fn test_parse_record_event() { 219 + let json = r#"{ 220 + "id": 12345, 221 + "type": "record", 222 + "record": { 223 + "live": true, 224 + "rev": "3lyileto4q52k", 225 + "did": "did:plc:z72i7hdynmk6r22z27h6tvur", 226 + "collection": "app.bsky.feed.post", 227 + "rkey": "3lyiletddxt2c", 228 + "action": "create", 229 + "cid": "bafyreigroo6vhxt62ufcndhaxzas6btq4jmniuz4egszbwuqgiyisqwqoy", 230 + "record": {"$type": "app.bsky.feed.post", "text": "Hello world!", "createdAt": "2025-01-01T00:00:00Z"} 231 + } 232 + }"#; 233 + 234 + let event: TapEvent = serde_json::from_str(json).expect("Failed to parse"); 235 + 236 + match event { 237 + TapEvent::Record { id, record } => { 238 + assert_eq!(id, 12345); 239 + assert!(record.live); 240 + assert_eq!(record.rev.as_str(), "3lyileto4q52k"); 241 + assert_eq!(&*record.did, "did:plc:z72i7hdynmk6r22z27h6tvur"); 242 + assert_eq!(&*record.collection, "app.bsky.feed.post"); 243 + assert_eq!(record.rkey.as_str(), "3lyiletddxt2c"); 244 + assert_eq!(record.action, RecordAction::Create); 245 + assert!(record.cid.is_some()); 246 + assert!(record.record.is_some()); 247 + 248 + // Test lazy parsing 249 + #[derive(Deserialize)] 250 + struct Post { 251 + text: String, 252 + } 253 + let post: Post = record.parse_record().expect("Failed to parse record"); 254 + assert_eq!(post.text, "Hello world!"); 255 + } 256 + _ => panic!("Expected Record event"), 257 + } 258 + } 259 + 260 + #[test] 261 + fn test_parse_delete_event() { 262 + let json = r#"{ 263 + "id": 12346, 264 + "type": "record", 265 + "record": { 266 + "live": true, 267 + "rev": "3lyileto4q52k", 268 + "did": "did:plc:z72i7hdynmk6r22z27h6tvur", 269 + "collection": "app.bsky.feed.post", 270 + "rkey": "3lyiletddxt2c", 271 + "action": "delete" 272 + } 273 + }"#; 274 + 275 + let event: TapEvent = serde_json::from_str(json).expect("Failed to parse"); 276 + 277 + match event { 278 + TapEvent::Record { id, record } => { 279 + assert_eq!(id, 12346); 280 + assert_eq!(record.action, RecordAction::Delete); 281 + assert!(record.is_delete()); 282 + assert!(record.cid.is_none()); 283 + assert!(record.record.is_none()); 284 + } 285 + _ => panic!("Expected Record event"), 286 + } 287 + } 288 + 289 + #[test] 290 + fn test_parse_identity_event() { 291 + let json = r#"{ 292 + "id": 12347, 293 + "type": "identity", 294 + "identity": { 295 + "did": "did:plc:z72i7hdynmk6r22z27h6tvur", 296 + "handle": "user.bsky.social", 297 + "is_active": true, 298 + "status": "active" 299 + } 300 + }"#; 301 + 302 + let event: TapEvent = serde_json::from_str(json).expect("Failed to parse"); 303 + 304 + match event { 305 + TapEvent::Identity { id, identity } => { 306 + assert_eq!(id, 12347); 307 + assert_eq!(&*identity.did, "did:plc:z72i7hdynmk6r22z27h6tvur"); 308 + assert_eq!(&*identity.handle, "user.bsky.social"); 309 + assert!(identity.is_active); 310 + assert_eq!(identity.status, IdentityStatus::Active); 311 + } 312 + _ => panic!("Expected Identity event"), 313 + } 314 + } 315 + 316 + #[test] 317 + fn test_record_action_display() { 318 + assert_eq!(RecordAction::Create.to_string(), "create"); 319 + assert_eq!(RecordAction::Update.to_string(), "update"); 320 + assert_eq!(RecordAction::Delete.to_string(), "delete"); 321 + } 322 + 323 + #[test] 324 + fn test_identity_status_display() { 325 + assert_eq!(IdentityStatus::Active.to_string(), "active"); 326 + assert_eq!(IdentityStatus::Deactivated.to_string(), "deactivated"); 327 + assert_eq!(IdentityStatus::Suspended.to_string(), "suspended"); 328 + assert_eq!(IdentityStatus::Deleted.to_string(), "deleted"); 329 + assert_eq!(IdentityStatus::Takendown.to_string(), "takendown"); 330 + } 331 + 332 + #[test] 333 + fn test_at_uri() { 334 + let record = RecordEvent { 335 + live: true, 336 + rev: "3lyileto4q52k".into(), 337 + did: "did:plc:xyz".into(), 338 + collection: "app.bsky.feed.post".into(), 339 + rkey: "abc123".into(), 340 + action: RecordAction::Create, 341 + cid: None, 342 + record: None, 343 + }; 344 + 345 + assert_eq!(record.at_uri(), "at://did:plc:xyz/app.bsky.feed.post/abc123"); 346 + } 347 + 348 + #[test] 349 + fn test_event_id() { 350 + let record_event = TapEvent::Record { 351 + id: 100, 352 + record: RecordEvent { 353 + live: true, 354 + rev: "rev".into(), 355 + did: "did".into(), 356 + collection: "col".into(), 357 + rkey: "rkey".into(), 358 + action: RecordAction::Create, 359 + cid: None, 360 + record: None, 361 + }, 362 + }; 363 + assert_eq!(record_event.id(), 100); 364 + 365 + let identity_event = TapEvent::Identity { 366 + id: 200, 367 + identity: IdentityEvent { 368 + did: "did".into(), 369 + handle: "handle".into(), 370 + is_active: true, 371 + status: IdentityStatus::Active, 372 + }, 373 + }; 374 + assert_eq!(identity_event.id(), 200); 375 + } 376 + }
+119
crates/atproto-tap/src/lib.rs
··· 1 + //! TAP (Trusted Attestation Protocol) service consumer for AT Protocol. 2 + //! 3 + //! This crate provides a client for consuming events from a TAP service, 4 + //! which delivers filtered, verified AT Protocol repository events. 5 + //! 6 + //! # Overview 7 + //! 8 + //! TAP is a single-tenant service that subscribes to an AT Protocol Relay and 9 + //! outputs filtered, verified events. Key features include: 10 + //! 11 + //! - **Verified Events**: MST integrity checks and signature verification 12 + //! - **Automatic Backfill**: Historical events delivered with `live: false` 13 + //! - **Repository Filtering**: Track specific DIDs or collections 14 + //! - **Acknowledgment Protocol**: At-least-once delivery semantics 15 + //! 16 + //! # Quick Start 17 + //! 18 + //! ```ignore 19 + //! use atproto_tap::{connect_to, TapEvent}; 20 + //! use tokio_stream::StreamExt; 21 + //! 22 + //! #[tokio::main] 23 + //! async fn main() { 24 + //! let mut stream = connect_to("localhost:2480"); 25 + //! 26 + //! while let Some(result) = stream.next().await { 27 + //! match result { 28 + //! Ok(event) => match event.as_ref() { 29 + //! TapEvent::Record { record, .. } => { 30 + //! println!("{} {} {}", record.action, record.collection, record.did); 31 + //! } 32 + //! TapEvent::Identity { identity, .. } => { 33 + //! println!("Identity: {} = {}", identity.did, identity.handle); 34 + //! } 35 + //! }, 36 + //! Err(e) => eprintln!("Error: {}", e), 37 + //! } 38 + //! } 39 + //! } 40 + //! ``` 41 + //! 42 + //! # Using with `tokio::select!` 43 + //! 44 + //! The stream integrates naturally with Tokio's select macro: 45 + //! 46 + //! ```ignore 47 + //! use atproto_tap::{connect, TapConfig}; 48 + //! use tokio_stream::StreamExt; 49 + //! use tokio::signal; 50 + //! 51 + //! #[tokio::main] 52 + //! async fn main() { 53 + //! let config = TapConfig::builder() 54 + //! .hostname("localhost:2480") 55 + //! .admin_password("secret") 56 + //! .build(); 57 + //! 58 + //! let mut stream = connect(config); 59 + //! 60 + //! loop { 61 + //! tokio::select! { 62 + //! Some(result) = stream.next() => { 63 + //! // Process event 64 + //! } 65 + //! _ = signal::ctrl_c() => { 66 + //! break; 67 + //! } 68 + //! } 69 + //! } 70 + //! } 71 + //! ``` 72 + //! 73 + //! # Management API 74 + //! 75 + //! Use [`TapClient`] to manage tracked repositories: 76 + //! 77 + //! ```ignore 78 + //! use atproto_tap::TapClient; 79 + //! 80 + //! let client = TapClient::new("localhost:2480", Some("password".to_string())); 81 + //! 82 + //! // Add repositories to track 83 + //! client.add_repos(&["did:plc:xyz123"]).await?; 84 + //! 85 + //! // Check service health 86 + //! if client.health().await? { 87 + //! println!("TAP service is healthy"); 88 + //! } 89 + //! ``` 90 + //! 91 + //! # Memory Efficiency 92 + //! 93 + //! This crate is optimized for high-throughput event processing: 94 + //! 95 + //! - **Arc-wrapped events**: Events are shared via `Arc` for zero-cost sharing 96 + //! - **CompactString**: Small strings use inline storage (no heap allocation) 97 + //! - **Box<str>**: Immutable strings without capacity overhead 98 + //! - **RawValue**: Record payloads are lazily parsed on demand 99 + //! - **Pre-allocated buffers**: Ack messages avoid per-message allocations 100 + 101 + #![forbid(unsafe_code)] 102 + #![warn(missing_docs)] 103 + 104 + mod client; 105 + mod config; 106 + mod connection; 107 + mod errors; 108 + mod events; 109 + mod stream; 110 + 111 + // Re-export public types 112 + pub use atproto_identity::model::{Document, Service, VerificationMethod}; 113 + pub use client::{RepoInfo, RepoState, TapClient}; 114 + #[allow(deprecated)] 115 + pub use client::RepoStatus; 116 + pub use config::{TapConfig, TapConfigBuilder}; 117 + pub use errors::TapError; 118 + pub use events::{IdentityEvent, IdentityStatus, RecordAction, RecordEvent, TapEvent}; 119 + pub use stream::{TapStream, connect, connect_to};
+316
crates/atproto-tap/src/stream.rs
··· 1 + //! TAP event stream implementation. 2 + //! 3 + //! This module provides [`TapStream`], an async stream that yields TAP events 4 + //! with automatic connection management and reconnection handling. 5 + //! 6 + //! # Design 7 + //! 8 + //! The stream encapsulates all connection logic, allowing consumers to simply 9 + //! iterate over events using standard stream combinators or `tokio::select!`. 10 + //! 11 + //! Reconnection is handled automatically with exponential backoff. Parse errors 12 + //! are yielded as `Err` items but don't affect connection state - only connection 13 + //! errors trigger reconnection attempts. 14 + 15 + use crate::config::TapConfig; 16 + use crate::connection::TapConnection; 17 + use crate::errors::TapError; 18 + use crate::events::TapEvent; 19 + use futures::Stream; 20 + use std::pin::Pin; 21 + use std::sync::Arc; 22 + use std::task::{Context, Poll}; 23 + use std::time::Duration; 24 + use tokio::sync::mpsc; 25 + 26 + /// An async stream of TAP events with automatic reconnection. 27 + /// 28 + /// `TapStream` implements [`Stream`] and yields `Result<Arc<TapEvent>, TapError>`. 29 + /// Events are wrapped in `Arc` for efficient zero-cost sharing across consumers. 30 + /// 31 + /// # Connection Management 32 + /// 33 + /// The stream automatically: 34 + /// - Connects on first poll 35 + /// - Reconnects with exponential backoff on connection errors 36 + /// - Sends acknowledgments after parsing each message (if enabled) 37 + /// - Yields parse errors without affecting connection state 38 + /// 39 + /// # Example 40 + /// 41 + /// ```ignore 42 + /// use atproto_tap::{TapConfig, TapStream}; 43 + /// use tokio_stream::StreamExt; 44 + /// 45 + /// let config = TapConfig::builder() 46 + /// .hostname("localhost:2480") 47 + /// .build(); 48 + /// 49 + /// let mut stream = TapStream::new(config); 50 + /// 51 + /// while let Some(result) = stream.next().await { 52 + /// match result { 53 + /// Ok(event) => println!("Event: {:?}", event), 54 + /// Err(e) => eprintln!("Error: {}", e), 55 + /// } 56 + /// } 57 + /// ``` 58 + pub struct TapStream { 59 + /// Receiver for events from the background task. 60 + receiver: mpsc::Receiver<Result<Arc<TapEvent>, TapError>>, 61 + /// Handle to request stream closure. 62 + close_sender: Option<mpsc::Sender<()>>, 63 + /// Whether the stream has been closed. 64 + closed: bool, 65 + } 66 + 67 + impl TapStream { 68 + /// Create a new TAP stream with the given configuration. 69 + /// 70 + /// The stream will start connecting immediately in a background task. 71 + pub fn new(config: TapConfig) -> Self { 72 + // Channel for events - buffer a few to handle bursts 73 + let (event_tx, event_rx) = mpsc::channel(32); 74 + // Channel for close signal 75 + let (close_tx, close_rx) = mpsc::channel(1); 76 + 77 + // Spawn background task to manage connection 78 + tokio::spawn(connection_task(config, event_tx, close_rx)); 79 + 80 + Self { 81 + receiver: event_rx, 82 + close_sender: Some(close_tx), 83 + closed: false, 84 + } 85 + } 86 + 87 + /// Close the stream and release resources. 88 + /// 89 + /// After calling this, the stream will yield `None` on the next poll. 90 + pub async fn close(&mut self) { 91 + if let Some(sender) = self.close_sender.take() { 92 + // Signal the background task to close 93 + let _ = sender.send(()).await; 94 + } 95 + self.closed = true; 96 + } 97 + 98 + /// Returns true if the stream is closed. 99 + pub fn is_closed(&self) -> bool { 100 + self.closed 101 + } 102 + } 103 + 104 + impl Stream for TapStream { 105 + type Item = Result<Arc<TapEvent>, TapError>; 106 + 107 + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 108 + if self.closed { 109 + return Poll::Ready(None); 110 + } 111 + 112 + self.receiver.poll_recv(cx) 113 + } 114 + } 115 + 116 + impl Drop for TapStream { 117 + fn drop(&mut self) { 118 + // Drop the close_sender to signal the background task 119 + self.close_sender.take(); 120 + tracing::debug!("TapStream dropped"); 121 + } 122 + } 123 + 124 + /// Background task that manages the WebSocket connection. 125 + async fn connection_task( 126 + config: TapConfig, 127 + event_tx: mpsc::Sender<Result<Arc<TapEvent>, TapError>>, 128 + mut close_rx: mpsc::Receiver<()>, 129 + ) { 130 + let mut current_reconnect_delay = config.initial_reconnect_delay; 131 + let mut attempt: u32 = 0; 132 + 133 + loop { 134 + // Check for close signal 135 + if close_rx.try_recv().is_ok() { 136 + tracing::debug!("Connection task received close signal"); 137 + break; 138 + } 139 + 140 + // Try to connect 141 + tracing::debug!(attempt, hostname = %config.hostname, "Connecting to TAP service"); 142 + let conn_result = TapConnection::connect(&config).await; 143 + 144 + match conn_result { 145 + Ok(mut conn) => { 146 + tracing::info!(hostname = %config.hostname, "TAP stream connected"); 147 + // Reset reconnection state on successful connect 148 + current_reconnect_delay = config.initial_reconnect_delay; 149 + attempt = 0; 150 + 151 + // Event loop for this connection 152 + loop { 153 + tokio::select! { 154 + biased; 155 + 156 + _ = close_rx.recv() => { 157 + tracing::debug!("Connection task received close signal during receive"); 158 + let _ = conn.close().await; 159 + return; 160 + } 161 + 162 + recv_result = conn.recv() => { 163 + match recv_result { 164 + Ok(Some(msg)) => { 165 + // Parse the message 166 + match serde_json::from_str::<TapEvent>(&msg) { 167 + Ok(event) => { 168 + let event_id = event.id(); 169 + 170 + // Send ack if enabled (before sending event to channel) 171 + if config.send_acks 172 + && let Err(err) = conn.send_ack(event_id).await 173 + { 174 + tracing::warn!(error = %err, "Failed to send ack"); 175 + // Don't break connection for ack errors 176 + } 177 + 178 + // Send event to channel 179 + let event = Arc::new(event); 180 + if event_tx.send(Ok(event)).await.is_err() { 181 + // Receiver dropped, exit task 182 + tracing::debug!("Event receiver dropped, closing connection"); 183 + let _ = conn.close().await; 184 + return; 185 + } 186 + } 187 + Err(err) => { 188 + // Parse errors don't affect connection 189 + tracing::warn!(error = %err, "Failed to parse TAP message"); 190 + if event_tx.send(Err(TapError::ParseError(err.to_string()))).await.is_err() { 191 + tracing::debug!("Event receiver dropped, closing connection"); 192 + let _ = conn.close().await; 193 + return; 194 + } 195 + } 196 + } 197 + } 198 + Ok(None) => { 199 + // Connection closed by server 200 + tracing::debug!("TAP connection closed by server"); 201 + break; 202 + } 203 + Err(err) => { 204 + // Connection error 205 + tracing::warn!(error = %err, "TAP connection error"); 206 + break; 207 + } 208 + } 209 + } 210 + } 211 + } 212 + } 213 + Err(err) => { 214 + tracing::warn!(error = %err, attempt, "Failed to connect to TAP service"); 215 + } 216 + } 217 + 218 + // Increment attempt counter 219 + attempt += 1; 220 + 221 + // Check if we've exceeded max attempts 222 + if let Some(max) = config.max_reconnect_attempts 223 + && attempt >= max 224 + { 225 + tracing::error!(attempts = attempt, "Max reconnection attempts exceeded"); 226 + let _ = event_tx 227 + .send(Err(TapError::MaxReconnectAttemptsExceeded(attempt))) 228 + .await; 229 + break; 230 + } 231 + 232 + // Wait before reconnecting with exponential backoff 233 + tracing::debug!( 234 + delay_ms = current_reconnect_delay.as_millis(), 235 + attempt, 236 + "Waiting before reconnection" 237 + ); 238 + 239 + tokio::select! { 240 + _ = close_rx.recv() => { 241 + tracing::debug!("Connection task received close signal during backoff"); 242 + return; 243 + } 244 + _ = tokio::time::sleep(current_reconnect_delay) => { 245 + // Update delay for next attempt 246 + current_reconnect_delay = Duration::from_secs_f64( 247 + (current_reconnect_delay.as_secs_f64() * config.reconnect_backoff_multiplier) 248 + .min(config.max_reconnect_delay.as_secs_f64()), 249 + ); 250 + } 251 + } 252 + } 253 + 254 + tracing::debug!("Connection task exiting"); 255 + } 256 + 257 + /// Create a new TAP stream with the given configuration. 258 + pub fn connect(config: TapConfig) -> TapStream { 259 + TapStream::new(config) 260 + } 261 + 262 + /// Create a new TAP stream connected to the given hostname. 263 + /// 264 + /// Uses default configuration values. 265 + pub fn connect_to(hostname: &str) -> TapStream { 266 + TapStream::new(TapConfig::new(hostname)) 267 + } 268 + 269 + #[cfg(test)] 270 + mod tests { 271 + use super::*; 272 + 273 + #[test] 274 + fn test_stream_initial_state() { 275 + // Note: This test doesn't actually poll the stream, just checks initial state 276 + // Creating a TapStream requires a tokio runtime for the spawn 277 + } 278 + 279 + #[tokio::test] 280 + async fn test_stream_close() { 281 + let mut stream = TapStream::new(TapConfig::new("localhost:9999")); 282 + assert!(!stream.is_closed()); 283 + stream.close().await; 284 + assert!(stream.is_closed()); 285 + } 286 + 287 + #[test] 288 + fn test_connect_functions() { 289 + // These just create configs, actual connection happens in background task 290 + // We can't test without a runtime, so just verify the types compile 291 + let _ = TapConfig::new("localhost:2480"); 292 + } 293 + 294 + #[test] 295 + fn test_reconnect_delay_calculation() { 296 + // Test the delay calculation logic 297 + let initial = Duration::from_secs(1); 298 + let max = Duration::from_secs(10); 299 + let multiplier = 2.0; 300 + 301 + let mut delay = initial; 302 + assert_eq!(delay, Duration::from_secs(1)); 303 + 304 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 305 + assert_eq!(delay, Duration::from_secs(2)); 306 + 307 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 308 + assert_eq!(delay, Duration::from_secs(4)); 309 + 310 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 311 + assert_eq!(delay, Duration::from_secs(8)); 312 + 313 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 314 + assert_eq!(delay, Duration::from_secs(10)); // Capped at max 315 + } 316 + }