A library for ATProtocol identities.

Compare changes

Choose any two refs to compare.

Changed files
+4669 -248
crates
atproto-attestation
atproto-client
atproto-extras
atproto-identity
src
atproto-jetstream
atproto-oauth
src
atproto-tap
atproto-xrpcs
atproto-xrpcs-helloworld
src
+91
Cargo.lock
··· 141 141 ] 142 142 143 143 [[package]] 144 + name = "atproto-extras" 145 + version = "0.13.0" 146 + dependencies = [ 147 + "anyhow", 148 + "async-trait", 149 + "atproto-identity", 150 + "atproto-record", 151 + "clap", 152 + "regex", 153 + "reqwest", 154 + "serde_json", 155 + "tokio", 156 + ] 157 + 158 + [[package]] 144 159 name = "atproto-identity" 145 160 version = "0.13.0" 146 161 dependencies = [ ··· 307 322 ] 308 323 309 324 [[package]] 325 + name = "atproto-tap" 326 + version = "0.13.0" 327 + dependencies = [ 328 + "atproto-client", 329 + "atproto-identity", 330 + "base64", 331 + "clap", 332 + "compact_str", 333 + "futures", 334 + "http", 335 + "itoa", 336 + "reqwest", 337 + "serde", 338 + "serde_json", 339 + "thiserror 2.0.17", 340 + "tokio", 341 + "tokio-stream", 342 + "tokio-websockets", 343 + "tracing", 344 + "tracing-subscriber", 345 + ] 346 + 347 + [[package]] 310 348 name = "atproto-xrpcs" 311 349 version = "0.13.0" 312 350 dependencies = [ ··· 491 529 checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a" 492 530 493 531 [[package]] 532 + name = "castaway" 533 + version = "0.2.4" 534 + source = "registry+https://github.com/rust-lang/crates.io-index" 535 + checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a" 536 + dependencies = [ 537 + "rustversion", 538 + ] 539 + 540 + [[package]] 494 541 name = "cbor4ii" 495 542 version = "0.2.14" 496 543 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 592 639 version = "1.0.4" 593 640 source = "registry+https://github.com/rust-lang/crates.io-index" 594 641 checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" 642 + 643 + [[package]] 644 + name = "compact_str" 645 + version = "0.8.1" 646 + source = "registry+https://github.com/rust-lang/crates.io-index" 647 + checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32" 648 + dependencies = [ 649 + "castaway", 650 + "cfg-if", 651 + "itoa", 652 + "rustversion", 653 + "ryu", 654 + "serde", 655 + "static_assertions", 656 + ] 595 657 596 658 [[package]] 597 659 name = "const-oid" ··· 1879 1941 ] 1880 1942 1881 1943 [[package]] 1944 + name = "regex" 1945 + version = "1.12.2" 1946 + source = "registry+https://github.com/rust-lang/crates.io-index" 1947 + checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4" 1948 + dependencies = [ 1949 + "aho-corasick", 1950 + "memchr", 1951 + "regex-automata", 1952 + "regex-syntax", 1953 + ] 1954 + 1955 + [[package]] 1882 1956 name = "regex-automata" 1883 1957 version = "0.4.13" 1884 1958 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2358 2432 checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596" 2359 2433 2360 2434 [[package]] 2435 + name = "static_assertions" 2436 + version = "1.1.0" 2437 + source = "registry+https://github.com/rust-lang/crates.io-index" 2438 + checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" 2439 + 2440 + [[package]] 2361 2441 name = "strsim" 2362 2442 version = "0.11.1" 2363 2443 source = "registry+https://github.com/rust-lang/crates.io-index" ··· 2547 2627 checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61" 2548 2628 dependencies = [ 2549 2629 "rustls", 2630 + "tokio", 2631 + ] 2632 + 2633 + [[package]] 2634 + name = "tokio-stream" 2635 + version = "0.1.17" 2636 + source = "registry+https://github.com/rust-lang/crates.io-index" 2637 + checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047" 2638 + dependencies = [ 2639 + "futures-core", 2640 + "pin-project-lite", 2550 2641 "tokio", 2551 2642 ] 2552 2643
+26 -20
Cargo.toml
··· 1 1 [workspace] 2 2 members = [ 3 3 "crates/atproto-client", 4 + "crates/atproto-extras", 4 5 "crates/atproto-identity", 5 6 "crates/atproto-jetstream", 6 7 "crates/atproto-oauth-aip", 7 8 "crates/atproto-oauth-axum", 8 9 "crates/atproto-oauth", 9 10 "crates/atproto-record", 11 + "crates/atproto-tap", 10 12 "crates/atproto-xrpcs-helloworld", 11 13 "crates/atproto-xrpcs", 12 14 "crates/atproto-lexicon", ··· 24 26 categories = ["command-line-utilities", "web-programming"] 25 27 26 28 [workspace.dependencies] 29 + atproto-attestation = { version = "0.13.0", path = "crates/atproto-attestation" } 27 30 atproto-client = { version = "0.13.0", path = "crates/atproto-client" } 31 + atproto-extras = { version = "0.13.0", path = "crates/atproto-extras" } 28 32 atproto-identity = { version = "0.13.0", path = "crates/atproto-identity" } 33 + atproto-jetstream = { version = "0.13.0", path = "crates/atproto-jetstream" } 29 34 atproto-oauth = { version = "0.13.0", path = "crates/atproto-oauth" } 30 - atproto-oauth-axum = { version = "0.13.0", path = "crates/atproto-oauth-axum" } 31 35 atproto-oauth-aip = { version = "0.13.0", path = "crates/atproto-oauth-aip" } 36 + atproto-oauth-axum = { version = "0.13.0", path = "crates/atproto-oauth-axum" } 32 37 atproto-record = { version = "0.13.0", path = "crates/atproto-record" } 38 + atproto-tap = { version = "0.13.0", path = "crates/atproto-tap" } 33 39 atproto-xrpcs = { version = "0.13.0", path = "crates/atproto-xrpcs" } 34 - atproto-jetstream = { version = "0.13.0", path = "crates/atproto-jetstream" } 35 - atproto-attestation = { version = "0.13.0", path = "crates/atproto-attestation" } 36 40 41 + ammonia = "4.0" 37 42 anyhow = "1.0" 38 - async-trait = "0.1.88" 39 - base64 = "0.22.1" 40 - chrono = {version = "0.4.41", default-features = false, features = ["std", "now"]} 43 + async-trait = "0.1" 44 + base64 = "0.22" 45 + chrono = {version = "0.4", default-features = false, features = ["std", "now"]} 41 46 clap = { version = "4.5", features = ["derive", "env"] } 42 - ecdsa = { version = "0.16.9", features = ["std"] } 43 - elliptic-curve = { version = "0.13.8", features = ["jwk", "serde"] } 47 + ecdsa = { version = "0.16", features = ["std"] } 48 + elliptic-curve = { version = "0.13", features = ["jwk", "serde"] } 44 49 futures = "0.3" 45 50 hickory-resolver = { version = "0.25" } 46 - http = "1.3.1" 47 - k256 = "0.13.4" 51 + http = "1.3" 52 + k256 = "0.13" 48 53 lru = "0.12" 49 - multibase = "0.9.1" 50 - p256 = "0.13.2" 51 - p384 = "0.13.0" 54 + multibase = "0.9" 55 + p256 = "0.13" 56 + p384 = "0.13" 52 57 rand = "0.8" 58 + regex = "1.11" 53 59 reqwest = { version = "0.12", default-features = false, features = ["charset", "http2", "system-proxy", "json", "rustls-tls"] } 54 - reqwest-chain = "1.0.0" 55 - reqwest-middleware = { version = "0.4.2", features = ["json", "multipart"]} 60 + reqwest-chain = "1.0" 61 + reqwest-middleware = { version = "0.4", features = ["json", "multipart"]} 56 62 rpassword = "7.3" 57 63 secrecy = { version = "0.10", features = ["serde"] } 58 64 serde = { version = "1.0", features = ["derive"] } 59 - serde_ipld_dagcbor = "0.6.3" 60 - serde_json = "1.0" 61 - sha2 = "0.10.9" 65 + serde_ipld_dagcbor = "0.6" 66 + serde_json = { version = "1.0", features = ["unbounded_depth"] } 67 + sha2 = "0.10" 62 68 thiserror = "2.0" 63 69 tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread"] } 64 70 tokio-websockets = { version = "0.11.4", features = ["client", "rustls-native-roots", "rand", "ring"] } 65 71 tokio-util = "0.7" 66 72 tracing = { version = "0.1", features = ["async-await"] } 67 - ulid = "1.2.1" 73 + ulid = "1.2" 68 74 zstd = "0.13" 69 75 url = "2.5" 70 76 urlencoding = "2.1" 71 77 72 - zeroize = { version = "1.8.1", features = ["zeroize_derive"] } 78 + zeroize = { version = "1.8", features = ["zeroize_derive"] } 73 79 74 80 [workspace.lints.rust] 75 81 unsafe_code = "forbid"
+4 -4
README.md
··· 131 131 ### XRPC Service 132 132 133 133 ```rust 134 - use atproto_xrpcs::authorization::ResolvingAuthorization; 134 + use atproto_xrpcs::authorization::Authorization; 135 135 use axum::{Json, Router, extract::Query, routing::get}; 136 136 use serde::Deserialize; 137 137 use serde_json::json; ··· 143 143 144 144 async fn handle_hello( 145 145 params: Query<HelloParams>, 146 - authorization: Option<ResolvingAuthorization>, 146 + authorization: Option<Authorization>, 147 147 ) -> Json<serde_json::Value> { 148 148 let subject = params.subject.as_deref().unwrap_or("World"); 149 - 149 + 150 150 let message = if let Some(auth) = authorization { 151 151 format!("Hello, authenticated {}! (caller: {})", subject, auth.subject()) 152 152 } else { 153 153 format!("Hello, {}!", subject) 154 154 }; 155 - 155 + 156 156 Json(json!({ "message": message })) 157 157 } 158 158
+19 -6
crates/atproto-attestation/src/bin/atproto-attestation-verify.rs
··· 47 47 48 48 use anyhow::{Context, Result, anyhow}; 49 49 use atproto_attestation::AnyInput; 50 - use atproto_identity::key::{KeyData, KeyResolver}; 50 + use atproto_identity::key::{KeyData, KeyResolver, identify_key}; 51 51 use clap::Parser; 52 52 use serde_json::Value; 53 53 use std::{ ··· 115 115 attestation: Option<String>, 116 116 } 117 117 118 - struct FakeKeyResolver {} 118 + /// A key resolver that supports `did:key:` identifiers directly. 119 + /// 120 + /// This resolver handles key references that are encoded as `did:key:` strings, 121 + /// parsing them to extract the cryptographic key data. For other DID methods, 122 + /// it returns an error since those would require fetching DID documents. 123 + struct DidKeyResolver {} 119 124 120 125 #[async_trait::async_trait] 121 - impl KeyResolver for FakeKeyResolver { 122 - async fn resolve(&self, _subject: &str) -> Result<KeyData> { 123 - todo!() 126 + impl KeyResolver for DidKeyResolver { 127 + async fn resolve(&self, subject: &str) -> Result<KeyData> { 128 + if subject.starts_with("did:key:") { 129 + identify_key(subject) 130 + .map_err(|e| anyhow!("Failed to parse did:key '{}': {}", subject, e)) 131 + } else { 132 + Err(anyhow!( 133 + "Subject '{}' is not a did:key: identifier. Only did:key: subjects are supported by this resolver.", 134 + subject 135 + )) 136 + } 124 137 } 125 138 } 126 139 ··· 175 188 identity_resolver, 176 189 }; 177 190 178 - let key_resolver = FakeKeyResolver {}; 191 + let key_resolver = DidKeyResolver {}; 179 192 180 193 atproto_attestation::verify_record( 181 194 AnyInput::Serialize(record.clone()),
+6
crates/atproto-client/Cargo.toml
··· 35 35 doc = true 36 36 required-features = ["clap"] 37 37 38 + [[bin]] 39 + name = "atproto-client-put-record" 40 + test = false 41 + bench = false 42 + doc = true 43 + 38 44 [dependencies] 39 45 atproto-identity.workspace = true 40 46 atproto-oauth.workspace = true
+165
crates/atproto-client/src/bin/atproto-client-put-record.rs
··· 1 + //! AT Protocol client tool for writing records to a repository. 2 + //! 3 + //! This binary tool creates or updates records in an AT Protocol repository 4 + //! using app password authentication. It resolves the subject to a DID, 5 + //! creates a session, and writes the record using the putRecord XRPC method. 6 + //! 7 + //! # Usage 8 + //! 9 + //! ```text 10 + //! ATPROTO_PASSWORD=<password> atproto-client-put-record <subject> <record_key> <record_json> 11 + //! ``` 12 + //! 13 + //! # Environment Variables 14 + //! 15 + //! - `ATPROTO_PASSWORD` - Required. App password for authentication. 16 + //! - `CERTIFICATE_BUNDLES` - Custom CA certificate bundles. 17 + //! - `USER_AGENT` - Custom user agent string. 18 + //! - `DNS_NAMESERVERS` - Custom DNS nameservers. 19 + //! - `PLC_HOSTNAME` - Override PLC hostname (default: plc.directory). 20 + 21 + use anyhow::Result; 22 + use atproto_client::{ 23 + client::{AppPasswordAuth, Auth}, 24 + com::atproto::{ 25 + repo::{put_record, PutRecordRequest, PutRecordResponse}, 26 + server::create_session, 27 + }, 28 + errors::CliError, 29 + }; 30 + use atproto_identity::{ 31 + config::{CertificateBundles, DnsNameservers, default_env, optional_env, version}, 32 + plc, 33 + resolve::{HickoryDnsResolver, resolve_subject}, 34 + web, 35 + }; 36 + use std::env; 37 + 38 + fn print_usage() { 39 + eprintln!("Usage: atproto-client-put-record <subject> <record_key> <record_json>"); 40 + eprintln!(); 41 + eprintln!("Arguments:"); 42 + eprintln!(" <subject> Handle or DID of the repository owner"); 43 + eprintln!(" <record_key> Record key (rkey) for the record"); 44 + eprintln!(" <record_json> JSON record data (must include $type field)"); 45 + eprintln!(); 46 + eprintln!("Environment Variables:"); 47 + eprintln!(" ATPROTO_PASSWORD Required. App password for authentication."); 48 + eprintln!(" CERTIFICATE_BUNDLES Custom CA certificate bundles."); 49 + eprintln!(" USER_AGENT Custom user agent string."); 50 + eprintln!(" DNS_NAMESERVERS Custom DNS nameservers."); 51 + eprintln!(" PLC_HOSTNAME Override PLC hostname (default: plc.directory)."); 52 + } 53 + 54 + #[tokio::main] 55 + async fn main() -> Result<()> { 56 + let args: Vec<String> = env::args().collect(); 57 + 58 + if args.len() != 4 { 59 + print_usage(); 60 + std::process::exit(1); 61 + } 62 + 63 + let subject = &args[1]; 64 + let record_key = &args[2]; 65 + let record_json = &args[3]; 66 + 67 + // Get password from environment variable 68 + let password = env::var("ATPROTO_PASSWORD").map_err(|_| { 69 + anyhow::anyhow!("ATPROTO_PASSWORD environment variable is required") 70 + })?; 71 + 72 + // Set up HTTP client configuration 73 + let certificate_bundles: CertificateBundles = optional_env("CERTIFICATE_BUNDLES").try_into()?; 74 + let default_user_agent = format!( 75 + "atproto-identity-rs ({}; +https://tangled.sh/@smokesignal.events/atproto-identity-rs)", 76 + version()? 77 + ); 78 + let user_agent = default_env("USER_AGENT", &default_user_agent); 79 + let dns_nameservers: DnsNameservers = optional_env("DNS_NAMESERVERS").try_into()?; 80 + let plc_hostname = default_env("PLC_HOSTNAME", "plc.directory"); 81 + 82 + let mut client_builder = reqwest::Client::builder(); 83 + for ca_certificate in certificate_bundles.as_ref() { 84 + let cert = std::fs::read(ca_certificate)?; 85 + let cert = reqwest::Certificate::from_pem(&cert)?; 86 + client_builder = client_builder.add_root_certificate(cert); 87 + } 88 + 89 + client_builder = client_builder.user_agent(user_agent); 90 + let http_client = client_builder.build()?; 91 + 92 + let dns_resolver = HickoryDnsResolver::create_resolver(dns_nameservers.as_ref()); 93 + 94 + // Parse the record JSON 95 + let record: serde_json::Value = serde_json::from_str(record_json).map_err(|err| { 96 + tracing::error!(error = ?err, "Failed to parse record JSON"); 97 + anyhow::anyhow!("Failed to parse record JSON: {}", err) 98 + })?; 99 + 100 + // Extract collection from $type field 101 + let collection = record 102 + .get("$type") 103 + .and_then(|v| v.as_str()) 104 + .ok_or_else(|| anyhow::anyhow!("Record must contain a $type field for the collection"))? 105 + .to_string(); 106 + 107 + // Resolve subject to DID 108 + let did = resolve_subject(&http_client, &dns_resolver, subject).await?; 109 + 110 + // Get DID document to find PDS endpoint 111 + let document = if did.starts_with("did:plc:") { 112 + plc::query(&http_client, &plc_hostname, &did).await? 113 + } else if did.starts_with("did:web:") { 114 + web::query(&http_client, &did).await? 115 + } else { 116 + anyhow::bail!("Unsupported DID method: {}", did); 117 + }; 118 + 119 + // Get PDS endpoint from the DID document 120 + let pds_endpoints = document.pds_endpoints(); 121 + let pds_endpoint = pds_endpoints 122 + .first() 123 + .ok_or_else(|| CliError::NoPdsEndpointFound { did: did.clone() })?; 124 + 125 + // Create session 126 + let session = create_session(&http_client, pds_endpoint, &did, &password, None).await?; 127 + 128 + // Set up app password authentication 129 + let auth = Auth::AppPassword(AppPasswordAuth { 130 + access_token: session.access_jwt.clone(), 131 + }); 132 + 133 + // Create put record request 134 + let put_request = PutRecordRequest { 135 + repo: session.did.clone(), 136 + collection, 137 + record_key: record_key.clone(), 138 + validate: true, 139 + record, 140 + swap_commit: None, 141 + swap_record: None, 142 + }; 143 + 144 + // Execute put record 145 + let response = put_record(&http_client, &auth, pds_endpoint, put_request).await?; 146 + 147 + match response { 148 + PutRecordResponse::StrongRef { uri, cid, .. } => { 149 + println!( 150 + "{}", 151 + serde_json::to_string_pretty(&serde_json::json!({ 152 + "uri": uri, 153 + "cid": cid 154 + }))? 155 + ); 156 + } 157 + PutRecordResponse::Error(err) => { 158 + let error_message = err.error_message(); 159 + tracing::error!(error = %error_message, "putRecord failed"); 160 + anyhow::bail!("putRecord failed: {}", error_message); 161 + } 162 + } 163 + 164 + Ok(()) 165 + }
+31 -5
crates/atproto-client/src/record_resolver.rs
··· 1 1 //! Helpers for resolving AT Protocol records referenced by URI. 2 2 3 3 use std::str::FromStr; 4 + use std::sync::Arc; 4 5 5 6 use anyhow::{Result, anyhow, bail}; 6 7 use async_trait::async_trait; 8 + use atproto_identity::traits::IdentityResolver; 7 9 use atproto_record::aturi::ATURI; 8 10 9 11 use crate::{ ··· 24 26 } 25 27 26 28 /// Resolver that fetches records using public XRPC endpoints. 29 + /// 30 + /// Uses an identity resolver to dynamically determine the PDS endpoint for each record. 27 31 #[derive(Clone)] 28 32 pub struct HttpRecordResolver { 29 33 http_client: reqwest::Client, 30 - base_url: String, 34 + identity_resolver: Arc<dyn IdentityResolver>, 31 35 } 32 36 33 37 impl HttpRecordResolver { 34 - /// Create a new resolver using the provided HTTP client and PDS base URL. 35 - pub fn new(http_client: reqwest::Client, base_url: impl Into<String>) -> Self { 38 + /// Create a new resolver using the provided HTTP client and identity resolver. 39 + /// 40 + /// The identity resolver is used to dynamically determine the PDS endpoint for each record 41 + /// based on the authority (DID or handle) in the AT URI. 42 + pub fn new( 43 + http_client: reqwest::Client, 44 + identity_resolver: Arc<dyn IdentityResolver>, 45 + ) -> Self { 36 46 Self { 37 47 http_client, 38 - base_url: base_url.into(), 48 + identity_resolver, 39 49 } 40 50 } 41 51 } ··· 47 57 T: serde::de::DeserializeOwned + Send, 48 58 { 49 59 let parsed = ATURI::from_str(aturi).map_err(|error| anyhow!(error))?; 60 + 61 + // Resolve the authority (DID or handle) to get the DID document 62 + let document = self 63 + .identity_resolver 64 + .resolve(&parsed.authority) 65 + .await 66 + .map_err(|error| { 67 + anyhow!("Failed to resolve identity for {}: {}", parsed.authority, error) 68 + })?; 69 + 70 + // Extract PDS endpoint from the DID document 71 + let pds_endpoints = document.pds_endpoints(); 72 + let base_url = pds_endpoints 73 + .first() 74 + .ok_or_else(|| anyhow!("No PDS endpoint found for {}", parsed.authority))?; 75 + 50 76 let auth = Auth::None; 51 77 52 78 let response = get_record( 53 79 &self.http_client, 54 80 &auth, 55 - &self.base_url, 81 + base_url, 56 82 &parsed.authority, 57 83 &parsed.collection, 58 84 &parsed.record_key,
+43
crates/atproto-extras/Cargo.toml
··· 1 + [package] 2 + name = "atproto-extras" 3 + version = "0.13.0" 4 + description = "AT Protocol extras - facet parsing and rich text utilities" 5 + readme = "README.md" 6 + homepage = "https://tangled.sh/@smokesignal.events/atproto-identity-rs" 7 + documentation = "https://docs.rs/atproto-extras" 8 + 9 + edition.workspace = true 10 + rust-version.workspace = true 11 + authors.workspace = true 12 + repository.workspace = true 13 + license.workspace = true 14 + keywords.workspace = true 15 + categories.workspace = true 16 + 17 + [dependencies] 18 + atproto-identity.workspace = true 19 + atproto-record.workspace = true 20 + 21 + anyhow.workspace = true 22 + async-trait.workspace = true 23 + clap = { workspace = true, optional = true } 24 + regex.workspace = true 25 + reqwest = { workspace = true, optional = true } 26 + serde_json = { workspace = true, optional = true } 27 + tokio = { workspace = true, optional = true } 28 + 29 + [dev-dependencies] 30 + tokio = { workspace = true, features = ["macros", "rt"] } 31 + 32 + [features] 33 + default = ["hickory-dns"] 34 + hickory-dns = ["atproto-identity/hickory-dns"] 35 + clap = ["dep:clap"] 36 + cli = ["dep:clap", "dep:serde_json", "dep:tokio", "dep:reqwest"] 37 + 38 + [[bin]] 39 + name = "atproto-extras-parse-facets" 40 + required-features = ["clap", "cli", "hickory-dns"] 41 + 42 + [lints] 43 + workspace = true
+128
crates/atproto-extras/README.md
··· 1 + # atproto-extras 2 + 3 + Extra utilities for AT Protocol applications, including rich text facet parsing. 4 + 5 + ## Features 6 + 7 + - **Facet Parsing**: Extract mentions (`@handle`), URLs, and hashtags (`#tag`) from plain text with correct UTF-8 byte offset calculation 8 + - **Identity Integration**: Resolve mention handles to DIDs during parsing 9 + 10 + ## Installation 11 + 12 + Add to your `Cargo.toml`: 13 + 14 + ```toml 15 + [dependencies] 16 + atproto-extras = "0.13" 17 + ``` 18 + 19 + ## Usage 20 + 21 + ### Parsing Text for Facets 22 + 23 + ```rust 24 + use atproto_extras::{parse_urls, parse_tags}; 25 + use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 26 + 27 + let text = "Check out https://example.com #rust"; 28 + 29 + // Parse URLs and tags - returns Vec<Facet> directly 30 + let url_facets = parse_urls(text); 31 + let tag_facets = parse_tags(text); 32 + 33 + // Each facet includes byte positions and typed features 34 + for facet in url_facets { 35 + if let Some(FacetFeature::Link(link)) = facet.features.first() { 36 + println!("URL at bytes {}..{}: {}", 37 + facet.index.byte_start, facet.index.byte_end, link.uri); 38 + } 39 + } 40 + 41 + for facet in tag_facets { 42 + if let Some(FacetFeature::Tag(tag)) = facet.features.first() { 43 + println!("Tag at bytes {}..{}: #{}", 44 + facet.index.byte_start, facet.index.byte_end, tag.tag); 45 + } 46 + } 47 + ``` 48 + 49 + ### Parsing Mentions 50 + 51 + Mention parsing requires an `IdentityResolver` to convert handles to DIDs: 52 + 53 + ```rust 54 + use atproto_extras::{parse_mentions, FacetLimits}; 55 + use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 56 + 57 + let text = "Hello @alice.bsky.social!"; 58 + let limits = FacetLimits::default(); 59 + 60 + // Requires an async context and IdentityResolver 61 + let facets = parse_mentions(text, &resolver, &limits).await; 62 + 63 + for facet in facets { 64 + if let Some(FacetFeature::Mention(mention)) = facet.features.first() { 65 + println!("Mention at bytes {}..{} resolved to {}", 66 + facet.index.byte_start, facet.index.byte_end, mention.did); 67 + } 68 + } 69 + ``` 70 + 71 + Mentions that cannot be resolved to a valid DID are automatically skipped. Mentions appearing within URLs are also excluded. 72 + 73 + ### Creating AT Protocol Facets 74 + 75 + ```rust 76 + use atproto_extras::{parse_facets_from_text, FacetLimits}; 77 + 78 + let text = "Hello @alice.bsky.social! Check https://rust-lang.org #rust"; 79 + let limits = FacetLimits::default(); 80 + 81 + // Requires an async context and IdentityResolver 82 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 83 + 84 + if let Some(facets) = facets { 85 + for facet in &facets { 86 + println!("Facet at {}..{}", facet.index.byte_start, facet.index.byte_end); 87 + } 88 + } 89 + ``` 90 + 91 + ## Byte Offset Handling 92 + 93 + AT Protocol facets use UTF-8 byte offsets, not character indices. This is critical for correct handling of multi-byte characters like emojis or non-ASCII text. 94 + 95 + ```rust 96 + use atproto_extras::parse_urls; 97 + 98 + // Text with emojis (multi-byte UTF-8 characters) 99 + let text = "โœจ Check https://example.com โœจ"; 100 + 101 + let facets = parse_urls(text); 102 + // Byte positions correctly account for the 4-byte emoji 103 + assert_eq!(facets[0].index.byte_start, 11); // After "โœจ Check " (4 + 1 + 6 = 11 bytes) 104 + ``` 105 + 106 + ## Facet Limits 107 + 108 + Use `FacetLimits` to control the maximum number of facets processed: 109 + 110 + ```rust 111 + use atproto_extras::FacetLimits; 112 + 113 + // Default limits 114 + let limits = FacetLimits::default(); 115 + // mentions_max: 5, tags_max: 5, links_max: 5, max: 10 116 + 117 + // Custom limits 118 + let custom = FacetLimits { 119 + mentions_max: 10, 120 + tags_max: 10, 121 + links_max: 10, 122 + max: 20, 123 + }; 124 + ``` 125 + 126 + ## License 127 + 128 + MIT
+176
crates/atproto-extras/src/bin/atproto-extras-parse-facets.rs
··· 1 + //! Command-line tool for generating AT Protocol facet arrays from text. 2 + //! 3 + //! This tool parses a string and outputs the facet array in JSON format. 4 + //! Facets include mentions (@handle), URLs (https://...), and hashtags (#tag). 5 + //! 6 + //! By default, mentions are detected but output with placeholder DIDs. Use 7 + //! `--resolve-mentions` to resolve handles to actual DIDs (requires network access). 8 + //! 9 + //! # Usage 10 + //! 11 + //! ```bash 12 + //! # Parse facets without resolving mentions 13 + //! cargo run --features clap,serde_json,tokio,hickory-dns --bin atproto-extras-parse-facets -- "Check out https://example.com and #rust" 14 + //! 15 + //! # Resolve mentions to DIDs 16 + //! cargo run --features clap,serde_json,tokio,hickory-dns --bin atproto-extras-parse-facets -- --resolve-mentions "Hello @bsky.app!" 17 + //! ``` 18 + 19 + use atproto_extras::{FacetLimits, parse_mentions, parse_tags, parse_urls}; 20 + use atproto_identity::resolve::{HickoryDnsResolver, InnerIdentityResolver}; 21 + use atproto_record::lexicon::app::bsky::richtext::facet::{ 22 + ByteSlice, Facet, FacetFeature, Mention, 23 + }; 24 + use clap::Parser; 25 + use regex::bytes::Regex; 26 + use std::sync::Arc; 27 + 28 + /// Parse text and output AT Protocol facets as JSON. 29 + #[derive(Parser)] 30 + #[command( 31 + name = "atproto-extras-parse-facets", 32 + version, 33 + about = "Parse text and output AT Protocol facets as JSON", 34 + long_about = "This tool parses a string for mentions, URLs, and hashtags,\n\ 35 + then outputs the corresponding AT Protocol facet array in JSON format.\n\n\ 36 + By default, mentions are detected but output with placeholder DIDs.\n\ 37 + Use --resolve-mentions to resolve handles to actual DIDs (requires network)." 38 + )] 39 + struct Args { 40 + /// The text to parse for facets 41 + text: String, 42 + 43 + /// Resolve mention handles to DIDs (requires network access) 44 + #[arg(long)] 45 + resolve_mentions: bool, 46 + 47 + /// Show debug information on stderr 48 + #[arg(long, short = 'd')] 49 + debug: bool, 50 + } 51 + 52 + /// Parse mention spans from text without resolution (returns placeholder DIDs). 53 + fn parse_mention_spans(text: &str) -> Vec<Facet> { 54 + let mut facets = Vec::new(); 55 + 56 + // Get URL ranges to exclude mentions within URLs 57 + let url_facets = parse_urls(text); 58 + 59 + // Same regex pattern as parse_mentions 60 + let mention_regex = Regex::new( 61 + r"(?:^|[^\w])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)", 62 + ) 63 + .expect("Invalid mention regex"); 64 + 65 + let text_bytes = text.as_bytes(); 66 + 67 + for capture in mention_regex.captures_iter(text_bytes) { 68 + if let Some(mention_match) = capture.get(1) { 69 + let start = mention_match.start(); 70 + let end = mention_match.end(); 71 + 72 + // Check if this mention overlaps with any URL 73 + let overlaps_url = url_facets.iter().any(|facet| { 74 + (start >= facet.index.byte_start && start < facet.index.byte_end) 75 + || (end > facet.index.byte_start && end <= facet.index.byte_end) 76 + }); 77 + 78 + if !overlaps_url { 79 + let handle = std::str::from_utf8(&mention_match.as_bytes()[1..]) 80 + .unwrap_or_default() 81 + .to_string(); 82 + 83 + facets.push(Facet { 84 + index: ByteSlice { 85 + byte_start: start, 86 + byte_end: end, 87 + }, 88 + features: vec![FacetFeature::Mention(Mention { 89 + did: format!("did:plc:<unresolved:{}>", handle), 90 + })], 91 + }); 92 + } 93 + } 94 + } 95 + 96 + facets 97 + } 98 + 99 + #[tokio::main] 100 + async fn main() { 101 + let args = Args::parse(); 102 + let text = &args.text; 103 + let mut facets: Vec<Facet> = Vec::new(); 104 + let limits = FacetLimits::default(); 105 + 106 + // Parse mentions (either resolved or with placeholders) 107 + if args.resolve_mentions { 108 + let http_client = reqwest::Client::new(); 109 + let dns_resolver = HickoryDnsResolver::create_resolver(&[]); 110 + let resolver = InnerIdentityResolver { 111 + http_client, 112 + dns_resolver: Arc::new(dns_resolver), 113 + plc_hostname: "plc.directory".to_string(), 114 + }; 115 + let mention_facets = parse_mentions(text, &resolver, &limits).await; 116 + facets.extend(mention_facets); 117 + } else { 118 + let mention_facets = parse_mention_spans(text); 119 + facets.extend(mention_facets); 120 + } 121 + 122 + // Parse URLs 123 + let url_facets = parse_urls(text); 124 + facets.extend(url_facets); 125 + 126 + // Parse hashtags 127 + let tag_facets = parse_tags(text); 128 + facets.extend(tag_facets); 129 + 130 + // Sort facets by byte_start for consistent output 131 + facets.sort_by_key(|f| f.index.byte_start); 132 + 133 + // Output as JSON 134 + if facets.is_empty() { 135 + println!("null"); 136 + } else { 137 + match serde_json::to_string_pretty(&facets) { 138 + Ok(json) => println!("{}", json), 139 + Err(e) => { 140 + eprintln!( 141 + "error-atproto-extras-parse-facets-1 Error serializing facets: {}", 142 + e 143 + ); 144 + std::process::exit(1); 145 + } 146 + } 147 + } 148 + 149 + // Show debug info if requested 150 + if args.debug { 151 + eprintln!(); 152 + eprintln!("--- Debug Info ---"); 153 + eprintln!("Input text: {:?}", text); 154 + eprintln!("Text length: {} bytes", text.len()); 155 + eprintln!("Facets found: {}", facets.len()); 156 + eprintln!("Mentions resolved: {}", args.resolve_mentions); 157 + 158 + // Show byte slice verification 159 + let text_bytes = text.as_bytes(); 160 + for (i, facet) in facets.iter().enumerate() { 161 + let start = facet.index.byte_start; 162 + let end = facet.index.byte_end; 163 + let slice_text = 164 + std::str::from_utf8(&text_bytes[start..end]).unwrap_or("<invalid utf8>"); 165 + let feature_type = match &facet.features[0] { 166 + FacetFeature::Mention(_) => "mention", 167 + FacetFeature::Link(_) => "link", 168 + FacetFeature::Tag(_) => "tag", 169 + }; 170 + eprintln!( 171 + " [{}] {} @ bytes {}..{}: {:?}", 172 + i, feature_type, start, end, slice_text 173 + ); 174 + } 175 + } 176 + }
+942
crates/atproto-extras/src/facets.rs
··· 1 + //! Rich text facet parsing for AT Protocol. 2 + //! 3 + //! This module provides functionality for extracting semantic annotations (facets) 4 + //! from plain text. Facets include mentions, links (URLs), and hashtags. 5 + //! 6 + //! # Overview 7 + //! 8 + //! AT Protocol rich text uses "facets" to annotate specific byte ranges within text with 9 + //! semantic meaning. This module handles: 10 + //! 11 + //! - **Parsing**: Extract mentions, URLs, and hashtags from plain text 12 + //! - **Facet Creation**: Build proper AT Protocol facet structures with resolved DIDs 13 + //! 14 + //! # Byte Offset Calculation 15 + //! 16 + //! This implementation correctly uses UTF-8 byte offsets as required by AT Protocol. 17 + //! The facets use "inclusive start and exclusive end" byte ranges. All parsing is done 18 + //! using `regex::bytes::Regex` which operates on byte slices and returns byte positions, 19 + //! ensuring correct handling of multi-byte UTF-8 characters (emojis, CJK, accented chars). 20 + //! 21 + //! # Example 22 + //! 23 + //! ```ignore 24 + //! use atproto_extras::facets::{parse_urls, parse_tags, FacetLimits}; 25 + //! use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 26 + //! 27 + //! let text = "Check out https://example.com #rust"; 28 + //! 29 + //! // Parse URLs and tags as Facet objects 30 + //! let url_facets = parse_urls(text); 31 + //! let tag_facets = parse_tags(text); 32 + //! 33 + //! // Access facet data directly 34 + //! for facet in url_facets { 35 + //! if let Some(FacetFeature::Link(link)) = facet.features.first() { 36 + //! println!("URL at bytes {}..{}: {}", 37 + //! facet.index.byte_start, facet.index.byte_end, link.uri); 38 + //! } 39 + //! } 40 + //! ``` 41 + 42 + use atproto_identity::resolve::IdentityResolver; 43 + use atproto_record::lexicon::app::bsky::richtext::facet::{ 44 + ByteSlice, Facet, FacetFeature, Link, Mention, Tag, 45 + }; 46 + use regex::bytes::Regex; 47 + 48 + /// Configuration for facet parsing limits. 49 + /// 50 + /// These limits protect against abuse by capping the number of facets 51 + /// that will be processed. This is important for both performance and 52 + /// security when handling user-generated content. 53 + /// 54 + /// # Example 55 + /// 56 + /// ``` 57 + /// use atproto_extras::FacetLimits; 58 + /// 59 + /// // Use defaults 60 + /// let limits = FacetLimits::default(); 61 + /// 62 + /// // Or customize 63 + /// let custom = FacetLimits { 64 + /// mentions_max: 10, 65 + /// tags_max: 10, 66 + /// links_max: 10, 67 + /// max: 20, 68 + /// }; 69 + /// ``` 70 + #[derive(Debug, Clone, Copy)] 71 + pub struct FacetLimits { 72 + /// Maximum number of mention facets to process (default: 5) 73 + pub mentions_max: usize, 74 + /// Maximum number of tag facets to process (default: 5) 75 + pub tags_max: usize, 76 + /// Maximum number of link facets to process (default: 5) 77 + pub links_max: usize, 78 + /// Maximum total number of facets to process (default: 10) 79 + pub max: usize, 80 + } 81 + 82 + impl Default for FacetLimits { 83 + fn default() -> Self { 84 + Self { 85 + mentions_max: 5, 86 + tags_max: 5, 87 + links_max: 5, 88 + max: 10, 89 + } 90 + } 91 + } 92 + 93 + /// Parse mentions from text and return them as Facet objects with resolved DIDs. 94 + /// 95 + /// This function extracts AT Protocol handle mentions (e.g., `@alice.bsky.social`) 96 + /// from text, resolves each handle to a DID using the provided identity resolver, 97 + /// and returns AT Protocol Facet objects with Mention features. 98 + /// 99 + /// Mentions that cannot be resolved to a valid DID are skipped. Mentions that 100 + /// appear within URLs are also excluded to avoid false positives. 101 + /// 102 + /// # Arguments 103 + /// 104 + /// * `text` - The text to parse for mentions 105 + /// * `identity_resolver` - Resolver for converting handles to DIDs 106 + /// * `limits` - Configuration for maximum mentions to process 107 + /// 108 + /// # Returns 109 + /// 110 + /// A vector of Facet objects for successfully resolved mentions. 111 + /// 112 + /// # Example 113 + /// 114 + /// ```ignore 115 + /// use atproto_extras::{parse_mentions, FacetLimits}; 116 + /// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 117 + /// 118 + /// let text = "Hello @alice.bsky.social!"; 119 + /// let limits = FacetLimits::default(); 120 + /// 121 + /// // Requires an async context and identity resolver 122 + /// let facets = parse_mentions(text, &resolver, &limits).await; 123 + /// 124 + /// for facet in facets { 125 + /// if let Some(FacetFeature::Mention(mention)) = facet.features.first() { 126 + /// println!("Mention {} resolved to {}", 127 + /// &text[facet.index.byte_start..facet.index.byte_end], 128 + /// mention.did); 129 + /// } 130 + /// } 131 + /// ``` 132 + pub async fn parse_mentions( 133 + text: &str, 134 + identity_resolver: &dyn IdentityResolver, 135 + limits: &FacetLimits, 136 + ) -> Vec<Facet> { 137 + let mut facets = Vec::new(); 138 + 139 + // First, parse all URLs to exclude mention matches within them 140 + let url_facets = parse_urls(text); 141 + 142 + // Regex based on: https://atproto.com/specs/handle#handle-identifier-syntax 143 + // Pattern: [$|\W](@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?) 144 + let mention_regex = Regex::new( 145 + r"(?:^|[^\w])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)", 146 + ) 147 + .unwrap(); 148 + 149 + let text_bytes = text.as_bytes(); 150 + let mut mention_count = 0; 151 + 152 + for capture in mention_regex.captures_iter(text_bytes) { 153 + if mention_count >= limits.mentions_max { 154 + break; 155 + } 156 + 157 + if let Some(mention_match) = capture.get(1) { 158 + let start = mention_match.start(); 159 + let end = mention_match.end(); 160 + 161 + // Check if this mention overlaps with any URL 162 + let overlaps_url = url_facets.iter().any(|facet| { 163 + // Check if mention is within or overlaps the URL span 164 + (start >= facet.index.byte_start && start < facet.index.byte_end) 165 + || (end > facet.index.byte_start && end <= facet.index.byte_end) 166 + }); 167 + 168 + // Only process the mention if it doesn't overlap with a URL 169 + if !overlaps_url { 170 + let handle = std::str::from_utf8(&mention_match.as_bytes()[1..]) 171 + .unwrap_or_default() 172 + .to_string(); 173 + 174 + // Try to resolve the handle to a DID 175 + // First try with at:// prefix, then without 176 + let at_uri = format!("at://{}", handle); 177 + let did_result = match identity_resolver.resolve(&at_uri).await { 178 + Ok(doc) => Ok(doc), 179 + Err(_) => identity_resolver.resolve(&handle).await, 180 + }; 181 + 182 + // Only add the mention facet if we successfully resolved the DID 183 + if let Ok(did_doc) = did_result { 184 + facets.push(Facet { 185 + index: ByteSlice { 186 + byte_start: start, 187 + byte_end: end, 188 + }, 189 + features: vec![FacetFeature::Mention(Mention { 190 + did: did_doc.id.to_string(), 191 + })], 192 + }); 193 + mention_count += 1; 194 + } 195 + } 196 + } 197 + } 198 + 199 + facets 200 + } 201 + 202 + /// Parse URLs from text and return them as Facet objects. 203 + /// 204 + /// This function extracts HTTP and HTTPS URLs from text with correct 205 + /// byte position tracking for UTF-8 text, returning AT Protocol Facet objects 206 + /// with Link features. 207 + /// 208 + /// # Supported URL Patterns 209 + /// 210 + /// - HTTP URLs: `http://example.com` 211 + /// - HTTPS URLs: `https://example.com` 212 + /// - URLs with paths, query strings, and fragments 213 + /// - URLs with subdomains: `https://www.example.com` 214 + /// 215 + /// # Example 216 + /// 217 + /// ``` 218 + /// use atproto_extras::parse_urls; 219 + /// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 220 + /// 221 + /// let text = "Visit https://example.com/path?query=1 for more info"; 222 + /// let facets = parse_urls(text); 223 + /// 224 + /// assert_eq!(facets.len(), 1); 225 + /// assert_eq!(facets[0].index.byte_start, 6); 226 + /// assert_eq!(facets[0].index.byte_end, 38); 227 + /// if let Some(FacetFeature::Link(link)) = facets[0].features.first() { 228 + /// assert_eq!(link.uri, "https://example.com/path?query=1"); 229 + /// } 230 + /// ``` 231 + /// 232 + /// # Multi-byte Character Handling 233 + /// 234 + /// Byte positions are correctly calculated even with emojis and other 235 + /// multi-byte UTF-8 characters: 236 + /// 237 + /// ``` 238 + /// use atproto_extras::parse_urls; 239 + /// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 240 + /// 241 + /// let text = "Check out https://example.com now!"; 242 + /// let facets = parse_urls(text); 243 + /// let text_bytes = text.as_bytes(); 244 + /// 245 + /// // The byte slice matches the URL 246 + /// let url_bytes = &text_bytes[facets[0].index.byte_start..facets[0].index.byte_end]; 247 + /// assert_eq!(std::str::from_utf8(url_bytes).unwrap(), "https://example.com"); 248 + /// ``` 249 + pub fn parse_urls(text: &str) -> Vec<Facet> { 250 + let mut facets = Vec::new(); 251 + 252 + // Partial/naive URL regex based on: https://stackoverflow.com/a/3809435 253 + // Pattern: [$|\W](https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]+\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?) 254 + // Modified to use + instead of {1,6} to support longer TLDs and multi-level subdomains 255 + let url_regex = Regex::new( 256 + r"(?:^|[^\w])(https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]+\b(?:[-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)" 257 + ).unwrap(); 258 + 259 + let text_bytes = text.as_bytes(); 260 + for capture in url_regex.captures_iter(text_bytes) { 261 + if let Some(url_match) = capture.get(1) { 262 + let url = std::str::from_utf8(url_match.as_bytes()) 263 + .unwrap_or_default() 264 + .to_string(); 265 + 266 + facets.push(Facet { 267 + index: ByteSlice { 268 + byte_start: url_match.start(), 269 + byte_end: url_match.end(), 270 + }, 271 + features: vec![FacetFeature::Link(Link { uri: url })], 272 + }); 273 + } 274 + } 275 + 276 + facets 277 + } 278 + 279 + /// Parse hashtags from text and return them as Facet objects. 280 + /// 281 + /// This function extracts hashtags (e.g., `#rust`, `#ATProto`) from text, 282 + /// returning AT Protocol Facet objects with Tag features. 283 + /// It supports both standard `#` and full-width `๏ผƒ` (U+FF03) hash symbols. 284 + /// 285 + /// # Tag Syntax 286 + /// 287 + /// - Tags must start with `#` or `๏ผƒ` (full-width) 288 + /// - Tag content follows word character rules (`\w`) 289 + /// - Purely numeric tags (e.g., `#123`) are excluded 290 + /// 291 + /// # Example 292 + /// 293 + /// ``` 294 + /// use atproto_extras::parse_tags; 295 + /// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature; 296 + /// 297 + /// let text = "Learning #rust and #golang today! #100DaysOfCode"; 298 + /// let facets = parse_tags(text); 299 + /// 300 + /// assert_eq!(facets.len(), 3); 301 + /// if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() { 302 + /// assert_eq!(tag.tag, "rust"); 303 + /// } 304 + /// if let Some(FacetFeature::Tag(tag)) = facets[1].features.first() { 305 + /// assert_eq!(tag.tag, "golang"); 306 + /// } 307 + /// if let Some(FacetFeature::Tag(tag)) = facets[2].features.first() { 308 + /// assert_eq!(tag.tag, "100DaysOfCode"); 309 + /// } 310 + /// ``` 311 + /// 312 + /// # Numeric Tags 313 + /// 314 + /// Purely numeric tags are excluded: 315 + /// 316 + /// ``` 317 + /// use atproto_extras::parse_tags; 318 + /// 319 + /// let text = "Item #42 is special"; 320 + /// let facets = parse_tags(text); 321 + /// 322 + /// // #42 is not extracted because it's purely numeric 323 + /// assert_eq!(facets.len(), 0); 324 + /// ``` 325 + pub fn parse_tags(text: &str) -> Vec<Facet> { 326 + let mut facets = Vec::new(); 327 + 328 + // Regex based on: https://github.com/bluesky-social/atproto/blob/d91988fe79030b61b556dd6f16a46f0c3b9d0b44/packages/api/src/rich-text/util.ts 329 + // Simplified for Rust - matches hashtags at word boundaries 330 + // Pattern matches: start of string or non-word char, then # or ๏ผƒ, then tag content 331 + let tag_regex = Regex::new(r"(?:^|[^\w])([#\xEF\xBC\x83])([\w]+(?:[\w]*)*)").unwrap(); 332 + 333 + let text_bytes = text.as_bytes(); 334 + 335 + // Work with bytes for proper position tracking 336 + for capture in tag_regex.captures_iter(text_bytes) { 337 + if let (Some(full_match), Some(hash_match), Some(tag_match)) = 338 + (capture.get(0), capture.get(1), capture.get(2)) 339 + { 340 + // Calculate the absolute byte position of the hash symbol 341 + // The full match includes the preceding character (if any) 342 + // so we need to adjust for that 343 + let match_start = full_match.start(); 344 + let hash_offset = hash_match.start() - full_match.start(); 345 + let start = match_start + hash_offset; 346 + let end = match_start + hash_offset + hash_match.len() + tag_match.len(); 347 + 348 + // Extract just the tag text (without the hash symbol) 349 + let tag = std::str::from_utf8(tag_match.as_bytes()).unwrap_or_default(); 350 + 351 + // Only include tags that are not purely numeric 352 + if !tag.chars().all(|c| c.is_ascii_digit()) { 353 + facets.push(Facet { 354 + index: ByteSlice { 355 + byte_start: start, 356 + byte_end: end, 357 + }, 358 + features: vec![FacetFeature::Tag(Tag { 359 + tag: tag.to_string(), 360 + })], 361 + }); 362 + } 363 + } 364 + } 365 + 366 + facets 367 + } 368 + 369 + /// Parse facets from text and return a vector of Facet objects. 370 + /// 371 + /// This function extracts mentions, URLs, and hashtags from the provided text 372 + /// and creates AT Protocol facets with proper byte indices. 373 + /// 374 + /// Mentions are resolved to actual DIDs using the provided identity resolver. 375 + /// If a handle cannot be resolved to a DID, the mention facet is skipped. 376 + /// 377 + /// # Arguments 378 + /// 379 + /// * `text` - The text to extract facets from 380 + /// * `identity_resolver` - Resolver for converting handles to DIDs 381 + /// * `limits` - Configuration for maximum facets per type and total 382 + /// 383 + /// # Returns 384 + /// 385 + /// Optional vector of facets. Returns `None` if no facets were found. 386 + /// 387 + /// # Example 388 + /// 389 + /// ```ignore 390 + /// use atproto_extras::{parse_facets_from_text, FacetLimits}; 391 + /// 392 + /// let text = "Hello @alice.bsky.social! Check #rust at https://rust-lang.org"; 393 + /// let limits = FacetLimits::default(); 394 + /// 395 + /// // Requires an async context and identity resolver 396 + /// let facets = parse_facets_from_text(text, &resolver, &limits).await; 397 + /// 398 + /// if let Some(facets) = facets { 399 + /// for facet in &facets { 400 + /// println!("Facet at {}..{}", facet.index.byte_start, facet.index.byte_end); 401 + /// } 402 + /// } 403 + /// ``` 404 + /// 405 + /// # Mention Resolution 406 + /// 407 + /// Mentions are only included if the handle resolves to a valid DID: 408 + /// 409 + /// ```ignore 410 + /// let text = "@valid.handle.com and @invalid.handle.xyz"; 411 + /// let facets = parse_facets_from_text(text, &resolver, &limits).await; 412 + /// 413 + /// // Only @valid.handle.com appears as a facet if @invalid.handle.xyz 414 + /// // cannot be resolved to a DID 415 + /// ``` 416 + pub async fn parse_facets_from_text( 417 + text: &str, 418 + identity_resolver: &dyn IdentityResolver, 419 + limits: &FacetLimits, 420 + ) -> Option<Vec<Facet>> { 421 + let mut facets = Vec::new(); 422 + 423 + // Parse mentions (already limited by mentions_max in parse_mentions) 424 + let mention_facets = parse_mentions(text, identity_resolver, limits).await; 425 + facets.extend(mention_facets); 426 + 427 + // Parse URLs (limited by links_max) 428 + let url_facets = parse_urls(text); 429 + for (idx, facet) in url_facets.into_iter().enumerate() { 430 + if idx >= limits.links_max { 431 + break; 432 + } 433 + facets.push(facet); 434 + } 435 + 436 + // Parse hashtags (limited by tags_max) 437 + let tag_facets = parse_tags(text); 438 + for (idx, facet) in tag_facets.into_iter().enumerate() { 439 + if idx >= limits.tags_max { 440 + break; 441 + } 442 + facets.push(facet); 443 + } 444 + 445 + // Apply global facet limit (truncate if exceeds max) 446 + if facets.len() > limits.max { 447 + facets.truncate(limits.max); 448 + } 449 + 450 + // Only return facets if we found any 451 + if !facets.is_empty() { 452 + Some(facets) 453 + } else { 454 + None 455 + } 456 + } 457 + 458 + #[cfg(test)] 459 + mod tests { 460 + use async_trait::async_trait; 461 + use atproto_identity::model::Document; 462 + use std::collections::HashMap; 463 + 464 + use super::*; 465 + 466 + /// Mock identity resolver for testing 467 + struct MockIdentityResolver { 468 + handles_to_dids: HashMap<String, String>, 469 + } 470 + 471 + impl MockIdentityResolver { 472 + fn new() -> Self { 473 + let mut handles_to_dids = HashMap::new(); 474 + handles_to_dids.insert( 475 + "alice.bsky.social".to_string(), 476 + "did:plc:alice123".to_string(), 477 + ); 478 + handles_to_dids.insert( 479 + "at://alice.bsky.social".to_string(), 480 + "did:plc:alice123".to_string(), 481 + ); 482 + Self { handles_to_dids } 483 + } 484 + 485 + fn add_identity(&mut self, handle: &str, did: &str) { 486 + self.handles_to_dids 487 + .insert(handle.to_string(), did.to_string()); 488 + self.handles_to_dids 489 + .insert(format!("at://{}", handle), did.to_string()); 490 + } 491 + } 492 + 493 + #[async_trait] 494 + impl IdentityResolver for MockIdentityResolver { 495 + async fn resolve(&self, handle: &str) -> anyhow::Result<Document> { 496 + let handle_key = handle.to_string(); 497 + 498 + if let Some(did) = self.handles_to_dids.get(&handle_key) { 499 + Ok(Document { 500 + context: vec![], 501 + id: did.clone(), 502 + also_known_as: vec![format!("at://{}", handle_key.trim_start_matches("at://"))], 503 + verification_method: vec![], 504 + service: vec![], 505 + extra: HashMap::new(), 506 + }) 507 + } else { 508 + Err(anyhow::anyhow!("Handle not found")) 509 + } 510 + } 511 + } 512 + 513 + #[tokio::test] 514 + async fn test_parse_facets_from_text_comprehensive() { 515 + let mut resolver = MockIdentityResolver::new(); 516 + resolver.add_identity("bob.test.com", "did:plc:bob456"); 517 + 518 + let limits = FacetLimits::default(); 519 + let text = "Join @alice.bsky.social and @bob.test.com at https://example.com #rust #golang"; 520 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 521 + 522 + assert!(facets.is_some()); 523 + let facets = facets.unwrap(); 524 + assert_eq!(facets.len(), 5); // 2 mentions, 1 URL, 2 hashtags 525 + 526 + // Check first mention 527 + assert_eq!(facets[0].index.byte_start, 5); 528 + assert_eq!(facets[0].index.byte_end, 23); 529 + if let FacetFeature::Mention(ref mention) = facets[0].features[0] { 530 + assert_eq!(mention.did, "did:plc:alice123"); 531 + } else { 532 + panic!("Expected Mention feature"); 533 + } 534 + 535 + // Check second mention 536 + assert_eq!(facets[1].index.byte_start, 28); 537 + assert_eq!(facets[1].index.byte_end, 41); 538 + if let FacetFeature::Mention(mention) = &facets[1].features[0] { 539 + assert_eq!(mention.did, "did:plc:bob456"); 540 + } else { 541 + panic!("Expected Mention feature"); 542 + } 543 + 544 + // Check URL 545 + assert_eq!(facets[2].index.byte_start, 45); 546 + assert_eq!(facets[2].index.byte_end, 64); 547 + if let FacetFeature::Link(link) = &facets[2].features[0] { 548 + assert_eq!(link.uri, "https://example.com"); 549 + } else { 550 + panic!("Expected Link feature"); 551 + } 552 + 553 + // Check first hashtag 554 + assert_eq!(facets[3].index.byte_start, 65); 555 + assert_eq!(facets[3].index.byte_end, 70); 556 + if let FacetFeature::Tag(tag) = &facets[3].features[0] { 557 + assert_eq!(tag.tag, "rust"); 558 + } else { 559 + panic!("Expected Tag feature"); 560 + } 561 + 562 + // Check second hashtag 563 + assert_eq!(facets[4].index.byte_start, 71); 564 + assert_eq!(facets[4].index.byte_end, 78); 565 + if let FacetFeature::Tag(tag) = &facets[4].features[0] { 566 + assert_eq!(tag.tag, "golang"); 567 + } else { 568 + panic!("Expected Tag feature"); 569 + } 570 + } 571 + 572 + #[tokio::test] 573 + async fn test_parse_facets_from_text_with_unresolvable_mention() { 574 + let resolver = MockIdentityResolver::new(); 575 + let limits = FacetLimits::default(); 576 + 577 + // Only alice.bsky.social is in the resolver, not unknown.handle.com 578 + let text = "Contact @unknown.handle.com for details #rust"; 579 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 580 + 581 + assert!(facets.is_some()); 582 + let facets = facets.unwrap(); 583 + // Should only have 1 facet (the hashtag) since the mention couldn't be resolved 584 + assert_eq!(facets.len(), 1); 585 + 586 + // Check that it's the hashtag facet 587 + if let FacetFeature::Tag(tag) = &facets[0].features[0] { 588 + assert_eq!(tag.tag, "rust"); 589 + } else { 590 + panic!("Expected Tag feature"); 591 + } 592 + } 593 + 594 + #[tokio::test] 595 + async fn test_parse_facets_from_text_empty() { 596 + let resolver = MockIdentityResolver::new(); 597 + let limits = FacetLimits::default(); 598 + let text = "No mentions, URLs, or hashtags here"; 599 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 600 + assert!(facets.is_none()); 601 + } 602 + 603 + #[tokio::test] 604 + async fn test_parse_facets_from_text_url_with_at_mention() { 605 + let resolver = MockIdentityResolver::new(); 606 + let limits = FacetLimits::default(); 607 + 608 + // URLs with @ should not create mention facets 609 + let text = "Tangled https://tangled.org/@smokesignal.events"; 610 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 611 + 612 + assert!(facets.is_some()); 613 + let facets = facets.unwrap(); 614 + 615 + // Should have exactly 1 facet (the URL), not 2 (URL + mention) 616 + assert_eq!( 617 + facets.len(), 618 + 1, 619 + "Expected 1 facet (URL only), got {}", 620 + facets.len() 621 + ); 622 + 623 + // Verify it's a link facet, not a mention 624 + if let FacetFeature::Link(link) = &facets[0].features[0] { 625 + assert_eq!(link.uri, "https://tangled.org/@smokesignal.events"); 626 + } else { 627 + panic!("Expected Link feature, got Mention or Tag instead"); 628 + } 629 + } 630 + 631 + #[tokio::test] 632 + async fn test_parse_facets_with_mention_limit() { 633 + let mut resolver = MockIdentityResolver::new(); 634 + resolver.add_identity("bob.test.com", "did:plc:bob456"); 635 + resolver.add_identity("charlie.test.com", "did:plc:charlie789"); 636 + 637 + // Limit to 2 mentions 638 + let limits = FacetLimits { 639 + mentions_max: 2, 640 + tags_max: 5, 641 + links_max: 5, 642 + max: 10, 643 + }; 644 + 645 + let text = "Join @alice.bsky.social @bob.test.com @charlie.test.com"; 646 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 647 + 648 + assert!(facets.is_some()); 649 + let facets = facets.unwrap(); 650 + // Should only have 2 mentions (alice and bob), charlie should be skipped 651 + assert_eq!(facets.len(), 2); 652 + 653 + // Verify they're both mentions 654 + for facet in &facets { 655 + assert!(matches!(facet.features[0], FacetFeature::Mention(_))); 656 + } 657 + } 658 + 659 + #[tokio::test] 660 + async fn test_parse_facets_with_global_limit() { 661 + let mut resolver = MockIdentityResolver::new(); 662 + resolver.add_identity("bob.test.com", "did:plc:bob456"); 663 + 664 + // Very restrictive global limit 665 + let limits = FacetLimits { 666 + mentions_max: 5, 667 + tags_max: 5, 668 + links_max: 5, 669 + max: 3, // Only allow 3 total facets 670 + }; 671 + 672 + let text = 673 + "Join @alice.bsky.social @bob.test.com at https://example.com #rust #golang #python"; 674 + let facets = parse_facets_from_text(text, &resolver, &limits).await; 675 + 676 + assert!(facets.is_some()); 677 + let facets = facets.unwrap(); 678 + // Should be truncated to 3 facets total 679 + assert_eq!(facets.len(), 3); 680 + } 681 + 682 + #[test] 683 + fn test_parse_urls_multiple_links() { 684 + let text = "IETF124 is happening in Montreal, Nov 1st to 7th https://www.ietf.org/meeting/124/\n\nWe're confirmed for two days of ATProto community sessions on Monday, Nov 3rd & Tuesday, Mov 4th at ECTO Co-Op. Many of us will also be participating in the free-to-attend IETF hackathon on Sunday, Nov 2nd.\n\nLatest updates and attendees in the forum https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164"; 685 + 686 + let facets = parse_urls(text); 687 + 688 + // Should find both URLs 689 + assert_eq!( 690 + facets.len(), 691 + 2, 692 + "Expected 2 URLs but found {}", 693 + facets.len() 694 + ); 695 + 696 + // Check first URL 697 + if let Some(FacetFeature::Link(link)) = facets[0].features.first() { 698 + assert_eq!(link.uri, "https://www.ietf.org/meeting/124/"); 699 + } else { 700 + panic!("Expected Link feature"); 701 + } 702 + 703 + // Check second URL 704 + if let Some(FacetFeature::Link(link)) = facets[1].features.first() { 705 + assert_eq!( 706 + link.uri, 707 + "https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164" 708 + ); 709 + } else { 710 + panic!("Expected Link feature"); 711 + } 712 + } 713 + 714 + #[test] 715 + fn test_parse_urls_with_html_entity() { 716 + // Test with the HTML entity &amp; in the text 717 + let text = "IETF124 is happening in Montreal, Nov 1st to 7th https://www.ietf.org/meeting/124/\n\nWe're confirmed for two days of ATProto community sessions on Monday, Nov 3rd &amp; Tuesday, Mov 4th at ECTO Co-Op. Many of us will also be participating in the free-to-attend IETF hackathon on Sunday, Nov 2nd.\n\nLatest updates and attendees in the forum https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164"; 718 + 719 + let facets = parse_urls(text); 720 + 721 + // Should find both URLs 722 + assert_eq!( 723 + facets.len(), 724 + 2, 725 + "Expected 2 URLs but found {}", 726 + facets.len() 727 + ); 728 + 729 + // Check first URL 730 + if let Some(FacetFeature::Link(link)) = facets[0].features.first() { 731 + assert_eq!(link.uri, "https://www.ietf.org/meeting/124/"); 732 + } else { 733 + panic!("Expected Link feature"); 734 + } 735 + 736 + // Check second URL 737 + if let Some(FacetFeature::Link(link)) = facets[1].features.first() { 738 + assert_eq!( 739 + link.uri, 740 + "https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164" 741 + ); 742 + } else { 743 + panic!("Expected Link feature"); 744 + } 745 + } 746 + 747 + #[test] 748 + fn test_byte_offset_with_html_entities() { 749 + // This test demonstrates that HTML entity escaping shifts byte positions. 750 + // The byte positions shift: 751 + // In original: '&' is at byte 8 (1 byte) 752 + // In escaped: '&amp;' starts at byte 8 (5 bytes) 753 + // This causes facet byte offsets to be misaligned if text is escaped before rendering. 754 + 755 + // If we have a URL after the ampersand in the original: 756 + let original_with_url = "Nov 3rd & Tuesday https://example.com"; 757 + let escaped_with_url = "Nov 3rd &amp; Tuesday https://example.com"; 758 + 759 + // Parse URLs from both versions 760 + let original_facets = parse_urls(original_with_url); 761 + let escaped_facets = parse_urls(escaped_with_url); 762 + 763 + // Both should find the URL, but at different byte positions 764 + assert_eq!(original_facets.len(), 1); 765 + assert_eq!(escaped_facets.len(), 1); 766 + 767 + // The byte positions will be different 768 + assert_eq!(original_facets[0].index.byte_start, 18); // After "Nov 3rd & Tuesday " 769 + assert_eq!(escaped_facets[0].index.byte_start, 22); // After "Nov 3rd &amp; Tuesday " (4 extra bytes for &amp;) 770 + } 771 + 772 + #[test] 773 + fn test_parse_urls_from_atproto_record_text() { 774 + // Test parsing URLs from real AT Protocol record description text. 775 + // This demonstrates the correct byte positions that should be used for facets. 776 + let text = "Dev, Power Users, and Generally inquisitive folks get a completely unprofessionally amateur interview. Just a yap sesh where chat is part of the call!\n\nโœจthe danielโœจ & I will be on a Zoom call and I will stream out to https://stream.place/psingletary.com\n\nSubscribe to the publications! https://atprotocalls.leaflet.pub/"; 777 + 778 + let facets = parse_urls(text); 779 + 780 + assert_eq!(facets.len(), 2, "Should find 2 URLs"); 781 + 782 + // First URL: https://stream.place/psingletary.com 783 + assert_eq!(facets[0].index.byte_start, 221); 784 + assert_eq!(facets[0].index.byte_end, 257); 785 + if let Some(FacetFeature::Link(link)) = facets[0].features.first() { 786 + assert_eq!(link.uri, "https://stream.place/psingletary.com"); 787 + } 788 + 789 + // Second URL: https://atprotocalls.leaflet.pub/ 790 + assert_eq!(facets[1].index.byte_start, 290); 791 + assert_eq!(facets[1].index.byte_end, 323); 792 + if let Some(FacetFeature::Link(link)) = facets[1].features.first() { 793 + assert_eq!(link.uri, "https://atprotocalls.leaflet.pub/"); 794 + } 795 + 796 + // Verify the byte slices match the expected text 797 + let text_bytes = text.as_bytes(); 798 + assert_eq!( 799 + std::str::from_utf8(&text_bytes[221..257]).unwrap(), 800 + "https://stream.place/psingletary.com" 801 + ); 802 + assert_eq!( 803 + std::str::from_utf8(&text_bytes[290..323]).unwrap(), 804 + "https://atprotocalls.leaflet.pub/" 805 + ); 806 + } 807 + 808 + #[tokio::test] 809 + async fn test_parse_mentions_basic() { 810 + let resolver = MockIdentityResolver::new(); 811 + let limits = FacetLimits::default(); 812 + let text = "Hello @alice.bsky.social!"; 813 + let facets = parse_mentions(text, &resolver, &limits).await; 814 + 815 + assert_eq!(facets.len(), 1); 816 + assert_eq!(facets[0].index.byte_start, 6); 817 + assert_eq!(facets[0].index.byte_end, 24); 818 + if let Some(FacetFeature::Mention(mention)) = facets[0].features.first() { 819 + assert_eq!(mention.did, "did:plc:alice123"); 820 + } else { 821 + panic!("Expected Mention feature"); 822 + } 823 + } 824 + 825 + #[tokio::test] 826 + async fn test_parse_mentions_multiple() { 827 + let mut resolver = MockIdentityResolver::new(); 828 + resolver.add_identity("bob.example.com", "did:plc:bob456"); 829 + let limits = FacetLimits::default(); 830 + let text = "CC @alice.bsky.social and @bob.example.com"; 831 + let facets = parse_mentions(text, &resolver, &limits).await; 832 + 833 + assert_eq!(facets.len(), 2); 834 + if let Some(FacetFeature::Mention(mention)) = facets[0].features.first() { 835 + assert_eq!(mention.did, "did:plc:alice123"); 836 + } 837 + if let Some(FacetFeature::Mention(mention)) = facets[1].features.first() { 838 + assert_eq!(mention.did, "did:plc:bob456"); 839 + } 840 + } 841 + 842 + #[tokio::test] 843 + async fn test_parse_mentions_unresolvable() { 844 + let resolver = MockIdentityResolver::new(); 845 + let limits = FacetLimits::default(); 846 + // unknown.handle.com is not in the resolver 847 + let text = "Hello @unknown.handle.com!"; 848 + let facets = parse_mentions(text, &resolver, &limits).await; 849 + 850 + // Should be empty since the handle can't be resolved 851 + assert_eq!(facets.len(), 0); 852 + } 853 + 854 + #[tokio::test] 855 + async fn test_parse_mentions_in_url_excluded() { 856 + let resolver = MockIdentityResolver::new(); 857 + let limits = FacetLimits::default(); 858 + // The @smokesignal.events is inside a URL and should not be parsed as a mention 859 + let text = "Check https://tangled.org/@smokesignal.events"; 860 + let facets = parse_mentions(text, &resolver, &limits).await; 861 + 862 + // Should be empty since the mention is inside a URL 863 + assert_eq!(facets.len(), 0); 864 + } 865 + 866 + #[test] 867 + fn test_parse_tags_basic() { 868 + let text = "Learning #rust today!"; 869 + let facets = parse_tags(text); 870 + 871 + assert_eq!(facets.len(), 1); 872 + assert_eq!(facets[0].index.byte_start, 9); 873 + assert_eq!(facets[0].index.byte_end, 14); 874 + if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() { 875 + assert_eq!(tag.tag, "rust"); 876 + } else { 877 + panic!("Expected Tag feature"); 878 + } 879 + } 880 + 881 + #[test] 882 + fn test_parse_tags_multiple() { 883 + let text = "#rust #golang #python are great!"; 884 + let facets = parse_tags(text); 885 + 886 + assert_eq!(facets.len(), 3); 887 + if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() { 888 + assert_eq!(tag.tag, "rust"); 889 + } 890 + if let Some(FacetFeature::Tag(tag)) = facets[1].features.first() { 891 + assert_eq!(tag.tag, "golang"); 892 + } 893 + if let Some(FacetFeature::Tag(tag)) = facets[2].features.first() { 894 + assert_eq!(tag.tag, "python"); 895 + } 896 + } 897 + 898 + #[test] 899 + fn test_parse_tags_excludes_numeric() { 900 + let text = "Item #42 is special #test123"; 901 + let facets = parse_tags(text); 902 + 903 + // #42 should be excluded (purely numeric), #test123 should be included 904 + assert_eq!(facets.len(), 1); 905 + if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() { 906 + assert_eq!(tag.tag, "test123"); 907 + } 908 + } 909 + 910 + #[test] 911 + fn test_parse_urls_basic() { 912 + let text = "Visit https://example.com today!"; 913 + let facets = parse_urls(text); 914 + 915 + assert_eq!(facets.len(), 1); 916 + assert_eq!(facets[0].index.byte_start, 6); 917 + assert_eq!(facets[0].index.byte_end, 25); 918 + if let Some(FacetFeature::Link(link)) = facets[0].features.first() { 919 + assert_eq!(link.uri, "https://example.com"); 920 + } 921 + } 922 + 923 + #[test] 924 + fn test_parse_urls_with_path() { 925 + let text = "Check https://example.com/path/to/page?query=1#section"; 926 + let facets = parse_urls(text); 927 + 928 + assert_eq!(facets.len(), 1); 929 + if let Some(FacetFeature::Link(link)) = facets[0].features.first() { 930 + assert_eq!(link.uri, "https://example.com/path/to/page?query=1#section"); 931 + } 932 + } 933 + 934 + #[test] 935 + fn test_facet_limits_default() { 936 + let limits = FacetLimits::default(); 937 + assert_eq!(limits.mentions_max, 5); 938 + assert_eq!(limits.tags_max, 5); 939 + assert_eq!(limits.links_max, 5); 940 + assert_eq!(limits.max, 10); 941 + } 942 + }
+50
crates/atproto-extras/src/lib.rs
··· 1 + //! Extra utilities for AT Protocol applications. 2 + //! 3 + //! This crate provides additional utilities that complement the core AT Protocol 4 + //! identity and record crates. Currently, it focuses on rich text facet parsing. 5 + //! 6 + //! ## Features 7 + //! 8 + //! - **Facet Parsing**: Extract mentions, URLs, and hashtags from plain text 9 + //! with correct UTF-8 byte offset calculation 10 + //! - **Identity Integration**: Resolve mention handles to DIDs during parsing 11 + //! 12 + //! ## Example 13 + //! 14 + //! ```ignore 15 + //! use atproto_extras::{parse_facets_from_text, FacetLimits}; 16 + //! 17 + //! // Parse facets from text (requires an IdentityResolver) 18 + //! let text = "Hello @alice.bsky.social! Check out https://example.com #rust"; 19 + //! let limits = FacetLimits::default(); 20 + //! let facets = parse_facets_from_text(text, &resolver, &limits).await; 21 + //! ``` 22 + //! 23 + //! ## Byte Offset Calculation 24 + //! 25 + //! This implementation correctly uses UTF-8 byte offsets as required by AT Protocol. 26 + //! The facets use "inclusive start and exclusive end" byte ranges. All parsing is done 27 + //! using `regex::bytes::Regex` which operates on byte slices and returns byte positions, 28 + //! ensuring correct handling of multi-byte UTF-8 characters (emojis, CJK, accented chars). 29 + 30 + #![forbid(unsafe_code)] 31 + #![warn(missing_docs)] 32 + 33 + /// Rich text facet parsing for AT Protocol. 34 + /// 35 + /// This module provides functionality for extracting semantic annotations (facets) 36 + /// from plain text. Facets include: 37 + /// 38 + /// - **Mentions**: User handles prefixed with `@` (e.g., `@alice.bsky.social`) 39 + /// - **Links**: HTTP/HTTPS URLs 40 + /// - **Tags**: Hashtags prefixed with `#` or `๏ผƒ` (e.g., `#rust`) 41 + /// 42 + /// ## Byte Offsets 43 + /// 44 + /// All facet indices use UTF-8 byte offsets, not character indices. This is 45 + /// critical for correct handling of multi-byte characters like emojis or 46 + /// non-ASCII text. 47 + pub mod facets; 48 + 49 + /// Re-export commonly used types for convenience. 50 + pub use facets::{FacetLimits, parse_facets_from_text, parse_mentions, parse_tags, parse_urls};
+19 -1
crates/atproto-identity/src/model.rs
··· 70 70 /// The DID identifier (e.g., "did:plc:abc123"). 71 71 pub id: String, 72 72 /// Alternative identifiers like handles and domains. 73 + #[serde(default)] 73 74 pub also_known_as: Vec<String>, 74 75 /// Available services for this identity. 76 + #[serde(default)] 75 77 pub service: Vec<Service>, 76 78 77 79 /// Cryptographic verification methods. 78 - #[serde(alias = "verificationMethod")] 80 + #[serde(alias = "verificationMethod", default)] 79 81 pub verification_method: Vec<VerificationMethod>, 80 82 81 83 /// Additional document properties not explicitly defined. ··· 402 404 let document = document.unwrap(); 403 405 assert_eq!(document.id, "did:plc:cbkjy5n7bk3ax2wplmtjofq2"); 404 406 } 407 + } 408 + 409 + #[test] 410 + fn test_deserialize_service_did_document() { 411 + // DID document from api.bsky.app - a service DID without alsoKnownAs 412 + let document = serde_json::from_str::<Document>( 413 + r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1"],"id":"did:web:api.bsky.app","verificationMethod":[{"id":"did:web:api.bsky.app#atproto","type":"Multikey","controller":"did:web:api.bsky.app","publicKeyMultibase":"zQ3shpRzb2NDriwCSSsce6EqGxG23kVktHZc57C3NEcuNy1jg"}],"service":[{"id":"#bsky_notif","type":"BskyNotificationService","serviceEndpoint":"https://api.bsky.app"},{"id":"#bsky_appview","type":"BskyAppView","serviceEndpoint":"https://api.bsky.app"}]}"##, 414 + ); 415 + assert!(document.is_ok(), "Failed to parse: {:?}", document.err()); 416 + 417 + let document = document.unwrap(); 418 + assert_eq!(document.id, "did:web:api.bsky.app"); 419 + assert!(document.also_known_as.is_empty()); 420 + assert_eq!(document.service.len(), 2); 421 + assert_eq!(document.service[0].id, "#bsky_notif"); 422 + assert_eq!(document.service[1].id, "#bsky_appview"); 405 423 } 406 424 }
+75 -24
crates/atproto-jetstream/src/consumer.rs
··· 2 2 //! 3 3 //! WebSocket event consumption with background processing and 4 4 //! customizable event handler dispatch. 5 + //! 6 + //! ## Memory Efficiency 7 + //! 8 + //! This module is optimized for high-throughput event processing with minimal allocations: 9 + //! 10 + //! - **Arc-based event sharing**: Events are wrapped in `Arc` and shared across all handlers, 11 + //! avoiding expensive clones of event data structures. 12 + //! - **Zero-copy handler IDs**: Handler identifiers use string slices to avoid allocations 13 + //! during registration and dispatch. 14 + //! - **Optimized query building**: WebSocket query strings are built with pre-allocated 15 + //! capacity to minimize reallocations. 16 + //! 17 + //! ## Usage 18 + //! 19 + //! Implement the `EventHandler` trait to process events: 20 + //! 21 + //! ```rust 22 + //! use atproto_jetstream::{EventHandler, JetstreamEvent}; 23 + //! use async_trait::async_trait; 24 + //! use std::sync::Arc; 25 + //! use anyhow::Result; 26 + //! 27 + //! struct MyHandler; 28 + //! 29 + //! #[async_trait] 30 + //! impl EventHandler for MyHandler { 31 + //! async fn handle_event(&self, event: Arc<JetstreamEvent>) -> Result<()> { 32 + //! // Process event without cloning 33 + //! Ok(()) 34 + //! } 35 + //! 36 + //! fn handler_id(&self) -> &str { 37 + //! "my-handler" 38 + //! } 39 + //! } 40 + //! ``` 5 41 6 42 use crate::errors::ConsumerError; 7 43 use anyhow::Result; ··· 133 169 #[async_trait] 134 170 pub trait EventHandler: Send + Sync { 135 171 /// Handle a received event 136 - async fn handle_event(&self, event: JetstreamEvent) -> Result<()>; 172 + /// 173 + /// Events are wrapped in Arc to enable efficient sharing across multiple handlers 174 + /// without cloning the entire event data structure. 175 + async fn handle_event(&self, event: Arc<JetstreamEvent>) -> Result<()>; 137 176 138 177 /// Get the handler's identifier 139 - fn handler_id(&self) -> String; 178 + /// 179 + /// Returns a string slice to avoid unnecessary allocations. 180 + fn handler_id(&self) -> &str; 140 181 } 141 182 142 183 #[cfg_attr(debug_assertions, derive(Debug))] ··· 167 208 pub struct Consumer { 168 209 config: ConsumerTaskConfig, 169 210 handlers: Arc<RwLock<HashMap<String, Arc<dyn EventHandler>>>>, 170 - event_sender: Arc<RwLock<Option<broadcast::Sender<JetstreamEvent>>>>, 211 + event_sender: Arc<RwLock<Option<broadcast::Sender<Arc<JetstreamEvent>>>>>, 171 212 } 172 213 173 214 impl Consumer { ··· 185 226 let handler_id = handler.handler_id(); 186 227 let mut handlers = self.handlers.write().await; 187 228 188 - if handlers.contains_key(&handler_id) { 229 + if handlers.contains_key(handler_id) { 189 230 return Err(ConsumerError::HandlerRegistrationFailed(format!( 190 231 "Handler with ID '{}' already registered", 191 232 handler_id ··· 193 234 .into()); 194 235 } 195 236 196 - handlers.insert(handler_id.clone(), handler); 237 + handlers.insert(handler_id.to_string(), handler); 197 238 Ok(()) 198 239 } 199 240 ··· 205 246 } 206 247 207 248 /// Get a broadcast receiver for events 208 - pub async fn get_event_receiver(&self) -> Result<broadcast::Receiver<JetstreamEvent>> { 249 + /// 250 + /// Events are wrapped in Arc to enable efficient sharing without cloning. 251 + pub async fn get_event_receiver(&self) -> Result<broadcast::Receiver<Arc<JetstreamEvent>>> { 209 252 let sender_guard = self.event_sender.read().await; 210 253 match sender_guard.as_ref() { 211 254 Some(sender) => Ok(sender.subscribe()), ··· 249 292 tracing::info!("Starting Jetstream consumer"); 250 293 251 294 // Build WebSocket URL with query parameters 252 - let mut query_params = vec![]; 295 + // Pre-allocate capacity to avoid reallocations during string building 296 + let capacity = 50 // Base parameters 297 + + self.config.collections.len() * 30 // Estimate per collection 298 + + self.config.dids.len() * 60; // Estimate per DID 299 + let mut query_string = String::with_capacity(capacity); 253 300 254 301 // Add compression parameter 255 - query_params.push(format!("compress={}", self.config.compression)); 302 + query_string.push_str("compress="); 303 + query_string.push_str(if self.config.compression { "true" } else { "false" }); 256 304 257 305 // Add requireHello parameter 258 - query_params.push(format!("requireHello={}", self.config.require_hello)); 306 + query_string.push_str("&requireHello="); 307 + query_string.push_str(if self.config.require_hello { "true" } else { "false" }); 259 308 260 309 // Add wantedCollections if specified (each collection as a separate query parameter) 261 310 if !self.config.collections.is_empty() && !self.config.require_hello { 262 311 for collection in &self.config.collections { 263 - query_params.push(format!( 264 - "wantedCollections={}", 265 - urlencoding::encode(collection) 266 - )); 312 + query_string.push_str("&wantedCollections="); 313 + query_string.push_str(&urlencoding::encode(collection)); 267 314 } 268 315 } 269 316 270 317 // Add wantedDids if specified (each DID as a separate query parameter) 271 318 if !self.config.dids.is_empty() && !self.config.require_hello { 272 319 for did in &self.config.dids { 273 - query_params.push(format!("wantedDids={}", urlencoding::encode(did))); 320 + query_string.push_str("&wantedDids="); 321 + query_string.push_str(&urlencoding::encode(did)); 274 322 } 275 323 } 276 324 277 325 // Add maxMessageSizeBytes if specified 278 326 if let Some(max_size) = self.config.max_message_size_bytes { 279 - query_params.push(format!("maxMessageSizeBytes={}", max_size)); 327 + use std::fmt::Write; 328 + write!(&mut query_string, "&maxMessageSizeBytes={}", max_size).unwrap(); 280 329 } 281 330 282 331 // Add cursor if specified 283 332 if let Some(cursor) = self.config.cursor { 284 - query_params.push(format!("cursor={}", cursor)); 333 + use std::fmt::Write; 334 + write!(&mut query_string, "&cursor={}", cursor).unwrap(); 285 335 } 286 - 287 - let query_string = query_params.join("&"); 288 336 let ws_url = Uri::from_str(&format!( 289 337 "wss://{}/subscribe?{}", 290 338 self.config.jetstream_hostname, query_string ··· 335 383 break; 336 384 }, 337 385 () = &mut sleeper => { 338 - // consumer_control_insert(&self.pool, &self.config.jetstream_hostname, time_usec).await?; 339 - 340 386 sleeper.as_mut().reset(Instant::now() + interval); 341 387 }, 342 388 item = client.next() => { ··· 404 450 } 405 451 406 452 /// Dispatch event to all registered handlers 453 + /// 454 + /// Wraps the event in Arc once and shares it across all handlers, 455 + /// avoiding expensive clones of the event data structure. 407 456 async fn dispatch_to_handlers(&self, event: JetstreamEvent) -> Result<()> { 408 457 let handlers = self.handlers.read().await; 458 + let event = Arc::new(event); 409 459 410 460 for (handler_id, handler) in handlers.iter() { 411 461 let handler_span = tracing::debug_span!("handler_dispatch", handler_id = %handler_id); 462 + let event_ref = Arc::clone(&event); 412 463 async { 413 - if let Err(err) = handler.handle_event(event.clone()).await { 464 + if let Err(err) = handler.handle_event(event_ref).await { 414 465 tracing::error!( 415 466 error = ?err, 416 467 handler_id = %handler_id, ··· 440 491 441 492 #[async_trait] 442 493 impl EventHandler for LoggingHandler { 443 - async fn handle_event(&self, _event: JetstreamEvent) -> Result<()> { 494 + async fn handle_event(&self, _event: Arc<JetstreamEvent>) -> Result<()> { 444 495 Ok(()) 445 496 } 446 497 447 - fn handler_id(&self) -> String { 448 - self.id.clone() 498 + fn handler_id(&self) -> &str { 499 + &self.id 449 500 } 450 501 } 451 502
+374 -5
crates/atproto-oauth/src/scopes.rs
··· 38 38 Atproto, 39 39 /// Transition scope for migration operations 40 40 Transition(TransitionScope), 41 + /// Include scope for referencing permission sets by NSID 42 + Include(IncludeScope), 41 43 /// OpenID Connect scope - required for OpenID Connect authentication 42 44 OpenId, 43 45 /// Profile scope - access to user profile information ··· 91 93 Generic, 92 94 /// Email transition operations 93 95 Email, 96 + } 97 + 98 + /// Include scope for referencing permission sets by NSID 99 + #[derive(Debug, Clone, PartialEq, Eq, Hash)] 100 + pub struct IncludeScope { 101 + /// The permission set NSID (e.g., "app.example.authFull") 102 + pub nsid: String, 103 + /// Optional audience DID for inherited RPC permissions 104 + pub aud: Option<String>, 94 105 } 95 106 96 107 /// Blob scope with mime type constraints ··· 310 321 "rpc", 311 322 "atproto", 312 323 "transition", 324 + "include", 313 325 "openid", 314 326 "profile", 315 327 "email", ··· 349 361 "rpc" => Self::parse_rpc(suffix), 350 362 "atproto" => Self::parse_atproto(suffix), 351 363 "transition" => Self::parse_transition(suffix), 364 + "include" => Self::parse_include(suffix), 352 365 "openid" => Self::parse_openid(suffix), 353 366 "profile" => Self::parse_profile(suffix), 354 367 "email" => Self::parse_email(suffix), ··· 573 586 Ok(Scope::Transition(scope)) 574 587 } 575 588 589 + fn parse_include(suffix: Option<&str>) -> Result<Self, ParseError> { 590 + let (nsid, params) = match suffix { 591 + Some(s) => { 592 + if let Some(pos) = s.find('?') { 593 + (&s[..pos], Some(&s[pos + 1..])) 594 + } else { 595 + (s, None) 596 + } 597 + } 598 + None => return Err(ParseError::MissingResource), 599 + }; 600 + 601 + if nsid.is_empty() { 602 + return Err(ParseError::MissingResource); 603 + } 604 + 605 + let aud = if let Some(params) = params { 606 + let parsed_params = parse_query_string(params); 607 + parsed_params 608 + .get("aud") 609 + .and_then(|v| v.first()) 610 + .map(|s| url_decode(s)) 611 + } else { 612 + None 613 + }; 614 + 615 + Ok(Scope::Include(IncludeScope { 616 + nsid: nsid.to_string(), 617 + aud, 618 + })) 619 + } 620 + 576 621 fn parse_openid(suffix: Option<&str>) -> Result<Self, ParseError> { 577 622 if suffix.is_some() { 578 623 return Err(ParseError::InvalidResource( ··· 677 722 if let Some(lxm) = scope.lxm.iter().next() { 678 723 match lxm { 679 724 RpcLexicon::All => "rpc:*".to_string(), 680 - RpcLexicon::Nsid(nsid) => format!("rpc:{}", nsid), 725 + RpcLexicon::Nsid(nsid) => format!("rpc:{}?aud=*", nsid), 726 + } 727 + } else { 728 + "rpc:*".to_string() 729 + } 730 + } else if scope.lxm.len() == 1 && scope.aud.len() == 1 { 731 + // Single lxm and single aud (aud is not All, handled above) 732 + if let (Some(lxm), Some(aud)) = 733 + (scope.lxm.iter().next(), scope.aud.iter().next()) 734 + { 735 + match (lxm, aud) { 736 + (RpcLexicon::Nsid(nsid), RpcAudience::Did(did)) => { 737 + format!("rpc:{}?aud={}", nsid, did) 738 + } 739 + (RpcLexicon::All, RpcAudience::Did(did)) => { 740 + format!("rpc:*?aud={}", did) 741 + } 742 + _ => "rpc:*".to_string(), 681 743 } 682 744 } else { 683 745 "rpc:*".to_string() ··· 713 775 TransitionScope::Generic => "transition:generic".to_string(), 714 776 TransitionScope::Email => "transition:email".to_string(), 715 777 }, 778 + Scope::Include(scope) => { 779 + if let Some(ref aud) = scope.aud { 780 + format!("include:{}?aud={}", scope.nsid, url_encode(aud)) 781 + } else { 782 + format!("include:{}", scope.nsid) 783 + } 784 + } 716 785 Scope::OpenId => "openid".to_string(), 717 786 Scope::Profile => "profile".to_string(), 718 787 Scope::Email => "email".to_string(), ··· 732 801 // Other scopes don't grant transition scopes 733 802 (_, Scope::Transition(_)) => false, 734 803 (Scope::Transition(_), _) => false, 804 + // Include scopes only grant themselves (exact match including aud) 805 + (Scope::Include(a), Scope::Include(b)) => a == b, 806 + // Other scopes don't grant include scopes 807 + (_, Scope::Include(_)) => false, 808 + (Scope::Include(_), _) => false, 735 809 // OpenID Connect scopes only grant themselves 736 810 (Scope::OpenId, Scope::OpenId) => true, 737 811 (Scope::OpenId, _) => false, ··· 873 947 params 874 948 } 875 949 950 + /// Decode a percent-encoded string 951 + fn url_decode(s: &str) -> String { 952 + let mut result = String::with_capacity(s.len()); 953 + let mut chars = s.chars().peekable(); 954 + 955 + while let Some(c) = chars.next() { 956 + if c == '%' { 957 + let hex: String = chars.by_ref().take(2).collect(); 958 + if hex.len() == 2 { 959 + if let Ok(byte) = u8::from_str_radix(&hex, 16) { 960 + result.push(byte as char); 961 + continue; 962 + } 963 + } 964 + result.push('%'); 965 + result.push_str(&hex); 966 + } else { 967 + result.push(c); 968 + } 969 + } 970 + 971 + result 972 + } 973 + 974 + /// Encode a string for use in a URL query parameter 975 + fn url_encode(s: &str) -> String { 976 + let mut result = String::with_capacity(s.len() * 3); 977 + 978 + for c in s.chars() { 979 + match c { 980 + 'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' | ':' => { 981 + result.push(c); 982 + } 983 + _ => { 984 + for byte in c.to_string().as_bytes() { 985 + result.push_str(&format!("%{:02X}", byte)); 986 + } 987 + } 988 + } 989 + } 990 + 991 + result 992 + } 993 + 876 994 /// Error type for scope parsing 877 995 #[derive(Debug, Clone, PartialEq, Eq)] 878 996 pub enum ParseError { ··· 1056 1174 ("repo:foo.bar", "repo:foo.bar"), 1057 1175 ("repo:foo.bar?action=create", "repo:foo.bar?action=create"), 1058 1176 ("rpc:*", "rpc:*"), 1177 + ("rpc:com.example.service", "rpc:com.example.service?aud=*"), 1178 + ( 1179 + "rpc:com.example.service?aud=did:example:123", 1180 + "rpc:com.example.service?aud=did:example:123", 1181 + ), 1059 1182 ]; 1060 1183 1061 1184 for (input, expected) in tests { ··· 1677 1800 1678 1801 // Test with complex scopes including query parameters 1679 1802 let scopes = vec![ 1680 - Scope::parse("rpc:com.example.service?aud=did:example:123&lxm=com.example.method") 1681 - .unwrap(), 1803 + Scope::parse("rpc:com.example.service?aud=did:example:123").unwrap(), 1682 1804 Scope::parse("repo:foo.bar?action=create&action=update").unwrap(), 1683 1805 Scope::parse("blob:image/*?accept=image/png&accept=image/jpeg").unwrap(), 1684 1806 ]; 1685 1807 let result = Scope::serialize_multiple(&scopes); 1686 1808 // The result should be sorted alphabetically 1687 - // Note: RPC scope with query params is serialized as "rpc?aud=...&lxm=..." 1809 + // Single lxm + single aud is serialized as "rpc:[lxm]?aud=[aud]" 1688 1810 assert!(result.starts_with("blob:")); 1689 1811 assert!(result.contains(" repo:")); 1690 - assert!(result.contains("rpc?aud=did:example:123&lxm=com.example.service")); 1812 + assert!(result.contains("rpc:com.example.service?aud=did:example:123")); 1691 1813 1692 1814 // Test with transition scopes 1693 1815 let scopes = vec![ ··· 1835 1957 assert!(!result.contains(&Scope::parse("account:email").unwrap())); 1836 1958 assert!(result.contains(&Scope::parse("account:email?action=manage").unwrap())); 1837 1959 assert!(result.contains(&Scope::parse("account:repo").unwrap())); 1960 + } 1961 + 1962 + #[test] 1963 + fn test_repo_nsid_with_wildcard_suffix() { 1964 + // Test parsing "repo:app.bsky.feed.*" - the asterisk is treated as a literal part of the NSID, 1965 + // not as a wildcard pattern. Only "repo:*" has special wildcard behavior for ALL collections. 1966 + let scope = Scope::parse("repo:app.bsky.feed.*").unwrap(); 1967 + 1968 + // Verify it parses as a specific NSID, not as a wildcard 1969 + assert_eq!( 1970 + scope, 1971 + Scope::Repo(RepoScope { 1972 + collection: RepoCollection::Nsid("app.bsky.feed.*".to_string()), 1973 + actions: { 1974 + let mut actions = BTreeSet::new(); 1975 + actions.insert(RepoAction::Create); 1976 + actions.insert(RepoAction::Update); 1977 + actions.insert(RepoAction::Delete); 1978 + actions 1979 + } 1980 + }) 1981 + ); 1982 + 1983 + // Verify normalization preserves the literal NSID 1984 + assert_eq!(scope.to_string_normalized(), "repo:app.bsky.feed.*"); 1985 + 1986 + // Test that it does NOT grant access to "app.bsky.feed.post" 1987 + // (because "app.bsky.feed.*" is a literal NSID, not a pattern) 1988 + let specific_feed = Scope::parse("repo:app.bsky.feed.post").unwrap(); 1989 + assert!(!scope.grants(&specific_feed)); 1990 + 1991 + // Test that only "repo:*" grants access to "app.bsky.feed.*" 1992 + let repo_all = Scope::parse("repo:*").unwrap(); 1993 + assert!(repo_all.grants(&scope)); 1994 + 1995 + // Test that "repo:app.bsky.feed.*" only grants itself 1996 + assert!(scope.grants(&scope)); 1997 + 1998 + // Test with actions 1999 + let scope_with_create = Scope::parse("repo:app.bsky.feed.*?action=create").unwrap(); 2000 + assert_eq!( 2001 + scope_with_create, 2002 + Scope::Repo(RepoScope { 2003 + collection: RepoCollection::Nsid("app.bsky.feed.*".to_string()), 2004 + actions: { 2005 + let mut actions = BTreeSet::new(); 2006 + actions.insert(RepoAction::Create); 2007 + actions 2008 + } 2009 + }) 2010 + ); 2011 + 2012 + // The full scope (with all actions) grants the create-only scope 2013 + assert!(scope.grants(&scope_with_create)); 2014 + // But the create-only scope does NOT grant the full scope 2015 + assert!(!scope_with_create.grants(&scope)); 2016 + 2017 + // Test parsing multiple scopes with NSID wildcards 2018 + let scopes = Scope::parse_multiple("repo:app.bsky.feed.* repo:app.bsky.graph.* repo:*").unwrap(); 2019 + assert_eq!(scopes.len(), 3); 2020 + 2021 + // Test that parse_multiple_reduced properly reduces when "repo:*" is present 2022 + let reduced = Scope::parse_multiple_reduced("repo:app.bsky.feed.* repo:app.bsky.graph.* repo:*").unwrap(); 2023 + assert_eq!(reduced.len(), 1); 2024 + assert_eq!(reduced[0], repo_all); 2025 + } 2026 + 2027 + #[test] 2028 + fn test_include_scope_parsing() { 2029 + // Test basic include scope 2030 + let scope = Scope::parse("include:app.example.authFull").unwrap(); 2031 + assert_eq!( 2032 + scope, 2033 + Scope::Include(IncludeScope { 2034 + nsid: "app.example.authFull".to_string(), 2035 + aud: None, 2036 + }) 2037 + ); 2038 + 2039 + // Test include scope with audience 2040 + let scope = Scope::parse("include:app.example.authFull?aud=did:web:api.example.com").unwrap(); 2041 + assert_eq!( 2042 + scope, 2043 + Scope::Include(IncludeScope { 2044 + nsid: "app.example.authFull".to_string(), 2045 + aud: Some("did:web:api.example.com".to_string()), 2046 + }) 2047 + ); 2048 + 2049 + // Test include scope with URL-encoded audience (with fragment) 2050 + let scope = Scope::parse("include:app.example.authFull?aud=did:web:api.example.com%23svc_chat").unwrap(); 2051 + assert_eq!( 2052 + scope, 2053 + Scope::Include(IncludeScope { 2054 + nsid: "app.example.authFull".to_string(), 2055 + aud: Some("did:web:api.example.com#svc_chat".to_string()), 2056 + }) 2057 + ); 2058 + 2059 + // Test missing NSID 2060 + assert!(matches!( 2061 + Scope::parse("include"), 2062 + Err(ParseError::MissingResource) 2063 + )); 2064 + 2065 + // Test empty NSID with query params 2066 + assert!(matches!( 2067 + Scope::parse("include:?aud=did:example:123"), 2068 + Err(ParseError::MissingResource) 2069 + )); 2070 + } 2071 + 2072 + #[test] 2073 + fn test_include_scope_normalization() { 2074 + // Test normalization without audience 2075 + let scope = Scope::parse("include:com.example.authBasic").unwrap(); 2076 + assert_eq!(scope.to_string_normalized(), "include:com.example.authBasic"); 2077 + 2078 + // Test normalization with audience (no special chars) 2079 + let scope = Scope::parse("include:com.example.authBasic?aud=did:plc:xyz123").unwrap(); 2080 + assert_eq!( 2081 + scope.to_string_normalized(), 2082 + "include:com.example.authBasic?aud=did:plc:xyz123" 2083 + ); 2084 + 2085 + // Test normalization with URL encoding (fragment needs encoding) 2086 + let scope = Scope::parse("include:app.example.authFull?aud=did:web:api.example.com%23svc_chat").unwrap(); 2087 + let normalized = scope.to_string_normalized(); 2088 + assert_eq!( 2089 + normalized, 2090 + "include:app.example.authFull?aud=did:web:api.example.com%23svc_chat" 2091 + ); 2092 + } 2093 + 2094 + #[test] 2095 + fn test_include_scope_grants() { 2096 + let include1 = Scope::parse("include:app.example.authFull").unwrap(); 2097 + let include2 = Scope::parse("include:app.example.authBasic").unwrap(); 2098 + let include1_with_aud = Scope::parse("include:app.example.authFull?aud=did:plc:xyz").unwrap(); 2099 + let account = Scope::parse("account:email").unwrap(); 2100 + 2101 + // Include scopes only grant themselves (exact match) 2102 + assert!(include1.grants(&include1)); 2103 + assert!(!include1.grants(&include2)); 2104 + assert!(!include1.grants(&include1_with_aud)); // Different because aud differs 2105 + assert!(include1_with_aud.grants(&include1_with_aud)); 2106 + 2107 + // Include scopes don't grant other scope types 2108 + assert!(!include1.grants(&account)); 2109 + assert!(!account.grants(&include1)); 2110 + 2111 + // Include scopes don't grant atproto or transition 2112 + let atproto = Scope::parse("atproto").unwrap(); 2113 + let transition = Scope::parse("transition:generic").unwrap(); 2114 + assert!(!include1.grants(&atproto)); 2115 + assert!(!include1.grants(&transition)); 2116 + assert!(!atproto.grants(&include1)); 2117 + assert!(!transition.grants(&include1)); 2118 + } 2119 + 2120 + #[test] 2121 + fn test_parse_multiple_with_include() { 2122 + let scopes = Scope::parse_multiple("atproto include:app.example.auth repo:*").unwrap(); 2123 + assert_eq!(scopes.len(), 3); 2124 + assert_eq!(scopes[0], Scope::Atproto); 2125 + assert!(matches!(scopes[1], Scope::Include(_))); 2126 + assert!(matches!(scopes[2], Scope::Repo(_))); 2127 + 2128 + // Test with URL-encoded audience 2129 + let scopes = Scope::parse_multiple( 2130 + "include:app.example.auth?aud=did:web:api.example.com%23svc account:email" 2131 + ).unwrap(); 2132 + assert_eq!(scopes.len(), 2); 2133 + if let Scope::Include(inc) = &scopes[0] { 2134 + assert_eq!(inc.nsid, "app.example.auth"); 2135 + assert_eq!(inc.aud, Some("did:web:api.example.com#svc".to_string())); 2136 + } else { 2137 + panic!("Expected Include scope"); 2138 + } 2139 + } 2140 + 2141 + #[test] 2142 + fn test_parse_multiple_reduced_with_include() { 2143 + // Include scopes don't reduce each other (each is distinct) 2144 + let scopes = Scope::parse_multiple_reduced( 2145 + "include:app.example.auth include:app.example.other include:app.example.auth" 2146 + ).unwrap(); 2147 + assert_eq!(scopes.len(), 2); // Duplicates are removed 2148 + assert!(scopes.contains(&Scope::Include(IncludeScope { 2149 + nsid: "app.example.auth".to_string(), 2150 + aud: None, 2151 + }))); 2152 + assert!(scopes.contains(&Scope::Include(IncludeScope { 2153 + nsid: "app.example.other".to_string(), 2154 + aud: None, 2155 + }))); 2156 + 2157 + // Include scopes with different audiences are not duplicates 2158 + let scopes = Scope::parse_multiple_reduced( 2159 + "include:app.example.auth include:app.example.auth?aud=did:plc:xyz" 2160 + ).unwrap(); 2161 + assert_eq!(scopes.len(), 2); 2162 + } 2163 + 2164 + #[test] 2165 + fn test_serialize_multiple_with_include() { 2166 + let scopes = vec![ 2167 + Scope::parse("repo:*").unwrap(), 2168 + Scope::parse("include:app.example.authFull").unwrap(), 2169 + Scope::Atproto, 2170 + ]; 2171 + let result = Scope::serialize_multiple(&scopes); 2172 + assert_eq!(result, "atproto include:app.example.authFull repo:*"); 2173 + 2174 + // Test with URL-encoded audience 2175 + let scopes = vec![ 2176 + Scope::Include(IncludeScope { 2177 + nsid: "app.example.auth".to_string(), 2178 + aud: Some("did:web:api.example.com#svc".to_string()), 2179 + }), 2180 + ]; 2181 + let result = Scope::serialize_multiple(&scopes); 2182 + assert_eq!(result, "include:app.example.auth?aud=did:web:api.example.com%23svc"); 2183 + } 2184 + 2185 + #[test] 2186 + fn test_remove_scope_with_include() { 2187 + let scopes = vec![ 2188 + Scope::Atproto, 2189 + Scope::parse("include:app.example.auth").unwrap(), 2190 + Scope::parse("account:email").unwrap(), 2191 + ]; 2192 + let to_remove = Scope::parse("include:app.example.auth").unwrap(); 2193 + let result = Scope::remove_scope(&scopes, &to_remove); 2194 + assert_eq!(result.len(), 2); 2195 + assert!(!result.contains(&to_remove)); 2196 + assert!(result.contains(&Scope::Atproto)); 2197 + } 2198 + 2199 + #[test] 2200 + fn test_include_scope_roundtrip() { 2201 + // Test that parse and serialize are inverses 2202 + let original = "include:com.example.authBasicFeatures?aud=did:web:api.example.com%23svc_appview"; 2203 + let scope = Scope::parse(original).unwrap(); 2204 + let serialized = scope.to_string_normalized(); 2205 + let reparsed = Scope::parse(&serialized).unwrap(); 2206 + assert_eq!(scope, reparsed); 1838 2207 } 1839 2208 }
+53
crates/atproto-tap/Cargo.toml
··· 1 + [package] 2 + name = "atproto-tap" 3 + version = "0.13.0" 4 + description = "AT Protocol TAP (Trusted Attestation Protocol) service consumer" 5 + readme = "README.md" 6 + homepage = "https://tangled.sh/@smokesignal.events/atproto-identity-rs" 7 + documentation = "https://docs.rs/atproto-tap" 8 + 9 + edition.workspace = true 10 + rust-version.workspace = true 11 + authors.workspace = true 12 + repository.workspace = true 13 + license.workspace = true 14 + keywords.workspace = true 15 + categories.workspace = true 16 + 17 + [dependencies] 18 + tokio = { workspace = true, features = ["sync", "time"] } 19 + tokio-stream = "0.1" 20 + tokio-websockets = { workspace = true } 21 + futures = { workspace = true } 22 + reqwest = { workspace = true } 23 + serde = { workspace = true } 24 + serde_json = { workspace = true } 25 + thiserror = { workspace = true } 26 + tracing = { workspace = true } 27 + http = { workspace = true } 28 + base64 = { workspace = true } 29 + atproto-identity.workspace = true 30 + atproto-client = { workspace = true, optional = true } 31 + 32 + # Memory efficiency 33 + compact_str = { version = "0.8", features = ["serde"] } 34 + itoa = "1.0" 35 + 36 + # Optional for CLI 37 + clap = { workspace = true, optional = true } 38 + tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true } 39 + 40 + [features] 41 + default = [] 42 + clap = ["dep:clap", "dep:tracing-subscriber", "dep:atproto-client", "tokio/rt-multi-thread", "tokio/macros", "tokio/signal"] 43 + 44 + [[bin]] 45 + name = "atproto-tap-client" 46 + required-features = ["clap"] 47 + 48 + [[bin]] 49 + name = "atproto-tap-extras" 50 + required-features = ["clap"] 51 + 52 + [lints] 53 + workspace = true
+351
crates/atproto-tap/src/bin/atproto-tap-client.rs
··· 1 + //! Command-line client for TAP services. 2 + //! 3 + //! This tool provides commands for consuming TAP events and managing tracked repositories. 4 + //! 5 + //! # Usage 6 + //! 7 + //! ```bash 8 + //! # Stream events from a TAP service 9 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 read 10 + //! 11 + //! # Stream with authentication and filters 12 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret read --live-only 13 + //! 14 + //! # Add repositories to track 15 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret repos add did:plc:xyz did:plc:abc 16 + //! 17 + //! # Remove repositories from tracking 18 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret repos remove did:plc:xyz 19 + //! 20 + //! # Resolve a DID to its DID document 21 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 resolve did:plc:xyz 22 + //! 23 + //! # Resolve a DID and only output the handle 24 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 resolve did:plc:xyz --handle-only 25 + //! 26 + //! # Get repository tracking info 27 + //! cargo run --features cli --bin atproto-tap-client -- localhost:2480 info did:plc:xyz 28 + //! ``` 29 + 30 + use atproto_tap::{TapClient, TapConfig, TapEvent, connect}; 31 + use clap::{Parser, Subcommand}; 32 + use std::time::Duration; 33 + use tokio_stream::StreamExt; 34 + 35 + /// TAP service client for consuming events and managing repositories. 36 + #[derive(Parser)] 37 + #[command( 38 + name = "atproto-tap-client", 39 + version, 40 + about = "TAP service client for AT Protocol", 41 + long_about = "Connect to a TAP service to stream repository/identity events or manage tracked repositories.\n\n\ 42 + Events are printed to stdout as JSON, one per line.\n\ 43 + Use Ctrl+C to gracefully stop the consumer." 44 + )] 45 + struct Args { 46 + /// TAP service hostname (e.g., localhost:2480) 47 + hostname: String, 48 + 49 + /// Admin password for authentication 50 + #[arg(short, long, global = true)] 51 + password: Option<String>, 52 + 53 + #[command(subcommand)] 54 + command: Command, 55 + } 56 + 57 + #[derive(Subcommand)] 58 + enum Command { 59 + /// Connect to TAP and stream events as JSON 60 + Read { 61 + /// Disable acknowledgments 62 + #[arg(long)] 63 + no_acks: bool, 64 + 65 + /// Maximum reconnection attempts (0 = unlimited) 66 + #[arg(long, default_value = "0")] 67 + max_reconnects: u32, 68 + 69 + /// Print debug information to stderr 70 + #[arg(short, long)] 71 + debug: bool, 72 + 73 + /// Filter to specific collections (comma-separated) 74 + #[arg(long)] 75 + collections: Option<String>, 76 + 77 + /// Only show live events (skip backfill) 78 + #[arg(long)] 79 + live_only: bool, 80 + }, 81 + 82 + /// Manage tracked repositories 83 + Repos { 84 + #[command(subcommand)] 85 + action: ReposAction, 86 + }, 87 + 88 + /// Resolve a DID to its DID document 89 + Resolve { 90 + /// DID to resolve (e.g., did:plc:xyz123) 91 + did: String, 92 + 93 + /// Only output the handle (instead of full DID document) 94 + #[arg(long)] 95 + handle_only: bool, 96 + }, 97 + 98 + /// Get tracking info for a repository 99 + Info { 100 + /// DID to get info for (e.g., did:plc:xyz123) 101 + did: String, 102 + }, 103 + } 104 + 105 + #[derive(Subcommand)] 106 + enum ReposAction { 107 + /// Add repositories to track 108 + Add { 109 + /// DIDs to add (e.g., did:plc:xyz123) 110 + #[arg(required = true)] 111 + dids: Vec<String>, 112 + }, 113 + 114 + /// Remove repositories from tracking 115 + Remove { 116 + /// DIDs to remove 117 + #[arg(required = true)] 118 + dids: Vec<String>, 119 + }, 120 + } 121 + 122 + #[tokio::main] 123 + async fn main() { 124 + let args = Args::parse(); 125 + 126 + match args.command { 127 + Command::Read { 128 + no_acks, 129 + max_reconnects, 130 + debug, 131 + collections, 132 + live_only, 133 + } => { 134 + run_read( 135 + &args.hostname, 136 + args.password, 137 + no_acks, 138 + max_reconnects, 139 + debug, 140 + collections, 141 + live_only, 142 + ) 143 + .await; 144 + } 145 + Command::Repos { action } => { 146 + run_repos(&args.hostname, args.password, action).await; 147 + } 148 + Command::Resolve { did, handle_only } => { 149 + run_resolve(&args.hostname, args.password, &did, handle_only).await; 150 + } 151 + Command::Info { did } => { 152 + run_info(&args.hostname, args.password, &did).await; 153 + } 154 + } 155 + } 156 + 157 + async fn run_read( 158 + hostname: &str, 159 + password: Option<String>, 160 + no_acks: bool, 161 + max_reconnects: u32, 162 + debug: bool, 163 + collections: Option<String>, 164 + live_only: bool, 165 + ) { 166 + // Initialize tracing if debug mode 167 + if debug { 168 + tracing_subscriber::fmt() 169 + .with_env_filter("atproto_tap=debug") 170 + .with_writer(std::io::stderr) 171 + .init(); 172 + } 173 + 174 + // Build configuration 175 + let mut config_builder = TapConfig::builder() 176 + .hostname(hostname) 177 + .send_acks(!no_acks); 178 + 179 + if let Some(password) = password { 180 + config_builder = config_builder.admin_password(password); 181 + } 182 + 183 + if max_reconnects > 0 { 184 + config_builder = config_builder.max_reconnect_attempts(Some(max_reconnects)); 185 + } 186 + 187 + // Set reasonable defaults for CLI usage 188 + config_builder = config_builder 189 + .initial_reconnect_delay(Duration::from_secs(1)) 190 + .max_reconnect_delay(Duration::from_secs(30)); 191 + 192 + let config = config_builder.build(); 193 + 194 + eprintln!("Connecting to TAP service at {}...", hostname); 195 + 196 + let mut stream = connect(config); 197 + 198 + // Parse collection filters 199 + let collection_filters: Vec<String> = collections 200 + .map(|c| c.split(',').map(|s| s.trim().to_string()).collect()) 201 + .unwrap_or_default(); 202 + 203 + // Handle Ctrl+C 204 + let ctrl_c = tokio::signal::ctrl_c(); 205 + tokio::pin!(ctrl_c); 206 + 207 + loop { 208 + tokio::select! { 209 + Some(result) = stream.next() => { 210 + match result { 211 + Ok(event) => { 212 + // Apply filters 213 + let should_print = match event.as_ref() { 214 + TapEvent::Record { record, .. } => { 215 + // Filter by live flag 216 + if live_only && !record.live { 217 + false 218 + } 219 + // Filter by collection 220 + else if !collection_filters.is_empty() { 221 + collection_filters.iter().any(|c| record.collection.as_ref() == c) 222 + } else { 223 + true 224 + } 225 + } 226 + TapEvent::Identity { .. } => !live_only, // Always show identity unless live_only 227 + }; 228 + 229 + if should_print { 230 + // Print as JSON to stdout 231 + match serde_json::to_string(event.as_ref()) { 232 + Ok(json) => println!("{}", json), 233 + Err(e) => { 234 + eprintln!("Failed to serialize event: {}", e); 235 + } 236 + } 237 + } 238 + } 239 + Err(e) => { 240 + eprintln!("Error: {}", e); 241 + 242 + // Exit on fatal errors 243 + if e.is_fatal() { 244 + eprintln!("Fatal error, exiting"); 245 + std::process::exit(1); 246 + } 247 + } 248 + } 249 + } 250 + _ = &mut ctrl_c => { 251 + eprintln!("\nReceived Ctrl+C, shutting down..."); 252 + stream.close().await; 253 + break; 254 + } 255 + } 256 + } 257 + 258 + eprintln!("Client stopped"); 259 + } 260 + 261 + async fn run_repos(hostname: &str, password: Option<String>, action: ReposAction) { 262 + let client = TapClient::new(hostname, password); 263 + 264 + match action { 265 + ReposAction::Add { dids } => { 266 + let did_refs: Vec<&str> = dids.iter().map(|s| s.as_str()).collect(); 267 + 268 + match client.add_repos(&did_refs).await { 269 + Ok(()) => { 270 + eprintln!("Added {} repository(ies) to tracking", dids.len()); 271 + for did in &dids { 272 + println!("{}", did); 273 + } 274 + } 275 + Err(e) => { 276 + eprintln!("Failed to add repositories: {}", e); 277 + std::process::exit(1); 278 + } 279 + } 280 + } 281 + ReposAction::Remove { dids } => { 282 + let did_refs: Vec<&str> = dids.iter().map(|s| s.as_str()).collect(); 283 + 284 + match client.remove_repos(&did_refs).await { 285 + Ok(()) => { 286 + eprintln!("Removed {} repository(ies) from tracking", dids.len()); 287 + for did in &dids { 288 + println!("{}", did); 289 + } 290 + } 291 + Err(e) => { 292 + eprintln!("Failed to remove repositories: {}", e); 293 + std::process::exit(1); 294 + } 295 + } 296 + } 297 + } 298 + } 299 + 300 + async fn run_resolve(hostname: &str, password: Option<String>, did: &str, handle_only: bool) { 301 + let client = TapClient::new(hostname, password); 302 + 303 + match client.resolve(did).await { 304 + Ok(doc) => { 305 + if handle_only { 306 + // Use the handles() method from atproto_identity::model::Document 307 + match doc.handles() { 308 + Some(handle) => println!("{}", handle), 309 + None => { 310 + eprintln!("No handle found in DID document"); 311 + std::process::exit(1); 312 + } 313 + } 314 + } else { 315 + // Print full DID document as JSON 316 + match serde_json::to_string_pretty(&doc) { 317 + Ok(json) => println!("{}", json), 318 + Err(e) => { 319 + eprintln!("Failed to serialize DID document: {}", e); 320 + std::process::exit(1); 321 + } 322 + } 323 + } 324 + } 325 + Err(e) => { 326 + eprintln!("Failed to resolve DID: {}", e); 327 + std::process::exit(1); 328 + } 329 + } 330 + } 331 + 332 + async fn run_info(hostname: &str, password: Option<String>, did: &str) { 333 + let client = TapClient::new(hostname, password); 334 + 335 + match client.info(did).await { 336 + Ok(info) => { 337 + // Print as JSON for easy parsing 338 + match serde_json::to_string_pretty(&info) { 339 + Ok(json) => println!("{}", json), 340 + Err(e) => { 341 + eprintln!("Failed to serialize info: {}", e); 342 + std::process::exit(1); 343 + } 344 + } 345 + } 346 + Err(e) => { 347 + eprintln!("Failed to get repository info: {}", e); 348 + std::process::exit(1); 349 + } 350 + } 351 + }
+214
crates/atproto-tap/src/bin/atproto-tap-extras.rs
··· 1 + //! Additional TAP client utilities for AT Protocol. 2 + //! 3 + //! This tool provides extra commands for managing TAP tracked repositories 4 + //! based on social graph data. 5 + //! 6 + //! # Usage 7 + //! 8 + //! ```bash 9 + //! # Add all accounts followed by a DID to TAP tracking 10 + //! cargo run --features cli --bin atproto-tap-extras -- localhost:2480 repos-add-followers did:plc:xyz 11 + //! 12 + //! # With authentication 13 + //! cargo run --features cli --bin atproto-tap-extras -- localhost:2480 -p secret repos-add-followers did:plc:xyz 14 + //! ``` 15 + 16 + use atproto_client::client::Auth; 17 + use atproto_client::com::atproto::repo::{ListRecordsParams, list_records}; 18 + use atproto_identity::plc::query as plc_query; 19 + use atproto_tap::TapClient; 20 + use clap::{Parser, Subcommand}; 21 + use serde::Deserialize; 22 + 23 + /// TAP extras utility for managing tracked repositories. 24 + #[derive(Parser)] 25 + #[command( 26 + name = "atproto-tap-extras", 27 + version, 28 + about = "TAP extras utility for AT Protocol", 29 + long_about = "Additional utilities for managing TAP tracked repositories based on social graph data." 30 + )] 31 + struct Args { 32 + /// TAP service hostname (e.g., localhost:2480) 33 + hostname: String, 34 + 35 + /// Admin password for TAP authentication 36 + #[arg(short, long, global = true)] 37 + password: Option<String>, 38 + 39 + /// PLC directory hostname for DID resolution 40 + #[arg(long, default_value = "plc.directory", global = true)] 41 + plc_hostname: String, 42 + 43 + #[command(subcommand)] 44 + command: Command, 45 + } 46 + 47 + #[derive(Subcommand)] 48 + enum Command { 49 + /// Add accounts followed by a DID to TAP tracking. 50 + /// 51 + /// Fetches all app.bsky.graph.follow records from the specified DID's repository 52 + /// and adds the followed DIDs to TAP for tracking. 53 + ReposAddFollowers { 54 + /// DID to read followers from (e.g., did:plc:xyz123) 55 + did: String, 56 + 57 + /// Batch size for adding repos to TAP 58 + #[arg(long, default_value = "100")] 59 + batch_size: usize, 60 + 61 + /// Dry run - print DIDs without adding to TAP 62 + #[arg(long)] 63 + dry_run: bool, 64 + }, 65 + } 66 + 67 + /// Follow record structure from app.bsky.graph.follow. 68 + #[derive(Debug, Deserialize)] 69 + struct FollowRecord { 70 + /// The DID of the account being followed. 71 + subject: String, 72 + } 73 + 74 + #[tokio::main] 75 + async fn main() { 76 + let args = Args::parse(); 77 + 78 + match args.command { 79 + Command::ReposAddFollowers { 80 + did, 81 + batch_size, 82 + dry_run, 83 + } => { 84 + run_repos_add_followers( 85 + &args.hostname, 86 + args.password, 87 + &args.plc_hostname, 88 + &did, 89 + batch_size, 90 + dry_run, 91 + ) 92 + .await; 93 + } 94 + } 95 + } 96 + 97 + async fn run_repos_add_followers( 98 + tap_hostname: &str, 99 + tap_password: Option<String>, 100 + plc_hostname: &str, 101 + did: &str, 102 + batch_size: usize, 103 + dry_run: bool, 104 + ) { 105 + let http_client = reqwest::Client::new(); 106 + 107 + // Resolve the DID to get the PDS endpoint 108 + eprintln!("Resolving DID: {}", did); 109 + let document = match plc_query(&http_client, plc_hostname, did).await { 110 + Ok(doc) => doc, 111 + Err(e) => { 112 + eprintln!("Failed to resolve DID: {}", e); 113 + std::process::exit(1); 114 + } 115 + }; 116 + 117 + let pds_endpoints = document.pds_endpoints(); 118 + if pds_endpoints.is_empty() { 119 + eprintln!("No PDS endpoint found in DID document"); 120 + std::process::exit(1); 121 + } 122 + let pds_url = pds_endpoints[0]; 123 + eprintln!("Using PDS: {}", pds_url); 124 + 125 + // Collect all followed DIDs 126 + let mut followed_dids: Vec<String> = Vec::new(); 127 + let mut cursor: Option<String> = None; 128 + let collection = "app.bsky.graph.follow".to_string(); 129 + 130 + eprintln!("Fetching follow records..."); 131 + 132 + loop { 133 + let params = if let Some(c) = cursor.take() { 134 + ListRecordsParams::new().limit(100).cursor(c) 135 + } else { 136 + ListRecordsParams::new().limit(100) 137 + }; 138 + 139 + let response = match list_records::<FollowRecord>( 140 + &http_client, 141 + &Auth::None, 142 + pds_url, 143 + did.to_string(), 144 + collection.clone(), 145 + params, 146 + ) 147 + .await 148 + { 149 + Ok(resp) => resp, 150 + Err(e) => { 151 + eprintln!("Failed to list records: {}", e); 152 + std::process::exit(1); 153 + } 154 + }; 155 + 156 + for record in &response.records { 157 + followed_dids.push(record.value.subject.clone()); 158 + } 159 + 160 + eprintln!( 161 + " Fetched {} records (total: {})", 162 + response.records.len(), 163 + followed_dids.len() 164 + ); 165 + 166 + match response.cursor { 167 + Some(c) if !response.records.is_empty() => { 168 + cursor = Some(c); 169 + } 170 + _ => break, 171 + } 172 + } 173 + 174 + if followed_dids.is_empty() { 175 + eprintln!("No follow records found"); 176 + return; 177 + } 178 + 179 + eprintln!("Found {} followed accounts", followed_dids.len()); 180 + 181 + if dry_run { 182 + eprintln!("\nDry run - would add these DIDs to TAP:"); 183 + for did in &followed_dids { 184 + println!("{}", did); 185 + } 186 + return; 187 + } 188 + 189 + // Add to TAP in batches 190 + let tap_client = TapClient::new(tap_hostname, tap_password); 191 + let mut added = 0; 192 + 193 + for chunk in followed_dids.chunks(batch_size) { 194 + let did_refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect(); 195 + 196 + match tap_client.add_repos(&did_refs).await { 197 + Ok(()) => { 198 + added += chunk.len(); 199 + eprintln!("Added {} DIDs to TAP (total: {})", chunk.len(), added); 200 + } 201 + Err(e) => { 202 + eprintln!("Failed to add repos to TAP: {}", e); 203 + std::process::exit(1); 204 + } 205 + } 206 + } 207 + 208 + eprintln!("Successfully added {} DIDs to TAP", added); 209 + 210 + // Print all added DIDs 211 + for did in &followed_dids { 212 + println!("{}", did); 213 + } 214 + }
+371
crates/atproto-tap/src/client.rs
··· 1 + //! HTTP client for TAP management API. 2 + //! 3 + //! This module provides [`TapClient`] for interacting with the TAP service's 4 + //! HTTP management endpoints, including adding/removing tracked repositories. 5 + 6 + use crate::errors::TapError; 7 + use atproto_identity::model::Document; 8 + use base64::Engine; 9 + use base64::engine::general_purpose::STANDARD as BASE64; 10 + use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue}; 11 + use serde::{Deserialize, Serialize}; 12 + 13 + /// HTTP client for TAP management API. 14 + /// 15 + /// Provides methods for managing which repositories the TAP service tracks, 16 + /// checking service health, and querying repository status. 17 + /// 18 + /// # Example 19 + /// 20 + /// ```ignore 21 + /// use atproto_tap::TapClient; 22 + /// 23 + /// let client = TapClient::new("localhost:2480", Some("admin_password".to_string())); 24 + /// 25 + /// // Add repositories to track 26 + /// client.add_repos(&["did:plc:xyz123", "did:plc:abc456"]).await?; 27 + /// 28 + /// // Check health 29 + /// if client.health().await? { 30 + /// println!("TAP service is healthy"); 31 + /// } 32 + /// ``` 33 + #[derive(Debug, Clone)] 34 + pub struct TapClient { 35 + http_client: reqwest::Client, 36 + base_url: String, 37 + auth_header: Option<HeaderValue>, 38 + } 39 + 40 + impl TapClient { 41 + /// Create a new TAP management client. 42 + /// 43 + /// # Arguments 44 + /// 45 + /// * `hostname` - TAP service hostname (e.g., "localhost:2480") 46 + /// * `admin_password` - Optional admin password for authentication 47 + pub fn new(hostname: &str, admin_password: Option<String>) -> Self { 48 + let auth_header = admin_password.map(|password| { 49 + let credentials = format!("admin:{}", password); 50 + let encoded = BASE64.encode(credentials.as_bytes()); 51 + HeaderValue::from_str(&format!("Basic {}", encoded)) 52 + .expect("Invalid auth header value") 53 + }); 54 + 55 + Self { 56 + http_client: reqwest::Client::new(), 57 + base_url: format!("http://{}", hostname), 58 + auth_header, 59 + } 60 + } 61 + 62 + /// Create default headers for requests. 63 + fn default_headers(&self) -> HeaderMap { 64 + let mut headers = HeaderMap::new(); 65 + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); 66 + if let Some(auth) = &self.auth_header { 67 + headers.insert(AUTHORIZATION, auth.clone()); 68 + } 69 + headers 70 + } 71 + 72 + /// Add repositories to track. 73 + /// 74 + /// Sends a POST request to `/repos/add` with the list of DIDs. 75 + /// 76 + /// # Arguments 77 + /// 78 + /// * `dids` - Slice of DID strings to track 79 + /// 80 + /// # Example 81 + /// 82 + /// ```ignore 83 + /// client.add_repos(&[ 84 + /// "did:plc:z72i7hdynmk6r22z27h6tvur", 85 + /// "did:plc:ewvi7nxzyoun6zhxrhs64oiz", 86 + /// ]).await?; 87 + /// ``` 88 + pub async fn add_repos(&self, dids: &[&str]) -> Result<(), TapError> { 89 + let url = format!("{}/repos/add", self.base_url); 90 + let body = AddReposRequest { 91 + dids: dids.iter().map(|s| s.to_string()).collect(), 92 + }; 93 + 94 + let response = self 95 + .http_client 96 + .post(&url) 97 + .headers(self.default_headers()) 98 + .json(&body) 99 + .send() 100 + .await?; 101 + 102 + if response.status().is_success() { 103 + tracing::debug!(count = dids.len(), "Added repositories to TAP"); 104 + Ok(()) 105 + } else { 106 + let status = response.status().as_u16(); 107 + let message = response.text().await.unwrap_or_default(); 108 + Err(TapError::HttpResponseError { status, message }) 109 + } 110 + } 111 + 112 + /// Remove repositories from tracking. 113 + /// 114 + /// Sends a POST request to `/repos/remove` with the list of DIDs. 115 + /// 116 + /// # Arguments 117 + /// 118 + /// * `dids` - Slice of DID strings to stop tracking 119 + pub async fn remove_repos(&self, dids: &[&str]) -> Result<(), TapError> { 120 + let url = format!("{}/repos/remove", self.base_url); 121 + let body = AddReposRequest { 122 + dids: dids.iter().map(|s| s.to_string()).collect(), 123 + }; 124 + 125 + let response = self 126 + .http_client 127 + .post(&url) 128 + .headers(self.default_headers()) 129 + .json(&body) 130 + .send() 131 + .await?; 132 + 133 + if response.status().is_success() { 134 + tracing::debug!(count = dids.len(), "Removed repositories from TAP"); 135 + Ok(()) 136 + } else { 137 + let status = response.status().as_u16(); 138 + let message = response.text().await.unwrap_or_default(); 139 + Err(TapError::HttpResponseError { status, message }) 140 + } 141 + } 142 + 143 + /// Check service health. 144 + /// 145 + /// Sends a GET request to `/health`. 146 + /// 147 + /// # Returns 148 + /// 149 + /// `true` if the service is healthy, `false` otherwise. 150 + pub async fn health(&self) -> Result<bool, TapError> { 151 + let url = format!("{}/health", self.base_url); 152 + 153 + let response = self 154 + .http_client 155 + .get(&url) 156 + .headers(self.default_headers()) 157 + .send() 158 + .await?; 159 + 160 + Ok(response.status().is_success()) 161 + } 162 + 163 + /// Resolve a DID to its DID document. 164 + /// 165 + /// Sends a GET request to `/resolve/:did`. 166 + /// 167 + /// # Arguments 168 + /// 169 + /// * `did` - The DID to resolve 170 + /// 171 + /// # Returns 172 + /// 173 + /// The DID document for the identity. 174 + pub async fn resolve(&self, did: &str) -> Result<Document, TapError> { 175 + let url = format!("{}/resolve/{}", self.base_url, did); 176 + 177 + let response = self 178 + .http_client 179 + .get(&url) 180 + .headers(self.default_headers()) 181 + .send() 182 + .await?; 183 + 184 + if response.status().is_success() { 185 + let doc: Document = response.json().await?; 186 + Ok(doc) 187 + } else { 188 + let status = response.status().as_u16(); 189 + let message = response.text().await.unwrap_or_default(); 190 + Err(TapError::HttpResponseError { status, message }) 191 + } 192 + } 193 + 194 + /// Get info about a tracked repository. 195 + /// 196 + /// Sends a GET request to `/info/:did`. 197 + /// 198 + /// # Arguments 199 + /// 200 + /// * `did` - The DID to get info for 201 + /// 202 + /// # Returns 203 + /// 204 + /// Repository tracking information. 205 + pub async fn info(&self, did: &str) -> Result<RepoInfo, TapError> { 206 + let url = format!("{}/info/{}", self.base_url, did); 207 + 208 + let response = self 209 + .http_client 210 + .get(&url) 211 + .headers(self.default_headers()) 212 + .send() 213 + .await?; 214 + 215 + if response.status().is_success() { 216 + let info: RepoInfo = response.json().await?; 217 + Ok(info) 218 + } else { 219 + let status = response.status().as_u16(); 220 + let message = response.text().await.unwrap_or_default(); 221 + Err(TapError::HttpResponseError { status, message }) 222 + } 223 + } 224 + } 225 + 226 + /// Request body for adding/removing repositories. 227 + #[derive(Debug, Serialize)] 228 + struct AddReposRequest { 229 + dids: Vec<String>, 230 + } 231 + 232 + /// Repository tracking information. 233 + #[derive(Debug, Clone, Serialize, Deserialize)] 234 + pub struct RepoInfo { 235 + /// The repository DID. 236 + pub did: Box<str>, 237 + /// Current sync state. 238 + pub state: RepoState, 239 + /// The handle for the repository. 240 + #[serde(default)] 241 + pub handle: Option<Box<str>>, 242 + /// Number of records in the repository. 243 + #[serde(default)] 244 + pub records: u64, 245 + /// Current repository revision. 246 + #[serde(default)] 247 + pub rev: Option<Box<str>>, 248 + /// Number of retries for syncing. 249 + #[serde(default)] 250 + pub retries: u32, 251 + /// Error message if any. 252 + #[serde(default)] 253 + pub error: Option<Box<str>>, 254 + /// Additional fields may be present depending on TAP version. 255 + #[serde(flatten)] 256 + pub extra: serde_json::Value, 257 + } 258 + 259 + /// Repository sync state. 260 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 261 + #[serde(rename_all = "lowercase")] 262 + pub enum RepoState { 263 + /// Repository is active and synced. 264 + Active, 265 + /// Repository is currently syncing. 266 + Syncing, 267 + /// Repository is fully synced. 268 + Synced, 269 + /// Sync failed for this repository. 270 + Failed, 271 + /// Repository is queued for sync. 272 + Queued, 273 + /// Unknown state. 274 + #[serde(other)] 275 + Unknown, 276 + } 277 + 278 + /// Deprecated alias for RepoState. 279 + #[deprecated(since = "0.13.0", note = "Use RepoState instead")] 280 + pub type RepoStatus = RepoState; 281 + 282 + impl std::fmt::Display for RepoState { 283 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 284 + match self { 285 + RepoState::Active => write!(f, "active"), 286 + RepoState::Syncing => write!(f, "syncing"), 287 + RepoState::Synced => write!(f, "synced"), 288 + RepoState::Failed => write!(f, "failed"), 289 + RepoState::Queued => write!(f, "queued"), 290 + RepoState::Unknown => write!(f, "unknown"), 291 + } 292 + } 293 + } 294 + 295 + #[cfg(test)] 296 + mod tests { 297 + use super::*; 298 + 299 + #[test] 300 + fn test_client_creation() { 301 + let client = TapClient::new("localhost:2480", None); 302 + assert_eq!(client.base_url, "http://localhost:2480"); 303 + assert!(client.auth_header.is_none()); 304 + 305 + let client = TapClient::new("localhost:2480", Some("secret".to_string())); 306 + assert!(client.auth_header.is_some()); 307 + } 308 + 309 + #[test] 310 + fn test_repo_state_display() { 311 + assert_eq!(RepoState::Active.to_string(), "active"); 312 + assert_eq!(RepoState::Syncing.to_string(), "syncing"); 313 + assert_eq!(RepoState::Synced.to_string(), "synced"); 314 + assert_eq!(RepoState::Failed.to_string(), "failed"); 315 + assert_eq!(RepoState::Queued.to_string(), "queued"); 316 + assert_eq!(RepoState::Unknown.to_string(), "unknown"); 317 + } 318 + 319 + #[test] 320 + fn test_repo_state_deserialize() { 321 + let json = r#""active""#; 322 + let state: RepoState = serde_json::from_str(json).unwrap(); 323 + assert_eq!(state, RepoState::Active); 324 + 325 + let json = r#""syncing""#; 326 + let state: RepoState = serde_json::from_str(json).unwrap(); 327 + assert_eq!(state, RepoState::Syncing); 328 + 329 + let json = r#""some_new_state""#; 330 + let state: RepoState = serde_json::from_str(json).unwrap(); 331 + assert_eq!(state, RepoState::Unknown); 332 + } 333 + 334 + #[test] 335 + fn test_repo_info_deserialize() { 336 + let json = r#"{"did":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","error":"","handle":"ngerakines.me","records":21382,"retries":0,"rev":"3mam4aazabs2m","state":"active"}"#; 337 + let info: RepoInfo = serde_json::from_str(json).unwrap(); 338 + assert_eq!(&*info.did, "did:plc:cbkjy5n7bk3ax2wplmtjofq2"); 339 + assert_eq!(info.state, RepoState::Active); 340 + assert_eq!(info.handle.as_deref(), Some("ngerakines.me")); 341 + assert_eq!(info.records, 21382); 342 + assert_eq!(info.retries, 0); 343 + assert_eq!(info.rev.as_deref(), Some("3mam4aazabs2m")); 344 + // Empty string deserializes as Some("") 345 + assert_eq!(info.error.as_deref(), Some("")); 346 + } 347 + 348 + #[test] 349 + fn test_repo_info_deserialize_minimal() { 350 + // Test with only required fields 351 + let json = r#"{"did":"did:plc:test","state":"syncing"}"#; 352 + let info: RepoInfo = serde_json::from_str(json).unwrap(); 353 + assert_eq!(&*info.did, "did:plc:test"); 354 + assert_eq!(info.state, RepoState::Syncing); 355 + assert_eq!(info.handle, None); 356 + assert_eq!(info.records, 0); 357 + assert_eq!(info.retries, 0); 358 + assert_eq!(info.rev, None); 359 + assert_eq!(info.error, None); 360 + } 361 + 362 + #[test] 363 + fn test_add_repos_request_serialize() { 364 + let req = AddReposRequest { 365 + dids: vec!["did:plc:xyz".to_string(), "did:plc:abc".to_string()], 366 + }; 367 + let json = serde_json::to_string(&req).unwrap(); 368 + assert!(json.contains("dids")); 369 + assert!(json.contains("did:plc:xyz")); 370 + } 371 + }
+220
crates/atproto-tap/src/config.rs
··· 1 + //! Configuration for TAP stream connections. 2 + //! 3 + //! This module provides the [`TapConfig`] struct for configuring TAP stream 4 + //! connections, including hostname, authentication, and reconnection behavior. 5 + 6 + use std::time::Duration; 7 + 8 + /// Configuration for a TAP stream connection. 9 + /// 10 + /// Use [`TapConfig::builder()`] for ergonomic construction with defaults. 11 + /// 12 + /// # Example 13 + /// 14 + /// ``` 15 + /// use atproto_tap::TapConfig; 16 + /// use std::time::Duration; 17 + /// 18 + /// let config = TapConfig::builder() 19 + /// .hostname("localhost:2480") 20 + /// .admin_password("secret") 21 + /// .send_acks(true) 22 + /// .max_reconnect_attempts(Some(10)) 23 + /// .build(); 24 + /// ``` 25 + #[derive(Debug, Clone)] 26 + pub struct TapConfig { 27 + /// TAP service hostname (e.g., "localhost:2480"). 28 + /// 29 + /// The WebSocket URL is constructed as `ws://{hostname}/channel`. 30 + pub hostname: String, 31 + 32 + /// Optional admin password for authentication. 33 + /// 34 + /// If set, HTTP Basic Auth is used with username "admin". 35 + pub admin_password: Option<String>, 36 + 37 + /// Whether to send acknowledgments for received messages. 38 + /// 39 + /// Default: `true`. Set to `false` if the TAP service has acks disabled. 40 + pub send_acks: bool, 41 + 42 + /// User-Agent header value for WebSocket connections. 43 + pub user_agent: String, 44 + 45 + /// Maximum reconnection attempts before giving up. 46 + /// 47 + /// `None` means unlimited reconnection attempts (default). 48 + pub max_reconnect_attempts: Option<u32>, 49 + 50 + /// Initial delay before first reconnection attempt. 51 + /// 52 + /// Default: 1 second. 53 + pub initial_reconnect_delay: Duration, 54 + 55 + /// Maximum delay between reconnection attempts. 56 + /// 57 + /// Default: 60 seconds. 58 + pub max_reconnect_delay: Duration, 59 + 60 + /// Multiplier for exponential backoff between reconnections. 61 + /// 62 + /// Default: 2.0 (doubles the delay each attempt). 63 + pub reconnect_backoff_multiplier: f64, 64 + } 65 + 66 + impl Default for TapConfig { 67 + fn default() -> Self { 68 + Self { 69 + hostname: "localhost:2480".to_string(), 70 + admin_password: None, 71 + send_acks: true, 72 + user_agent: format!("atproto-tap/{}", env!("CARGO_PKG_VERSION")), 73 + max_reconnect_attempts: None, 74 + initial_reconnect_delay: Duration::from_secs(1), 75 + max_reconnect_delay: Duration::from_secs(60), 76 + reconnect_backoff_multiplier: 2.0, 77 + } 78 + } 79 + } 80 + 81 + impl TapConfig { 82 + /// Create a new configuration builder with defaults. 83 + pub fn builder() -> TapConfigBuilder { 84 + TapConfigBuilder::default() 85 + } 86 + 87 + /// Create a minimal configuration for the given hostname. 88 + pub fn new(hostname: impl Into<String>) -> Self { 89 + Self { 90 + hostname: hostname.into(), 91 + ..Default::default() 92 + } 93 + } 94 + 95 + /// Returns the WebSocket URL for the TAP channel. 96 + pub fn ws_url(&self) -> String { 97 + format!("ws://{}/channel", self.hostname) 98 + } 99 + 100 + /// Returns the HTTP base URL for the TAP management API. 101 + pub fn http_base_url(&self) -> String { 102 + format!("http://{}", self.hostname) 103 + } 104 + } 105 + 106 + /// Builder for [`TapConfig`]. 107 + #[derive(Debug, Clone, Default)] 108 + pub struct TapConfigBuilder { 109 + config: TapConfig, 110 + } 111 + 112 + impl TapConfigBuilder { 113 + /// Set the TAP service hostname. 114 + pub fn hostname(mut self, hostname: impl Into<String>) -> Self { 115 + self.config.hostname = hostname.into(); 116 + self 117 + } 118 + 119 + /// Set the admin password for authentication. 120 + pub fn admin_password(mut self, password: impl Into<String>) -> Self { 121 + self.config.admin_password = Some(password.into()); 122 + self 123 + } 124 + 125 + /// Set whether to send acknowledgments. 126 + pub fn send_acks(mut self, send_acks: bool) -> Self { 127 + self.config.send_acks = send_acks; 128 + self 129 + } 130 + 131 + /// Set the User-Agent header value. 132 + pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self { 133 + self.config.user_agent = user_agent.into(); 134 + self 135 + } 136 + 137 + /// Set the maximum reconnection attempts. 138 + /// 139 + /// `None` means unlimited attempts. 140 + pub fn max_reconnect_attempts(mut self, max: Option<u32>) -> Self { 141 + self.config.max_reconnect_attempts = max; 142 + self 143 + } 144 + 145 + /// Set the initial reconnection delay. 146 + pub fn initial_reconnect_delay(mut self, delay: Duration) -> Self { 147 + self.config.initial_reconnect_delay = delay; 148 + self 149 + } 150 + 151 + /// Set the maximum reconnection delay. 152 + pub fn max_reconnect_delay(mut self, delay: Duration) -> Self { 153 + self.config.max_reconnect_delay = delay; 154 + self 155 + } 156 + 157 + /// Set the reconnection backoff multiplier. 158 + pub fn reconnect_backoff_multiplier(mut self, multiplier: f64) -> Self { 159 + self.config.reconnect_backoff_multiplier = multiplier; 160 + self 161 + } 162 + 163 + /// Build the configuration. 164 + pub fn build(self) -> TapConfig { 165 + self.config 166 + } 167 + } 168 + 169 + #[cfg(test)] 170 + mod tests { 171 + use super::*; 172 + 173 + #[test] 174 + fn test_default_config() { 175 + let config = TapConfig::default(); 176 + assert_eq!(config.hostname, "localhost:2480"); 177 + assert!(config.admin_password.is_none()); 178 + assert!(config.send_acks); 179 + assert!(config.max_reconnect_attempts.is_none()); 180 + assert_eq!(config.initial_reconnect_delay, Duration::from_secs(1)); 181 + assert_eq!(config.max_reconnect_delay, Duration::from_secs(60)); 182 + assert!((config.reconnect_backoff_multiplier - 2.0).abs() < f64::EPSILON); 183 + } 184 + 185 + #[test] 186 + fn test_builder() { 187 + let config = TapConfig::builder() 188 + .hostname("tap.example.com:2480") 189 + .admin_password("secret123") 190 + .send_acks(false) 191 + .max_reconnect_attempts(Some(5)) 192 + .initial_reconnect_delay(Duration::from_millis(500)) 193 + .max_reconnect_delay(Duration::from_secs(30)) 194 + .reconnect_backoff_multiplier(1.5) 195 + .build(); 196 + 197 + assert_eq!(config.hostname, "tap.example.com:2480"); 198 + assert_eq!(config.admin_password, Some("secret123".to_string())); 199 + assert!(!config.send_acks); 200 + assert_eq!(config.max_reconnect_attempts, Some(5)); 201 + assert_eq!(config.initial_reconnect_delay, Duration::from_millis(500)); 202 + assert_eq!(config.max_reconnect_delay, Duration::from_secs(30)); 203 + assert!((config.reconnect_backoff_multiplier - 1.5).abs() < f64::EPSILON); 204 + } 205 + 206 + #[test] 207 + fn test_ws_url() { 208 + let config = TapConfig::new("localhost:2480"); 209 + assert_eq!(config.ws_url(), "ws://localhost:2480/channel"); 210 + 211 + let config = TapConfig::new("tap.example.com:8080"); 212 + assert_eq!(config.ws_url(), "ws://tap.example.com:8080/channel"); 213 + } 214 + 215 + #[test] 216 + fn test_http_base_url() { 217 + let config = TapConfig::new("localhost:2480"); 218 + assert_eq!(config.http_base_url(), "http://localhost:2480"); 219 + } 220 + }
+168
crates/atproto-tap/src/connection.rs
··· 1 + //! WebSocket connection management for TAP streams. 2 + //! 3 + //! This module handles the low-level WebSocket connection to a TAP service, 4 + //! including authentication and message sending/receiving. 5 + 6 + use crate::config::TapConfig; 7 + use crate::errors::TapError; 8 + use base64::Engine; 9 + use base64::engine::general_purpose::STANDARD as BASE64; 10 + use futures::{SinkExt, StreamExt}; 11 + use http::Uri; 12 + use std::str::FromStr; 13 + use tokio_websockets::{ClientBuilder, Message, WebSocketStream}; 14 + use tokio_websockets::MaybeTlsStream; 15 + use tokio::net::TcpStream; 16 + 17 + /// WebSocket connection to a TAP service. 18 + pub(crate) struct TapConnection { 19 + /// The underlying WebSocket stream. 20 + ws: WebSocketStream<MaybeTlsStream<TcpStream>>, 21 + /// Pre-allocated buffer for acknowledgment messages. 22 + ack_buffer: Vec<u8>, 23 + } 24 + 25 + impl TapConnection { 26 + /// Establish a new WebSocket connection to the TAP service. 27 + pub async fn connect(config: &TapConfig) -> Result<Self, TapError> { 28 + let uri = Uri::from_str(&config.ws_url()) 29 + .map_err(|e| TapError::InvalidUrl(e.to_string()))?; 30 + 31 + let mut builder = ClientBuilder::from_uri(uri); 32 + 33 + // Add User-Agent header 34 + builder = builder 35 + .add_header( 36 + http::header::USER_AGENT, 37 + http::HeaderValue::from_str(&config.user_agent) 38 + .map_err(|e| TapError::ConnectionFailed(format!("Invalid user agent: {}", e)))?, 39 + ) 40 + .map_err(|e| TapError::ConnectionFailed(format!("Failed to add header: {}", e)))?; 41 + 42 + // Add Basic Auth header if password is configured 43 + if let Some(password) = &config.admin_password { 44 + let credentials = format!("admin:{}", password); 45 + let encoded = BASE64.encode(credentials.as_bytes()); 46 + let auth_value = format!("Basic {}", encoded); 47 + 48 + builder = builder 49 + .add_header( 50 + http::header::AUTHORIZATION, 51 + http::HeaderValue::from_str(&auth_value) 52 + .map_err(|e| TapError::ConnectionFailed(format!("Invalid auth header: {}", e)))?, 53 + ) 54 + .map_err(|e| TapError::ConnectionFailed(format!("Failed to add auth header: {}", e)))?; 55 + } 56 + 57 + // Connect 58 + let (ws, _response) = builder 59 + .connect() 60 + .await 61 + .map_err(|e| TapError::ConnectionFailed(e.to_string()))?; 62 + 63 + tracing::debug!(hostname = %config.hostname, "Connected to TAP service"); 64 + 65 + Ok(Self { 66 + ws, 67 + ack_buffer: Vec::with_capacity(48), // {"type":"ack","id":18446744073709551615} is 40 bytes max 68 + }) 69 + } 70 + 71 + /// Receive the next message from the WebSocket. 72 + /// 73 + /// Returns `None` if the connection was closed cleanly. 74 + pub async fn recv(&mut self) -> Result<Option<String>, TapError> { 75 + match self.ws.next().await { 76 + Some(Ok(msg)) => { 77 + if msg.is_text() { 78 + msg.as_text() 79 + .map(|s| Some(s.to_string())) 80 + .ok_or_else(|| TapError::ParseError("Failed to get text from message".into())) 81 + } else if msg.is_close() { 82 + tracing::debug!("Received close frame from TAP service"); 83 + Ok(None) 84 + } else { 85 + // Ignore ping/pong and binary messages 86 + tracing::trace!("Received non-text message, ignoring"); 87 + // Recurse to get the next text message 88 + Box::pin(self.recv()).await 89 + } 90 + } 91 + Some(Err(e)) => Err(TapError::ConnectionFailed(e.to_string())), 92 + None => { 93 + tracing::debug!("WebSocket stream ended"); 94 + Ok(None) 95 + } 96 + } 97 + } 98 + 99 + /// Send an acknowledgment for the given event ID. 100 + /// 101 + /// Uses a pre-allocated buffer and itoa for allocation-free formatting. 102 + /// Format: `{"type":"ack","id":12345}` 103 + pub async fn send_ack(&mut self, id: u64) -> Result<(), TapError> { 104 + self.ack_buffer.clear(); 105 + self.ack_buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 106 + let mut itoa_buf = itoa::Buffer::new(); 107 + self.ack_buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 108 + self.ack_buffer.push(b'}'); 109 + 110 + // All bytes are ASCII so this is always valid UTF-8 111 + let msg = std::str::from_utf8(&self.ack_buffer) 112 + .expect("ack buffer contains only ASCII"); 113 + 114 + self.ws 115 + .send(Message::text(msg.to_string())) 116 + .await 117 + .map_err(|e| TapError::AckFailed(e.to_string()))?; 118 + 119 + // Flush to ensure the ack is sent immediately 120 + self.ws 121 + .flush() 122 + .await 123 + .map_err(|e| TapError::AckFailed(format!("Failed to flush ack: {}", e)))?; 124 + 125 + tracing::trace!(id, "Sent ack"); 126 + Ok(()) 127 + } 128 + 129 + /// Close the WebSocket connection gracefully. 130 + pub async fn close(&mut self) -> Result<(), TapError> { 131 + self.ws 132 + .close() 133 + .await 134 + .map_err(|e| TapError::ConnectionFailed(format!("Failed to close: {}", e)))?; 135 + Ok(()) 136 + } 137 + } 138 + 139 + #[cfg(test)] 140 + mod tests { 141 + #[test] 142 + fn test_ack_buffer_format() { 143 + // Test that our manual JSON formatting is correct 144 + // Format: {"type":"ack","id":12345} 145 + let mut buffer = Vec::with_capacity(64); 146 + 147 + let id: u64 = 12345; 148 + buffer.clear(); 149 + buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 150 + let mut itoa_buf = itoa::Buffer::new(); 151 + buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 152 + buffer.push(b'}'); 153 + 154 + let result = std::str::from_utf8(&buffer).unwrap(); 155 + assert_eq!(result, r#"{"type":"ack","id":12345}"#); 156 + 157 + // Test max u64 158 + let id: u64 = u64::MAX; 159 + buffer.clear(); 160 + buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":"); 161 + buffer.extend_from_slice(itoa_buf.format(id).as_bytes()); 162 + buffer.push(b'}'); 163 + 164 + let result = std::str::from_utf8(&buffer).unwrap(); 165 + assert_eq!(result, r#"{"type":"ack","id":18446744073709551615}"#); 166 + assert!(buffer.len() <= 64); // Fits in our pre-allocated buffer 167 + } 168 + }
+143
crates/atproto-tap/src/errors.rs
··· 1 + //! Error types for TAP operations. 2 + //! 3 + //! This module defines the error types returned by TAP stream and client operations. 4 + 5 + use thiserror::Error; 6 + 7 + /// Errors that can occur during TAP operations. 8 + #[derive(Debug, Error)] 9 + pub enum TapError { 10 + /// WebSocket connection failed. 11 + #[error("error-atproto-tap-connection-1 WebSocket connection failed: {0}")] 12 + ConnectionFailed(String), 13 + 14 + /// Connection was closed unexpectedly. 15 + #[error("error-atproto-tap-connection-2 Connection closed unexpectedly")] 16 + ConnectionClosed, 17 + 18 + /// Maximum reconnection attempts exceeded. 19 + #[error("error-atproto-tap-connection-3 Maximum reconnection attempts exceeded after {0} attempts")] 20 + MaxReconnectAttemptsExceeded(u32), 21 + 22 + /// Authentication failed. 23 + #[error("error-atproto-tap-auth-1 Authentication failed: {0}")] 24 + AuthenticationFailed(String), 25 + 26 + /// Failed to parse a message from the server. 27 + #[error("error-atproto-tap-parse-1 Failed to parse message: {0}")] 28 + ParseError(String), 29 + 30 + /// Failed to send an acknowledgment. 31 + #[error("error-atproto-tap-ack-1 Failed to send acknowledgment: {0}")] 32 + AckFailed(String), 33 + 34 + /// HTTP request failed. 35 + #[error("error-atproto-tap-http-1 HTTP request failed: {0}")] 36 + HttpError(String), 37 + 38 + /// HTTP response indicated an error. 39 + #[error("error-atproto-tap-http-2 HTTP error response: {status} - {message}")] 40 + HttpResponseError { 41 + /// HTTP status code. 42 + status: u16, 43 + /// Error message from response. 44 + message: String, 45 + }, 46 + 47 + /// Invalid URL. 48 + #[error("error-atproto-tap-url-1 Invalid URL: {0}")] 49 + InvalidUrl(String), 50 + 51 + /// I/O error. 52 + #[error("error-atproto-tap-io-1 I/O error: {0}")] 53 + IoError(#[from] std::io::Error), 54 + 55 + /// JSON serialization/deserialization error. 56 + #[error("error-atproto-tap-json-1 JSON error: {0}")] 57 + JsonError(#[from] serde_json::Error), 58 + 59 + /// Stream has been closed and cannot be used. 60 + #[error("error-atproto-tap-stream-1 Stream is closed")] 61 + StreamClosed, 62 + 63 + /// Operation timed out. 64 + #[error("error-atproto-tap-timeout-1 Operation timed out")] 65 + Timeout, 66 + } 67 + 68 + impl TapError { 69 + /// Returns true if this error indicates a connection issue that may be recoverable. 70 + pub fn is_connection_error(&self) -> bool { 71 + matches!( 72 + self, 73 + TapError::ConnectionFailed(_) 74 + | TapError::ConnectionClosed 75 + | TapError::IoError(_) 76 + | TapError::Timeout 77 + ) 78 + } 79 + 80 + /// Returns true if this error is a parse error that doesn't affect connection state. 81 + pub fn is_parse_error(&self) -> bool { 82 + matches!(self, TapError::ParseError(_) | TapError::JsonError(_)) 83 + } 84 + 85 + /// Returns true if this error is fatal and the stream should not attempt recovery. 86 + pub fn is_fatal(&self) -> bool { 87 + matches!( 88 + self, 89 + TapError::MaxReconnectAttemptsExceeded(_) 90 + | TapError::AuthenticationFailed(_) 91 + | TapError::StreamClosed 92 + ) 93 + } 94 + } 95 + 96 + impl From<reqwest::Error> for TapError { 97 + fn from(err: reqwest::Error) -> Self { 98 + if err.is_timeout() { 99 + TapError::Timeout 100 + } else if err.is_connect() { 101 + TapError::ConnectionFailed(err.to_string()) 102 + } else { 103 + TapError::HttpError(err.to_string()) 104 + } 105 + } 106 + } 107 + 108 + #[cfg(test)] 109 + mod tests { 110 + use super::*; 111 + 112 + #[test] 113 + fn test_error_classification() { 114 + assert!(TapError::ConnectionFailed("test".into()).is_connection_error()); 115 + assert!(TapError::ConnectionClosed.is_connection_error()); 116 + assert!(TapError::Timeout.is_connection_error()); 117 + 118 + assert!(TapError::ParseError("test".into()).is_parse_error()); 119 + assert!(TapError::JsonError(serde_json::from_str::<()>("invalid").unwrap_err()).is_parse_error()); 120 + 121 + assert!(TapError::MaxReconnectAttemptsExceeded(5).is_fatal()); 122 + assert!(TapError::AuthenticationFailed("test".into()).is_fatal()); 123 + assert!(TapError::StreamClosed.is_fatal()); 124 + 125 + // Non-fatal errors 126 + assert!(!TapError::ConnectionFailed("test".into()).is_fatal()); 127 + assert!(!TapError::ParseError("test".into()).is_fatal()); 128 + } 129 + 130 + #[test] 131 + fn test_error_display() { 132 + let err = TapError::ConnectionFailed("refused".to_string()); 133 + assert!(err.to_string().contains("error-atproto-tap-connection-1")); 134 + assert!(err.to_string().contains("refused")); 135 + 136 + let err = TapError::HttpResponseError { 137 + status: 404, 138 + message: "Not Found".to_string(), 139 + }; 140 + assert!(err.to_string().contains("404")); 141 + assert!(err.to_string().contains("Not Found")); 142 + } 143 + }
+488
crates/atproto-tap/src/events.rs
··· 1 + //! TAP event types for AT Protocol record and identity events. 2 + //! 3 + //! This module defines the event structures received from a TAP service. 4 + //! Events are optimized for memory efficiency using: 5 + //! - `CompactString` for small strings (SSO for โ‰ค24 bytes) 6 + //! - `Box<str>` for immutable strings (no capacity overhead) 7 + //! - `serde_json::Value` for record payloads (allows lazy access) 8 + 9 + use compact_str::CompactString; 10 + use serde::de::{self, Deserializer, IgnoredAny, MapAccess, Visitor}; 11 + use serde::{Deserialize, Serialize, de::DeserializeOwned}; 12 + use std::fmt; 13 + 14 + /// A TAP event received from the stream. 15 + /// 16 + /// TAP delivers two types of events: 17 + /// - `Record`: Repository record changes (create, update, delete) 18 + /// - `Identity`: Identity/handle changes for accounts 19 + #[derive(Debug, Clone, Serialize, Deserialize)] 20 + #[serde(tag = "type", rename_all = "lowercase")] 21 + pub enum TapEvent { 22 + /// A repository record event (create, update, or delete). 23 + Record { 24 + /// Sequential event identifier. 25 + id: u64, 26 + /// The record event data. 27 + record: RecordEvent, 28 + }, 29 + /// An identity change event. 30 + Identity { 31 + /// Sequential event identifier. 32 + id: u64, 33 + /// The identity event data. 34 + identity: IdentityEvent, 35 + }, 36 + } 37 + 38 + impl TapEvent { 39 + /// Returns the event ID. 40 + pub fn id(&self) -> u64 { 41 + match self { 42 + TapEvent::Record { id, .. } => *id, 43 + TapEvent::Identity { id, .. } => *id, 44 + } 45 + } 46 + } 47 + 48 + /// Extract only the event ID from a JSON string without fully parsing it. 49 + /// 50 + /// This is a fallback parser used when full `TapEvent` parsing fails (e.g., due to 51 + /// deeply nested records hitting serde_json's recursion limit). It uses `IgnoredAny` 52 + /// to efficiently skip over nested content without building data structures, allowing 53 + /// us to extract the ID for acknowledgment even when full parsing fails. 54 + /// 55 + /// # Example 56 + /// 57 + /// ``` 58 + /// use atproto_tap::extract_event_id; 59 + /// 60 + /// let json = r#"{"type":"record","id":12345,"record":{"deeply":"nested"}}"#; 61 + /// assert_eq!(extract_event_id(json), Some(12345)); 62 + /// ``` 63 + pub fn extract_event_id(json: &str) -> Option<u64> { 64 + let mut deserializer = serde_json::Deserializer::from_str(json); 65 + deserializer.disable_recursion_limit(); 66 + EventIdOnly::deserialize(&mut deserializer).ok().map(|e| e.id) 67 + } 68 + 69 + /// Internal struct for extracting only the "id" field from a TAP event. 70 + #[derive(Debug)] 71 + struct EventIdOnly { 72 + id: u64, 73 + } 74 + 75 + impl<'de> Deserialize<'de> for EventIdOnly { 76 + fn deserialize<D>(deserializer: D) -> Result<Self, D::Error> 77 + where 78 + D: Deserializer<'de>, 79 + { 80 + deserializer.deserialize_map(EventIdOnlyVisitor) 81 + } 82 + } 83 + 84 + struct EventIdOnlyVisitor; 85 + 86 + impl<'de> Visitor<'de> for EventIdOnlyVisitor { 87 + type Value = EventIdOnly; 88 + 89 + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { 90 + formatter.write_str("a map with an 'id' field") 91 + } 92 + 93 + fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error> 94 + where 95 + M: MapAccess<'de>, 96 + { 97 + let mut id: Option<u64> = None; 98 + 99 + while let Some(key) = map.next_key::<&str>()? { 100 + if key == "id" { 101 + id = Some(map.next_value()?); 102 + // Found what we need - skip the rest efficiently using IgnoredAny 103 + // which handles deeply nested structures without recursion issues 104 + while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {} 105 + break; 106 + } else { 107 + // Skip this value without fully parsing it 108 + map.next_value::<IgnoredAny>()?; 109 + } 110 + } 111 + 112 + id.map(|id| EventIdOnly { id }) 113 + .ok_or_else(|| de::Error::missing_field("id")) 114 + } 115 + } 116 + 117 + /// A repository record event from TAP. 118 + /// 119 + /// Contains information about a record change in a user's repository, 120 + /// including the action taken and the record data (for creates/updates). 121 + #[derive(Debug, Clone, Serialize, Deserialize)] 122 + pub struct RecordEvent { 123 + /// True if from live firehose, false if from backfill/resync. 124 + /// 125 + /// During initial sync or recovery, TAP delivers historical events 126 + /// with `live: false`. Once caught up, live events have `live: true`. 127 + pub live: bool, 128 + 129 + /// Repository revision identifier. 130 + /// 131 + /// Typically 13 characters, stored inline via CompactString SSO. 132 + pub rev: CompactString, 133 + 134 + /// Actor DID (e.g., "did:plc:xyz123"). 135 + pub did: Box<str>, 136 + 137 + /// Collection NSID (e.g., "app.bsky.feed.post"). 138 + pub collection: Box<str>, 139 + 140 + /// Record key within the collection. 141 + /// 142 + /// Typically a TID (13 characters), stored inline via CompactString SSO. 143 + pub rkey: CompactString, 144 + 145 + /// The action performed on the record. 146 + pub action: RecordAction, 147 + 148 + /// Content identifier (CID) of the record. 149 + /// 150 + /// Present for create and update actions, absent for delete. 151 + #[serde(skip_serializing_if = "Option::is_none")] 152 + pub cid: Option<CompactString>, 153 + 154 + /// Record data as JSON value. 155 + /// 156 + /// Present for create and update actions, absent for delete. 157 + /// Use [`parse_record`](Self::parse_record) to deserialize on demand. 158 + #[serde(skip_serializing_if = "Option::is_none")] 159 + pub record: Option<serde_json::Value>, 160 + } 161 + 162 + impl RecordEvent { 163 + /// Parse the record payload into a typed structure. 164 + /// 165 + /// This method deserializes the raw JSON on demand, avoiding 166 + /// unnecessary allocations when the record data isn't needed. 167 + /// 168 + /// # Errors 169 + /// 170 + /// Returns an error if the record is absent (delete events) or 171 + /// if deserialization fails. 172 + /// 173 + /// # Example 174 + /// 175 + /// ```ignore 176 + /// use serde::Deserialize; 177 + /// 178 + /// #[derive(Deserialize)] 179 + /// struct Post { 180 + /// text: String, 181 + /// #[serde(rename = "createdAt")] 182 + /// created_at: String, 183 + /// } 184 + /// 185 + /// let post: Post = record_event.parse_record()?; 186 + /// println!("Post text: {}", post.text); 187 + /// ``` 188 + pub fn parse_record<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> { 189 + match &self.record { 190 + Some(value) => serde_json::from_value(value.clone()), 191 + None => Err(serde::de::Error::custom("no record data (delete event)")), 192 + } 193 + } 194 + 195 + /// Returns the record as a JSON Value reference, if present. 196 + pub fn record_value(&self) -> Option<&serde_json::Value> { 197 + self.record.as_ref() 198 + } 199 + 200 + /// Returns true if this is a delete event. 201 + pub fn is_delete(&self) -> bool { 202 + self.action == RecordAction::Delete 203 + } 204 + 205 + /// Returns the AT-URI for this record. 206 + /// 207 + /// Format: `at://{did}/{collection}/{rkey}` 208 + pub fn at_uri(&self) -> String { 209 + format!("at://{}/{}/{}", self.did, self.collection, self.rkey) 210 + } 211 + } 212 + 213 + /// The action performed on a record. 214 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)] 215 + #[serde(rename_all = "lowercase")] 216 + pub enum RecordAction { 217 + /// A new record was created. 218 + Create, 219 + /// An existing record was updated. 220 + Update, 221 + /// A record was deleted. 222 + Delete, 223 + } 224 + 225 + impl std::fmt::Display for RecordAction { 226 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 227 + match self { 228 + RecordAction::Create => write!(f, "create"), 229 + RecordAction::Update => write!(f, "update"), 230 + RecordAction::Delete => write!(f, "delete"), 231 + } 232 + } 233 + } 234 + 235 + /// An identity change event from TAP. 236 + /// 237 + /// Contains information about handle or account status changes. 238 + #[derive(Debug, Clone, Serialize, Deserialize)] 239 + pub struct IdentityEvent { 240 + /// Actor DID. 241 + pub did: Box<str>, 242 + 243 + /// Current handle for the account. 244 + pub handle: Box<str>, 245 + 246 + /// Whether the account is currently active. 247 + #[serde(default)] 248 + pub is_active: bool, 249 + 250 + /// Account status. 251 + #[serde(default)] 252 + pub status: IdentityStatus, 253 + } 254 + 255 + /// Account status in an identity event. 256 + #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)] 257 + #[serde(rename_all = "lowercase")] 258 + pub enum IdentityStatus { 259 + /// Account is active and in good standing. 260 + #[default] 261 + Active, 262 + /// Account has been deactivated by the user. 263 + Deactivated, 264 + /// Account has been suspended. 265 + Suspended, 266 + /// Account has been deleted. 267 + Deleted, 268 + /// Account has been taken down. 269 + Takendown, 270 + } 271 + 272 + impl std::fmt::Display for IdentityStatus { 273 + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 274 + match self { 275 + IdentityStatus::Active => write!(f, "active"), 276 + IdentityStatus::Deactivated => write!(f, "deactivated"), 277 + IdentityStatus::Suspended => write!(f, "suspended"), 278 + IdentityStatus::Deleted => write!(f, "deleted"), 279 + IdentityStatus::Takendown => write!(f, "takendown"), 280 + } 281 + } 282 + } 283 + 284 + #[cfg(test)] 285 + mod tests { 286 + use super::*; 287 + 288 + #[test] 289 + fn test_parse_record_event() { 290 + let json = r#"{ 291 + "id": 12345, 292 + "type": "record", 293 + "record": { 294 + "live": true, 295 + "rev": "3lyileto4q52k", 296 + "did": "did:plc:z72i7hdynmk6r22z27h6tvur", 297 + "collection": "app.bsky.feed.post", 298 + "rkey": "3lyiletddxt2c", 299 + "action": "create", 300 + "cid": "bafyreigroo6vhxt62ufcndhaxzas6btq4jmniuz4egszbwuqgiyisqwqoy", 301 + "record": {"$type": "app.bsky.feed.post", "text": "Hello world!", "createdAt": "2025-01-01T00:00:00Z"} 302 + } 303 + }"#; 304 + 305 + let event: TapEvent = serde_json::from_str(json).expect("Failed to parse"); 306 + 307 + match event { 308 + TapEvent::Record { id, record } => { 309 + assert_eq!(id, 12345); 310 + assert!(record.live); 311 + assert_eq!(record.rev.as_str(), "3lyileto4q52k"); 312 + assert_eq!(&*record.did, "did:plc:z72i7hdynmk6r22z27h6tvur"); 313 + assert_eq!(&*record.collection, "app.bsky.feed.post"); 314 + assert_eq!(record.rkey.as_str(), "3lyiletddxt2c"); 315 + assert_eq!(record.action, RecordAction::Create); 316 + assert!(record.cid.is_some()); 317 + assert!(record.record.is_some()); 318 + 319 + // Test lazy parsing 320 + #[derive(Deserialize)] 321 + struct Post { 322 + text: String, 323 + } 324 + let post: Post = record.parse_record().expect("Failed to parse record"); 325 + assert_eq!(post.text, "Hello world!"); 326 + } 327 + _ => panic!("Expected Record event"), 328 + } 329 + } 330 + 331 + #[test] 332 + fn test_parse_delete_event() { 333 + let json = r#"{ 334 + "id": 12346, 335 + "type": "record", 336 + "record": { 337 + "live": true, 338 + "rev": "3lyileto4q52k", 339 + "did": "did:plc:z72i7hdynmk6r22z27h6tvur", 340 + "collection": "app.bsky.feed.post", 341 + "rkey": "3lyiletddxt2c", 342 + "action": "delete" 343 + } 344 + }"#; 345 + 346 + let event: TapEvent = serde_json::from_str(json).expect("Failed to parse"); 347 + 348 + match event { 349 + TapEvent::Record { id, record } => { 350 + assert_eq!(id, 12346); 351 + assert_eq!(record.action, RecordAction::Delete); 352 + assert!(record.is_delete()); 353 + assert!(record.cid.is_none()); 354 + assert!(record.record.is_none()); 355 + } 356 + _ => panic!("Expected Record event"), 357 + } 358 + } 359 + 360 + #[test] 361 + fn test_parse_identity_event() { 362 + let json = r#"{ 363 + "id": 12347, 364 + "type": "identity", 365 + "identity": { 366 + "did": "did:plc:z72i7hdynmk6r22z27h6tvur", 367 + "handle": "user.bsky.social", 368 + "is_active": true, 369 + "status": "active" 370 + } 371 + }"#; 372 + 373 + let event: TapEvent = serde_json::from_str(json).expect("Failed to parse"); 374 + 375 + match event { 376 + TapEvent::Identity { id, identity } => { 377 + assert_eq!(id, 12347); 378 + assert_eq!(&*identity.did, "did:plc:z72i7hdynmk6r22z27h6tvur"); 379 + assert_eq!(&*identity.handle, "user.bsky.social"); 380 + assert!(identity.is_active); 381 + assert_eq!(identity.status, IdentityStatus::Active); 382 + } 383 + _ => panic!("Expected Identity event"), 384 + } 385 + } 386 + 387 + #[test] 388 + fn test_record_action_display() { 389 + assert_eq!(RecordAction::Create.to_string(), "create"); 390 + assert_eq!(RecordAction::Update.to_string(), "update"); 391 + assert_eq!(RecordAction::Delete.to_string(), "delete"); 392 + } 393 + 394 + #[test] 395 + fn test_identity_status_display() { 396 + assert_eq!(IdentityStatus::Active.to_string(), "active"); 397 + assert_eq!(IdentityStatus::Deactivated.to_string(), "deactivated"); 398 + assert_eq!(IdentityStatus::Suspended.to_string(), "suspended"); 399 + assert_eq!(IdentityStatus::Deleted.to_string(), "deleted"); 400 + assert_eq!(IdentityStatus::Takendown.to_string(), "takendown"); 401 + } 402 + 403 + #[test] 404 + fn test_at_uri() { 405 + let record = RecordEvent { 406 + live: true, 407 + rev: "3lyileto4q52k".into(), 408 + did: "did:plc:xyz".into(), 409 + collection: "app.bsky.feed.post".into(), 410 + rkey: "abc123".into(), 411 + action: RecordAction::Create, 412 + cid: None, 413 + record: None, 414 + }; 415 + 416 + assert_eq!(record.at_uri(), "at://did:plc:xyz/app.bsky.feed.post/abc123"); 417 + } 418 + 419 + #[test] 420 + fn test_event_id() { 421 + let record_event = TapEvent::Record { 422 + id: 100, 423 + record: RecordEvent { 424 + live: true, 425 + rev: "rev".into(), 426 + did: "did".into(), 427 + collection: "col".into(), 428 + rkey: "rkey".into(), 429 + action: RecordAction::Create, 430 + cid: None, 431 + record: None, 432 + }, 433 + }; 434 + assert_eq!(record_event.id(), 100); 435 + 436 + let identity_event = TapEvent::Identity { 437 + id: 200, 438 + identity: IdentityEvent { 439 + did: "did".into(), 440 + handle: "handle".into(), 441 + is_active: true, 442 + status: IdentityStatus::Active, 443 + }, 444 + }; 445 + assert_eq!(identity_event.id(), 200); 446 + } 447 + 448 + #[test] 449 + fn test_extract_event_id_simple() { 450 + let json = r#"{"type":"record","id":12345,"record":{"deeply":"nested"}}"#; 451 + assert_eq!(extract_event_id(json), Some(12345)); 452 + } 453 + 454 + #[test] 455 + fn test_extract_event_id_at_end() { 456 + let json = r#"{"type":"record","record":{"deeply":"nested"},"id":99999}"#; 457 + assert_eq!(extract_event_id(json), Some(99999)); 458 + } 459 + 460 + #[test] 461 + fn test_extract_event_id_missing() { 462 + let json = r#"{"type":"record","record":{"deeply":"nested"}}"#; 463 + assert_eq!(extract_event_id(json), None); 464 + } 465 + 466 + #[test] 467 + fn test_extract_event_id_invalid_json() { 468 + let json = r#"{"type":"record","id":123"#; // Truncated JSON 469 + assert_eq!(extract_event_id(json), None); 470 + } 471 + 472 + #[test] 473 + fn test_extract_event_id_deeply_nested() { 474 + // Create a deeply nested JSON that would exceed serde_json's default recursion limit 475 + let mut json = String::from(r#"{"id":42,"record":{"nested":"#); 476 + for _ in 0..200 { 477 + json.push_str("["); 478 + } 479 + json.push_str("1"); 480 + for _ in 0..200 { 481 + json.push_str("]"); 482 + } 483 + json.push_str("}}"); 484 + 485 + // extract_event_id should still work because it uses IgnoredAny with disabled recursion limit 486 + assert_eq!(extract_event_id(&json), Some(42)); 487 + } 488 + }
+119
crates/atproto-tap/src/lib.rs
··· 1 + //! TAP (Trusted Attestation Protocol) service consumer for AT Protocol. 2 + //! 3 + //! This crate provides a client for consuming events from a TAP service, 4 + //! which delivers filtered, verified AT Protocol repository events. 5 + //! 6 + //! # Overview 7 + //! 8 + //! TAP is a single-tenant service that subscribes to an AT Protocol Relay and 9 + //! outputs filtered, verified events. Key features include: 10 + //! 11 + //! - **Verified Events**: MST integrity checks and signature verification 12 + //! - **Automatic Backfill**: Historical events delivered with `live: false` 13 + //! - **Repository Filtering**: Track specific DIDs or collections 14 + //! - **Acknowledgment Protocol**: At-least-once delivery semantics 15 + //! 16 + //! # Quick Start 17 + //! 18 + //! ```ignore 19 + //! use atproto_tap::{connect_to, TapEvent}; 20 + //! use tokio_stream::StreamExt; 21 + //! 22 + //! #[tokio::main] 23 + //! async fn main() { 24 + //! let mut stream = connect_to("localhost:2480"); 25 + //! 26 + //! while let Some(result) = stream.next().await { 27 + //! match result { 28 + //! Ok(event) => match event.as_ref() { 29 + //! TapEvent::Record { record, .. } => { 30 + //! println!("{} {} {}", record.action, record.collection, record.did); 31 + //! } 32 + //! TapEvent::Identity { identity, .. } => { 33 + //! println!("Identity: {} = {}", identity.did, identity.handle); 34 + //! } 35 + //! }, 36 + //! Err(e) => eprintln!("Error: {}", e), 37 + //! } 38 + //! } 39 + //! } 40 + //! ``` 41 + //! 42 + //! # Using with `tokio::select!` 43 + //! 44 + //! The stream integrates naturally with Tokio's select macro: 45 + //! 46 + //! ```ignore 47 + //! use atproto_tap::{connect, TapConfig}; 48 + //! use tokio_stream::StreamExt; 49 + //! use tokio::signal; 50 + //! 51 + //! #[tokio::main] 52 + //! async fn main() { 53 + //! let config = TapConfig::builder() 54 + //! .hostname("localhost:2480") 55 + //! .admin_password("secret") 56 + //! .build(); 57 + //! 58 + //! let mut stream = connect(config); 59 + //! 60 + //! loop { 61 + //! tokio::select! { 62 + //! Some(result) = stream.next() => { 63 + //! // Process event 64 + //! } 65 + //! _ = signal::ctrl_c() => { 66 + //! break; 67 + //! } 68 + //! } 69 + //! } 70 + //! } 71 + //! ``` 72 + //! 73 + //! # Management API 74 + //! 75 + //! Use [`TapClient`] to manage tracked repositories: 76 + //! 77 + //! ```ignore 78 + //! use atproto_tap::TapClient; 79 + //! 80 + //! let client = TapClient::new("localhost:2480", Some("password".to_string())); 81 + //! 82 + //! // Add repositories to track 83 + //! client.add_repos(&["did:plc:xyz123"]).await?; 84 + //! 85 + //! // Check service health 86 + //! if client.health().await? { 87 + //! println!("TAP service is healthy"); 88 + //! } 89 + //! ``` 90 + //! 91 + //! # Memory Efficiency 92 + //! 93 + //! This crate is optimized for high-throughput event processing: 94 + //! 95 + //! - **Arc-wrapped events**: Events are shared via `Arc` for zero-cost sharing 96 + //! - **CompactString**: Small strings use inline storage (no heap allocation) 97 + //! - **Box<str>**: Immutable strings without capacity overhead 98 + //! - **RawValue**: Record payloads are lazily parsed on demand 99 + //! - **Pre-allocated buffers**: Ack messages avoid per-message allocations 100 + 101 + #![forbid(unsafe_code)] 102 + #![warn(missing_docs)] 103 + 104 + mod client; 105 + mod config; 106 + mod connection; 107 + mod errors; 108 + mod events; 109 + mod stream; 110 + 111 + // Re-export public types 112 + pub use atproto_identity::model::{Document, Service, VerificationMethod}; 113 + pub use client::{RepoInfo, RepoState, TapClient}; 114 + #[allow(deprecated)] 115 + pub use client::RepoStatus; 116 + pub use config::{TapConfig, TapConfigBuilder}; 117 + pub use errors::TapError; 118 + pub use events::{IdentityEvent, IdentityStatus, RecordAction, RecordEvent, TapEvent, extract_event_id}; 119 + pub use stream::{TapStream, connect, connect_to};
+330
crates/atproto-tap/src/stream.rs
··· 1 + //! TAP event stream implementation. 2 + //! 3 + //! This module provides [`TapStream`], an async stream that yields TAP events 4 + //! with automatic connection management and reconnection handling. 5 + //! 6 + //! # Design 7 + //! 8 + //! The stream encapsulates all connection logic, allowing consumers to simply 9 + //! iterate over events using standard stream combinators or `tokio::select!`. 10 + //! 11 + //! Reconnection is handled automatically with exponential backoff. Parse errors 12 + //! are yielded as `Err` items but don't affect connection state - only connection 13 + //! errors trigger reconnection attempts. 14 + 15 + use crate::config::TapConfig; 16 + use crate::connection::TapConnection; 17 + use crate::errors::TapError; 18 + use crate::events::{TapEvent, extract_event_id}; 19 + use futures::Stream; 20 + use std::pin::Pin; 21 + use std::sync::Arc; 22 + use std::task::{Context, Poll}; 23 + use std::time::Duration; 24 + use tokio::sync::mpsc; 25 + 26 + /// An async stream of TAP events with automatic reconnection. 27 + /// 28 + /// `TapStream` implements [`Stream`] and yields `Result<Arc<TapEvent>, TapError>`. 29 + /// Events are wrapped in `Arc` for efficient zero-cost sharing across consumers. 30 + /// 31 + /// # Connection Management 32 + /// 33 + /// The stream automatically: 34 + /// - Connects on first poll 35 + /// - Reconnects with exponential backoff on connection errors 36 + /// - Sends acknowledgments after parsing each message (if enabled) 37 + /// - Yields parse errors without affecting connection state 38 + /// 39 + /// # Example 40 + /// 41 + /// ```ignore 42 + /// use atproto_tap::{TapConfig, TapStream}; 43 + /// use tokio_stream::StreamExt; 44 + /// 45 + /// let config = TapConfig::builder() 46 + /// .hostname("localhost:2480") 47 + /// .build(); 48 + /// 49 + /// let mut stream = TapStream::new(config); 50 + /// 51 + /// while let Some(result) = stream.next().await { 52 + /// match result { 53 + /// Ok(event) => println!("Event: {:?}", event), 54 + /// Err(e) => eprintln!("Error: {}", e), 55 + /// } 56 + /// } 57 + /// ``` 58 + pub struct TapStream { 59 + /// Receiver for events from the background task. 60 + receiver: mpsc::Receiver<Result<Arc<TapEvent>, TapError>>, 61 + /// Handle to request stream closure. 62 + close_sender: Option<mpsc::Sender<()>>, 63 + /// Whether the stream has been closed. 64 + closed: bool, 65 + } 66 + 67 + impl TapStream { 68 + /// Create a new TAP stream with the given configuration. 69 + /// 70 + /// The stream will start connecting immediately in a background task. 71 + pub fn new(config: TapConfig) -> Self { 72 + // Channel for events - buffer a few to handle bursts 73 + let (event_tx, event_rx) = mpsc::channel(32); 74 + // Channel for close signal 75 + let (close_tx, close_rx) = mpsc::channel(1); 76 + 77 + // Spawn background task to manage connection 78 + tokio::spawn(connection_task(config, event_tx, close_rx)); 79 + 80 + Self { 81 + receiver: event_rx, 82 + close_sender: Some(close_tx), 83 + closed: false, 84 + } 85 + } 86 + 87 + /// Close the stream and release resources. 88 + /// 89 + /// After calling this, the stream will yield `None` on the next poll. 90 + pub async fn close(&mut self) { 91 + if let Some(sender) = self.close_sender.take() { 92 + // Signal the background task to close 93 + let _ = sender.send(()).await; 94 + } 95 + self.closed = true; 96 + } 97 + 98 + /// Returns true if the stream is closed. 99 + pub fn is_closed(&self) -> bool { 100 + self.closed 101 + } 102 + } 103 + 104 + impl Stream for TapStream { 105 + type Item = Result<Arc<TapEvent>, TapError>; 106 + 107 + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 108 + if self.closed { 109 + return Poll::Ready(None); 110 + } 111 + 112 + self.receiver.poll_recv(cx) 113 + } 114 + } 115 + 116 + impl Drop for TapStream { 117 + fn drop(&mut self) { 118 + // Drop the close_sender to signal the background task 119 + self.close_sender.take(); 120 + tracing::debug!("TapStream dropped"); 121 + } 122 + } 123 + 124 + /// Background task that manages the WebSocket connection. 125 + async fn connection_task( 126 + config: TapConfig, 127 + event_tx: mpsc::Sender<Result<Arc<TapEvent>, TapError>>, 128 + mut close_rx: mpsc::Receiver<()>, 129 + ) { 130 + let mut current_reconnect_delay = config.initial_reconnect_delay; 131 + let mut attempt: u32 = 0; 132 + 133 + loop { 134 + // Check for close signal 135 + if close_rx.try_recv().is_ok() { 136 + tracing::debug!("Connection task received close signal"); 137 + break; 138 + } 139 + 140 + // Try to connect 141 + tracing::debug!(attempt, hostname = %config.hostname, "Connecting to TAP service"); 142 + let conn_result = TapConnection::connect(&config).await; 143 + 144 + match conn_result { 145 + Ok(mut conn) => { 146 + tracing::info!(hostname = %config.hostname, "TAP stream connected"); 147 + // Reset reconnection state on successful connect 148 + current_reconnect_delay = config.initial_reconnect_delay; 149 + attempt = 0; 150 + 151 + // Event loop for this connection 152 + loop { 153 + tokio::select! { 154 + biased; 155 + 156 + _ = close_rx.recv() => { 157 + tracing::debug!("Connection task received close signal during receive"); 158 + let _ = conn.close().await; 159 + return; 160 + } 161 + 162 + recv_result = conn.recv() => { 163 + match recv_result { 164 + Ok(Some(msg)) => { 165 + // Parse the message 166 + match serde_json::from_str::<TapEvent>(&msg) { 167 + Ok(event) => { 168 + let event_id = event.id(); 169 + 170 + // Send ack if enabled (before sending event to channel) 171 + if config.send_acks 172 + && let Err(err) = conn.send_ack(event_id).await 173 + { 174 + tracing::warn!(error = %err, "Failed to send ack"); 175 + // Don't break connection for ack errors 176 + } 177 + 178 + // Send event to channel 179 + let event = Arc::new(event); 180 + if event_tx.send(Ok(event)).await.is_err() { 181 + // Receiver dropped, exit task 182 + tracing::debug!("Event receiver dropped, closing connection"); 183 + let _ = conn.close().await; 184 + return; 185 + } 186 + } 187 + Err(err) => { 188 + // Parse errors don't affect connection 189 + tracing::warn!(error = %err, "Failed to parse TAP message"); 190 + 191 + // Try to extract just the ID using fallback parser 192 + // so we can still ack the message even if full parsing fails 193 + if config.send_acks { 194 + if let Some(event_id) = extract_event_id(&msg) { 195 + tracing::debug!(event_id, "Extracted event ID via fallback parser"); 196 + if let Err(ack_err) = conn.send_ack(event_id).await { 197 + tracing::warn!(error = %ack_err, "Failed to send ack for unparseable message"); 198 + } 199 + } else { 200 + tracing::warn!("Could not extract event ID from unparseable message"); 201 + } 202 + } 203 + 204 + if event_tx.send(Err(TapError::ParseError(err.to_string()))).await.is_err() { 205 + tracing::debug!("Event receiver dropped, closing connection"); 206 + let _ = conn.close().await; 207 + return; 208 + } 209 + } 210 + } 211 + } 212 + Ok(None) => { 213 + // Connection closed by server 214 + tracing::debug!("TAP connection closed by server"); 215 + break; 216 + } 217 + Err(err) => { 218 + // Connection error 219 + tracing::warn!(error = %err, "TAP connection error"); 220 + break; 221 + } 222 + } 223 + } 224 + } 225 + } 226 + } 227 + Err(err) => { 228 + tracing::warn!(error = %err, attempt, "Failed to connect to TAP service"); 229 + } 230 + } 231 + 232 + // Increment attempt counter 233 + attempt += 1; 234 + 235 + // Check if we've exceeded max attempts 236 + if let Some(max) = config.max_reconnect_attempts 237 + && attempt >= max 238 + { 239 + tracing::error!(attempts = attempt, "Max reconnection attempts exceeded"); 240 + let _ = event_tx 241 + .send(Err(TapError::MaxReconnectAttemptsExceeded(attempt))) 242 + .await; 243 + break; 244 + } 245 + 246 + // Wait before reconnecting with exponential backoff 247 + tracing::debug!( 248 + delay_ms = current_reconnect_delay.as_millis(), 249 + attempt, 250 + "Waiting before reconnection" 251 + ); 252 + 253 + tokio::select! { 254 + _ = close_rx.recv() => { 255 + tracing::debug!("Connection task received close signal during backoff"); 256 + return; 257 + } 258 + _ = tokio::time::sleep(current_reconnect_delay) => { 259 + // Update delay for next attempt 260 + current_reconnect_delay = Duration::from_secs_f64( 261 + (current_reconnect_delay.as_secs_f64() * config.reconnect_backoff_multiplier) 262 + .min(config.max_reconnect_delay.as_secs_f64()), 263 + ); 264 + } 265 + } 266 + } 267 + 268 + tracing::debug!("Connection task exiting"); 269 + } 270 + 271 + /// Create a new TAP stream with the given configuration. 272 + pub fn connect(config: TapConfig) -> TapStream { 273 + TapStream::new(config) 274 + } 275 + 276 + /// Create a new TAP stream connected to the given hostname. 277 + /// 278 + /// Uses default configuration values. 279 + pub fn connect_to(hostname: &str) -> TapStream { 280 + TapStream::new(TapConfig::new(hostname)) 281 + } 282 + 283 + #[cfg(test)] 284 + mod tests { 285 + use super::*; 286 + 287 + #[test] 288 + fn test_stream_initial_state() { 289 + // Note: This test doesn't actually poll the stream, just checks initial state 290 + // Creating a TapStream requires a tokio runtime for the spawn 291 + } 292 + 293 + #[tokio::test] 294 + async fn test_stream_close() { 295 + let mut stream = TapStream::new(TapConfig::new("localhost:9999")); 296 + assert!(!stream.is_closed()); 297 + stream.close().await; 298 + assert!(stream.is_closed()); 299 + } 300 + 301 + #[test] 302 + fn test_connect_functions() { 303 + // These just create configs, actual connection happens in background task 304 + // We can't test without a runtime, so just verify the types compile 305 + let _ = TapConfig::new("localhost:2480"); 306 + } 307 + 308 + #[test] 309 + fn test_reconnect_delay_calculation() { 310 + // Test the delay calculation logic 311 + let initial = Duration::from_secs(1); 312 + let max = Duration::from_secs(10); 313 + let multiplier = 2.0; 314 + 315 + let mut delay = initial; 316 + assert_eq!(delay, Duration::from_secs(1)); 317 + 318 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 319 + assert_eq!(delay, Duration::from_secs(2)); 320 + 321 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 322 + assert_eq!(delay, Duration::from_secs(4)); 323 + 324 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 325 + assert_eq!(delay, Duration::from_secs(8)); 326 + 327 + delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64())); 328 + assert_eq!(delay, Duration::from_secs(10)); // Capped at max 329 + } 330 + }
+13 -13
crates/atproto-xrpcs/README.md
··· 23 23 ### Basic XRPC Service 24 24 25 25 ```rust 26 - use atproto_xrpcs::authorization::ResolvingAuthorization; 26 + use atproto_xrpcs::authorization::Authorization; 27 27 use axum::{Json, Router, extract::Query, routing::get}; 28 28 use serde::Deserialize; 29 29 use serde_json::json; ··· 35 35 36 36 async fn handle_hello( 37 37 params: Query<HelloParams>, 38 - authorization: Option<ResolvingAuthorization>, 38 + authorization: Option<Authorization>, 39 39 ) -> Json<serde_json::Value> { 40 40 let name = params.name.as_deref().unwrap_or("World"); 41 - 41 + 42 42 let message = if authorization.is_some() { 43 43 format!("Hello, authenticated {}!", name) 44 44 } else { 45 45 format!("Hello, {}!", name) 46 46 }; 47 - 47 + 48 48 Json(json!({ "message": message })) 49 49 } 50 50 ··· 56 56 ### JWT Authorization 57 57 58 58 ```rust 59 - use atproto_xrpcs::authorization::ResolvingAuthorization; 59 + use atproto_xrpcs::authorization::Authorization; 60 60 61 61 async fn handle_secure_endpoint( 62 - authorization: ResolvingAuthorization, // Required authorization 62 + authorization: Authorization, // Required authorization 63 63 ) -> Json<serde_json::Value> { 64 - // The ResolvingAuthorization extractor automatically: 64 + // The Authorization extractor automatically: 65 65 // 1. Validates the JWT token 66 - // 2. Resolves the caller's DID document 66 + // 2. Resolves the caller's DID document 67 67 // 3. Verifies the signature against the DID document 68 68 // 4. Provides access to caller identity information 69 - 69 + 70 70 let caller_did = authorization.subject(); 71 71 Json(json!({"caller": caller_did, "status": "authenticated"})) 72 72 } ··· 79 79 use axum::{response::IntoResponse, http::StatusCode}; 80 80 81 81 async fn protected_handler( 82 - authorization: Result<ResolvingAuthorization, AuthorizationError>, 82 + authorization: Result<Authorization, AuthorizationError>, 83 83 ) -> impl IntoResponse { 84 84 match authorization { 85 85 Ok(auth) => (StatusCode::OK, "Access granted").into_response(), 86 - Err(AuthorizationError::InvalidJWTToken { .. }) => { 86 + Err(AuthorizationError::InvalidJWTFormat) => { 87 87 (StatusCode::UNAUTHORIZED, "Invalid token").into_response() 88 88 } 89 - Err(AuthorizationError::DIDDocumentResolutionFailed { .. }) => { 89 + Err(AuthorizationError::SubjectResolutionFailed { .. }) => { 90 90 (StatusCode::FORBIDDEN, "Identity verification failed").into_response() 91 91 } 92 92 Err(_) => { ··· 98 98 99 99 ## Authorization Flow 100 100 101 - The `ResolvingAuthorization` extractor implements: 101 + The `Authorization` extractor implements: 102 102 103 103 1. JWT extraction from HTTP Authorization headers 104 104 2. Token validation (signature and claims structure)
+42 -108
crates/atproto-xrpcs/src/authorization.rs
··· 1 1 //! JWT authorization extractors for XRPC services. 2 2 //! 3 - //! Axum extractors for JWT validation against DID documents with 4 - //! cached and resolving authorization modes. 3 + //! Axum extractors for JWT validation against DID documents resolved 4 + //! via an identity resolver. 5 5 6 6 use anyhow::Result; 7 7 use atproto_identity::key::identify_key; 8 - use atproto_identity::resolve::IdentityResolver; 9 - use atproto_identity::traits::DidDocumentStorage; 8 + use atproto_identity::traits::IdentityResolver; 10 9 use atproto_oauth::jwt::{Claims, Header}; 11 10 use axum::extract::{FromRef, OptionalFromRequestParts}; 12 11 use axum::http::request::Parts; ··· 17 16 18 17 use crate::errors::AuthorizationError; 19 18 20 - /// JWT authorization extractor that validates tokens against cached DID documents. 19 + /// JWT authorization extractor that validates tokens against DID documents. 21 20 /// 22 21 /// Contains JWT header, validated claims, original token, and validation status. 23 - /// Only validates against DID documents already present in storage. 22 + /// Resolves DID documents via the configured identity resolver. 23 + #[derive(Clone)] 24 24 pub struct Authorization(pub Header, pub Claims, pub String, pub bool); 25 25 26 - /// JWT authorization extractor with automatic DID document resolution. 27 - /// 28 - /// Contains JWT header, validated claims, original token, and validation status. 29 - /// Attempts to resolve missing DID documents from authoritative sources when needed. 30 - pub struct ResolvingAuthorization(pub Header, pub Claims, pub String, pub bool); 31 - 32 - impl<S> OptionalFromRequestParts<S> for Authorization 33 - where 34 - S: Send + Sync, 35 - Arc<dyn DidDocumentStorage>: FromRef<S>, 36 - { 37 - type Rejection = Infallible; 38 - 39 - async fn from_request_parts( 40 - parts: &mut Parts, 41 - state: &S, 42 - ) -> Result<Option<Self>, Self::Rejection> { 43 - let auth_header = parts 44 - .headers 45 - .get("authorization") 46 - .and_then(|value| value.to_str().ok()) 47 - .and_then(|s| s.strip_prefix("Bearer ")); 48 - 49 - let token = match auth_header { 50 - Some(token) => token.to_string(), 51 - None => { 52 - return Ok(None); 53 - } 54 - }; 55 - 56 - let did_document_storage = Arc::<dyn DidDocumentStorage>::from_ref(state); 57 - 58 - match validate_jwt(&token, did_document_storage, None).await { 59 - Ok((header, claims)) => Ok(Some(Authorization(header, claims, token, true))), 60 - Err(_) => { 61 - // Return unvalidated authorization so the handler can decide what to do 62 - let header = Header::default(); 63 - let claims = Claims::default(); 64 - Ok(Some(Authorization(header, claims, token, false))) 65 - } 26 + impl Authorization { 27 + /// identity returns the optional issuer claim of the authorization structure. 28 + pub fn identity(&self) -> Option<&str> { 29 + if self.3 { 30 + return self.1.jose.issuer.as_deref(); 66 31 } 32 + None 67 33 } 68 34 } 69 35 70 - impl<S> OptionalFromRequestParts<S> for ResolvingAuthorization 36 + impl<S> OptionalFromRequestParts<S> for Authorization 71 37 where 72 38 S: Send + Sync, 73 - Arc<dyn DidDocumentStorage>: FromRef<S>, 74 39 Arc<dyn IdentityResolver>: FromRef<S>, 75 40 { 76 41 type Rejection = Infallible; ··· 92 57 } 93 58 }; 94 59 95 - let did_document_storage = Arc::<dyn DidDocumentStorage>::from_ref(state); 96 60 let identity_resolver = Arc::<dyn IdentityResolver>::from_ref(state); 97 61 98 - match validate_jwt(&token, did_document_storage, Some(identity_resolver)).await { 99 - Ok((header, claims)) => Ok(Some(ResolvingAuthorization(header, claims, token, true))), 62 + match validate_jwt(&token, identity_resolver).await { 63 + Ok((header, claims)) => Ok(Some(Authorization(header, claims, token, true))), 100 64 Err(_) => { 101 65 // Return unvalidated authorization so the handler can decide what to do 102 66 let header = Header::default(); 103 67 let claims = Claims::default(); 104 - Ok(Some(ResolvingAuthorization(header, claims, token, false))) 68 + Ok(Some(Authorization(header, claims, token, false))) 105 69 } 106 70 } 107 71 } ··· 109 73 110 74 async fn validate_jwt( 111 75 token: &str, 112 - storage: Arc<dyn DidDocumentStorage + Send + Sync>, 113 - identity_resolver: Option<Arc<dyn IdentityResolver>>, 76 + identity_resolver: Arc<dyn IdentityResolver>, 114 77 ) -> Result<(Header, Claims)> { 115 78 // Split and decode JWT 116 79 let parts: Vec<&str> = token.split('.').collect(); ··· 134 97 .as_ref() 135 98 .ok_or_else(|| AuthorizationError::NoIssuerInClaims)?; 136 99 137 - // Try to look up DID document directly first 138 - let mut did_document = storage.get_document_by_did(issuer).await?; 139 - 140 - // If not found, try to resolve the subject 141 - if did_document.is_none() 142 - && let Some(identity_resolver) = identity_resolver 143 - { 144 - did_document = match identity_resolver.resolve(issuer).await { 145 - Ok(value) => { 146 - storage 147 - .store_document(value.clone()) 148 - .await 149 - .map_err(|err| AuthorizationError::DocumentStorageFailed { error: err })?; 150 - 151 - Some(value) 152 - } 153 - Err(err) => { 154 - return Err(AuthorizationError::SubjectResolutionFailed { 155 - issuer: issuer.to_string(), 156 - error: err, 157 - } 158 - .into()); 159 - } 160 - }; 161 - } 162 - 163 - let did_document = did_document.ok_or_else(|| AuthorizationError::DIDDocumentNotFound { 164 - issuer: issuer.to_string(), 100 + // Resolve the DID document via identity resolver 101 + let did_document = identity_resolver.resolve(issuer).await.map_err(|err| { 102 + AuthorizationError::SubjectResolutionFailed { 103 + issuer: issuer.to_string(), 104 + error: err, 105 + } 165 106 })?; 166 107 167 108 // Extract keys from DID document ··· 206 147 mod tests { 207 148 use super::*; 208 149 use atproto_identity::model::{Document, VerificationMethod}; 209 - use atproto_identity::traits::DidDocumentStorage; 210 150 use axum::extract::FromRef; 211 151 use axum::http::{Method, Request}; 212 152 use std::collections::HashMap; 213 153 214 154 #[derive(Clone)] 215 - struct MockStorage { 155 + struct MockResolver { 216 156 document: Document, 217 157 } 218 158 219 159 #[async_trait::async_trait] 220 - impl DidDocumentStorage for MockStorage { 221 - async fn get_document_by_did(&self, did: &str) -> Result<Option<Document>> { 222 - if did == self.document.id { 223 - Ok(Some(self.document.clone())) 160 + impl IdentityResolver for MockResolver { 161 + async fn resolve(&self, subject: &str) -> Result<Document> { 162 + if subject == self.document.id { 163 + Ok(self.document.clone()) 224 164 } else { 225 - Ok(None) 165 + Err(anyhow::anyhow!( 166 + "error-atproto-xrpcs-authorization-1 DID not found: {}", 167 + subject 168 + )) 226 169 } 227 170 } 228 - 229 - async fn store_document(&self, _document: Document) -> Result<()> { 230 - Ok(()) 231 - } 232 - 233 - async fn delete_document_by_did(&self, _did: &str) -> Result<()> { 234 - Ok(()) 235 - } 236 171 } 237 172 238 173 #[derive(Clone)] 239 174 struct TestState { 240 - storage: Arc<dyn DidDocumentStorage + Send + Sync>, 175 + resolver: Arc<dyn IdentityResolver>, 241 176 } 242 177 243 - impl FromRef<TestState> for Arc<dyn DidDocumentStorage> { 178 + impl FromRef<TestState> for Arc<dyn IdentityResolver> { 244 179 fn from_ref(state: &TestState) -> Self { 245 - state.storage.clone() 180 + state.resolver.clone() 246 181 } 247 182 } 248 183 ··· 266 201 extra: HashMap::new(), 267 202 }; 268 203 269 - // Create mock storage 270 - let storage = 271 - Arc::new(MockStorage { document }) as Arc<dyn DidDocumentStorage + Send + Sync>; 272 - let state = TestState { storage }; 204 + // Create mock resolver 205 + let resolver = Arc::new(MockResolver { document }) as Arc<dyn IdentityResolver>; 206 + let state = TestState { resolver }; 273 207 274 208 // Create request with Authorization header 275 209 let request = Request::builder() ··· 307 241 308 242 #[tokio::test] 309 243 async fn test_authorization_no_header() { 310 - // Create mock storage 311 - let storage = Arc::new(MockStorage { 244 + // Create mock resolver 245 + let resolver = Arc::new(MockResolver { 312 246 document: Document { 313 247 context: vec![], 314 248 id: "did:plc:test".to_string(), ··· 317 251 verification_method: vec![], 318 252 extra: HashMap::new(), 319 253 }, 320 - }) as Arc<dyn DidDocumentStorage + Send + Sync>; 321 - let state = TestState { storage }; 254 + }) as Arc<dyn IdentityResolver>; 255 + let state = TestState { resolver }; 322 256 323 257 // Create request without Authorization header 324 258 let request = Request::builder()
+5 -49
crates/atproto-xrpcs/src/errors.rs
··· 42 42 #[error("error-atproto-xrpcs-authorization-4 No issuer found in JWT claims")] 43 43 NoIssuerInClaims, 44 44 45 - /// Occurs when DID document is not found for the issuer 46 - #[error("error-atproto-xrpcs-authorization-5 DID document not found for issuer: {issuer}")] 47 - DIDDocumentNotFound { 48 - /// The issuer DID that was not found 49 - issuer: String, 50 - }, 51 - 52 45 /// Occurs when no verification keys are found in DID document 53 - #[error("error-atproto-xrpcs-authorization-6 No verification keys found in DID document")] 46 + #[error("error-atproto-xrpcs-authorization-5 No verification keys found in DID document")] 54 47 NoVerificationKeys, 55 48 56 49 /// Occurs when JWT header cannot be base64 decoded 57 - #[error("error-atproto-xrpcs-authorization-7 Failed to decode JWT header: {error}")] 50 + #[error("error-atproto-xrpcs-authorization-6 Failed to decode JWT header: {error}")] 58 51 HeaderDecodeError { 59 52 /// The underlying base64 decode error 60 53 error: base64::DecodeError, 61 54 }, 62 55 63 56 /// Occurs when JWT header cannot be parsed as JSON 64 - #[error("error-atproto-xrpcs-authorization-8 Failed to parse JWT header: {error}")] 57 + #[error("error-atproto-xrpcs-authorization-7 Failed to parse JWT header: {error}")] 65 58 HeaderParseError { 66 59 /// The underlying JSON parse error 67 60 error: serde_json::Error, 68 61 }, 69 62 70 63 /// Occurs when JWT validation fails with all available keys 71 - #[error("error-atproto-xrpcs-authorization-9 JWT validation failed with all available keys")] 64 + #[error("error-atproto-xrpcs-authorization-8 JWT validation failed with all available keys")] 72 65 ValidationFailedAllKeys, 73 66 74 67 /// Occurs when subject resolution fails during DID document lookup 75 - #[error("error-atproto-xrpcs-authorization-10 Subject resolution failed: {issuer} {error}")] 68 + #[error("error-atproto-xrpcs-authorization-9 Subject resolution failed: {issuer} {error}")] 76 69 SubjectResolutionFailed { 77 70 /// The issuer that failed to resolve 78 71 issuer: String, 79 72 /// The underlying resolution error 80 - error: anyhow::Error, 81 - }, 82 - 83 - /// Occurs when DID document lookup fails after successful resolution 84 - #[error( 85 - "error-atproto-xrpcs-authorization-11 DID document not found for resolved issuer: {resolved_did}" 86 - )] 87 - ResolvedDIDDocumentNotFound { 88 - /// The resolved DID that was not found in storage 89 - resolved_did: String, 90 - }, 91 - 92 - /// Occurs when PLC directory query fails 93 - #[error("error-atproto-xrpcs-authorization-12 PLC directory query failed: {error}")] 94 - PLCQueryFailed { 95 - /// The underlying PLC query error 96 - error: anyhow::Error, 97 - }, 98 - 99 - /// Occurs when web DID query fails 100 - #[error("error-atproto-xrpcs-authorization-13 Web DID query failed: {error}")] 101 - WebDIDQueryFailed { 102 - /// The underlying web DID query error 103 - error: anyhow::Error, 104 - }, 105 - 106 - /// Occurs when DID document storage operation fails 107 - #[error("error-atproto-xrpcs-authorization-14 DID document storage failed: {error}")] 108 - DocumentStorageFailed { 109 - /// The underlying storage error 110 - error: anyhow::Error, 111 - }, 112 - 113 - /// Occurs when input parsing fails for resolved DID 114 - #[error("error-atproto-xrpcs-authorization-15 Input parsing failed for resolved DID: {error}")] 115 - InputParsingFailed { 116 - /// The underlying parsing error 117 73 error: anyhow::Error, 118 74 }, 119 75 }
+3 -13
crates/atproto-xrpcs-helloworld/src/main.rs
··· 7 7 config::{CertificateBundles, DnsNameservers, default_env, optional_env, require_env, version}, 8 8 key::{KeyData, KeyResolver, identify_key, to_public}, 9 9 resolve::{HickoryDnsResolver, IdentityResolver, InnerIdentityResolver}, 10 - storage_lru::LruDidDocumentStorage, 11 - traits::DidDocumentStorage, 12 10 }; 13 - use atproto_xrpcs::authorization::ResolvingAuthorization; 11 + use atproto_xrpcs::authorization::Authorization; 14 12 use axum::{ 15 13 Json, Router, 16 14 extract::{FromRef, Query, State}, ··· 21 19 use http::{HeaderMap, StatusCode}; 22 20 use serde::Deserialize; 23 21 use serde_json::json; 24 - use std::{collections::HashMap, num::NonZeroUsize, ops::Deref, sync::Arc}; 22 + use std::{collections::HashMap, ops::Deref, sync::Arc}; 25 23 26 24 #[derive(Clone)] 27 25 pub struct SimpleKeyResolver { ··· 61 59 62 60 pub struct InnerWebContext { 63 61 pub http_client: reqwest::Client, 64 - pub document_storage: Arc<dyn DidDocumentStorage>, 65 62 pub key_resolver: Arc<dyn KeyResolver>, 66 63 pub service_document: ServiceDocument, 67 64 pub service_did: ServiceDID, ··· 97 94 } 98 95 } 99 96 100 - impl FromRef<WebContext> for Arc<dyn DidDocumentStorage> { 101 - fn from_ref(context: &WebContext) -> Self { 102 - context.0.document_storage.clone() 103 - } 104 - } 105 - 106 97 impl FromRef<WebContext> for Arc<dyn KeyResolver> { 107 98 fn from_ref(context: &WebContext) -> Self { 108 99 context.0.key_resolver.clone() ··· 216 207 217 208 let web_context = WebContext(Arc::new(InnerWebContext { 218 209 http_client: http_client.clone(), 219 - document_storage: Arc::new(LruDidDocumentStorage::new(NonZeroUsize::new(255).unwrap())), 220 210 key_resolver: Arc::new(SimpleKeyResolver { 221 211 keys: signing_key_storage, 222 212 }), ··· 284 274 async fn handle_xrpc_hello_world( 285 275 parameters: Query<HelloParameters>, 286 276 headers: HeaderMap, 287 - authorization: Option<ResolvingAuthorization>, 277 + authorization: Option<Authorization>, 288 278 ) -> Json<serde_json::Value> { 289 279 println!("headers {headers:?}"); 290 280 let subject = parameters.subject.as_deref().unwrap_or("World");