+91
Cargo.lock
+91
Cargo.lock
···
141
141
]
142
142
143
143
[[package]]
144
+
name = "atproto-extras"
145
+
version = "0.13.0"
146
+
dependencies = [
147
+
"anyhow",
148
+
"async-trait",
149
+
"atproto-identity",
150
+
"atproto-record",
151
+
"clap",
152
+
"regex",
153
+
"reqwest",
154
+
"serde_json",
155
+
"tokio",
156
+
]
157
+
158
+
[[package]]
144
159
name = "atproto-identity"
145
160
version = "0.13.0"
146
161
dependencies = [
···
307
322
]
308
323
309
324
[[package]]
325
+
name = "atproto-tap"
326
+
version = "0.13.0"
327
+
dependencies = [
328
+
"atproto-client",
329
+
"atproto-identity",
330
+
"base64",
331
+
"clap",
332
+
"compact_str",
333
+
"futures",
334
+
"http",
335
+
"itoa",
336
+
"reqwest",
337
+
"serde",
338
+
"serde_json",
339
+
"thiserror 2.0.17",
340
+
"tokio",
341
+
"tokio-stream",
342
+
"tokio-websockets",
343
+
"tracing",
344
+
"tracing-subscriber",
345
+
]
346
+
347
+
[[package]]
310
348
name = "atproto-xrpcs"
311
349
version = "0.13.0"
312
350
dependencies = [
···
491
529
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
492
530
493
531
[[package]]
532
+
name = "castaway"
533
+
version = "0.2.4"
534
+
source = "registry+https://github.com/rust-lang/crates.io-index"
535
+
checksum = "dec551ab6e7578819132c713a93c022a05d60159dc86e7a7050223577484c55a"
536
+
dependencies = [
537
+
"rustversion",
538
+
]
539
+
540
+
[[package]]
494
541
name = "cbor4ii"
495
542
version = "0.2.14"
496
543
source = "registry+https://github.com/rust-lang/crates.io-index"
···
592
639
version = "1.0.4"
593
640
source = "registry+https://github.com/rust-lang/crates.io-index"
594
641
checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75"
642
+
643
+
[[package]]
644
+
name = "compact_str"
645
+
version = "0.8.1"
646
+
source = "registry+https://github.com/rust-lang/crates.io-index"
647
+
checksum = "3b79c4069c6cad78e2e0cdfcbd26275770669fb39fd308a752dc110e83b9af32"
648
+
dependencies = [
649
+
"castaway",
650
+
"cfg-if",
651
+
"itoa",
652
+
"rustversion",
653
+
"ryu",
654
+
"serde",
655
+
"static_assertions",
656
+
]
595
657
596
658
[[package]]
597
659
name = "const-oid"
···
1879
1941
]
1880
1942
1881
1943
[[package]]
1944
+
name = "regex"
1945
+
version = "1.12.2"
1946
+
source = "registry+https://github.com/rust-lang/crates.io-index"
1947
+
checksum = "843bc0191f75f3e22651ae5f1e72939ab2f72a4bc30fa80a066bd66edefc24d4"
1948
+
dependencies = [
1949
+
"aho-corasick",
1950
+
"memchr",
1951
+
"regex-automata",
1952
+
"regex-syntax",
1953
+
]
1954
+
1955
+
[[package]]
1882
1956
name = "regex-automata"
1883
1957
version = "0.4.13"
1884
1958
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2358
2432
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
2359
2433
2360
2434
[[package]]
2435
+
name = "static_assertions"
2436
+
version = "1.1.0"
2437
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2438
+
checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f"
2439
+
2440
+
[[package]]
2361
2441
name = "strsim"
2362
2442
version = "0.11.1"
2363
2443
source = "registry+https://github.com/rust-lang/crates.io-index"
···
2547
2627
checksum = "1729aa945f29d91ba541258c8df89027d5792d85a8841fb65e8bf0f4ede4ef61"
2548
2628
dependencies = [
2549
2629
"rustls",
2630
+
"tokio",
2631
+
]
2632
+
2633
+
[[package]]
2634
+
name = "tokio-stream"
2635
+
version = "0.1.17"
2636
+
source = "registry+https://github.com/rust-lang/crates.io-index"
2637
+
checksum = "eca58d7bba4a75707817a2c44174253f9236b2d5fbd055602e9d5c07c139a047"
2638
+
dependencies = [
2639
+
"futures-core",
2640
+
"pin-project-lite",
2550
2641
"tokio",
2551
2642
]
2552
2643
+10
-4
Cargo.toml
+10
-4
Cargo.toml
···
1
1
[workspace]
2
2
members = [
3
3
"crates/atproto-client",
4
+
"crates/atproto-extras",
4
5
"crates/atproto-identity",
5
6
"crates/atproto-jetstream",
6
7
"crates/atproto-oauth-aip",
7
8
"crates/atproto-oauth-axum",
8
9
"crates/atproto-oauth",
9
10
"crates/atproto-record",
11
+
"crates/atproto-tap",
10
12
"crates/atproto-xrpcs-helloworld",
11
13
"crates/atproto-xrpcs",
12
14
"crates/atproto-lexicon",
···
24
26
categories = ["command-line-utilities", "web-programming"]
25
27
26
28
[workspace.dependencies]
29
+
atproto-attestation = { version = "0.13.0", path = "crates/atproto-attestation" }
27
30
atproto-client = { version = "0.13.0", path = "crates/atproto-client" }
31
+
atproto-extras = { version = "0.13.0", path = "crates/atproto-extras" }
28
32
atproto-identity = { version = "0.13.0", path = "crates/atproto-identity" }
33
+
atproto-jetstream = { version = "0.13.0", path = "crates/atproto-jetstream" }
29
34
atproto-oauth = { version = "0.13.0", path = "crates/atproto-oauth" }
30
-
atproto-oauth-axum = { version = "0.13.0", path = "crates/atproto-oauth-axum" }
31
35
atproto-oauth-aip = { version = "0.13.0", path = "crates/atproto-oauth-aip" }
36
+
atproto-oauth-axum = { version = "0.13.0", path = "crates/atproto-oauth-axum" }
32
37
atproto-record = { version = "0.13.0", path = "crates/atproto-record" }
38
+
atproto-tap = { version = "0.13.0", path = "crates/atproto-tap" }
33
39
atproto-xrpcs = { version = "0.13.0", path = "crates/atproto-xrpcs" }
34
-
atproto-jetstream = { version = "0.13.0", path = "crates/atproto-jetstream" }
35
-
atproto-attestation = { version = "0.13.0", path = "crates/atproto-attestation" }
36
40
41
+
ammonia = "4.0"
37
42
anyhow = "1.0"
38
43
async-trait = "0.1"
39
44
base64 = "0.22"
···
50
55
p256 = "0.13"
51
56
p384 = "0.13"
52
57
rand = "0.8"
58
+
regex = "1.11"
53
59
reqwest = { version = "0.12", default-features = false, features = ["charset", "http2", "system-proxy", "json", "rustls-tls"] }
54
60
reqwest-chain = "1.0"
55
61
reqwest-middleware = { version = "0.4", features = ["json", "multipart"]}
···
57
63
secrecy = { version = "0.10", features = ["serde"] }
58
64
serde = { version = "1.0", features = ["derive"] }
59
65
serde_ipld_dagcbor = "0.6"
60
-
serde_json = "1.0"
66
+
serde_json = { version = "1.0", features = ["unbounded_depth"] }
61
67
sha2 = "0.10"
62
68
thiserror = "2.0"
63
69
tokio = { version = "1.41", features = ["macros", "rt", "rt-multi-thread"] }
+4
-4
README.md
+4
-4
README.md
···
131
131
### XRPC Service
132
132
133
133
```rust
134
-
use atproto_xrpcs::authorization::ResolvingAuthorization;
134
+
use atproto_xrpcs::authorization::Authorization;
135
135
use axum::{Json, Router, extract::Query, routing::get};
136
136
use serde::Deserialize;
137
137
use serde_json::json;
···
143
143
144
144
async fn handle_hello(
145
145
params: Query<HelloParams>,
146
-
authorization: Option<ResolvingAuthorization>,
146
+
authorization: Option<Authorization>,
147
147
) -> Json<serde_json::Value> {
148
148
let subject = params.subject.as_deref().unwrap_or("World");
149
-
149
+
150
150
let message = if let Some(auth) = authorization {
151
151
format!("Hello, authenticated {}! (caller: {})", subject, auth.subject())
152
152
} else {
153
153
format!("Hello, {}!", subject)
154
154
};
155
-
155
+
156
156
Json(json!({ "message": message }))
157
157
}
158
158
+19
-6
crates/atproto-attestation/src/bin/atproto-attestation-verify.rs
+19
-6
crates/atproto-attestation/src/bin/atproto-attestation-verify.rs
···
47
47
48
48
use anyhow::{Context, Result, anyhow};
49
49
use atproto_attestation::AnyInput;
50
-
use atproto_identity::key::{KeyData, KeyResolver};
50
+
use atproto_identity::key::{KeyData, KeyResolver, identify_key};
51
51
use clap::Parser;
52
52
use serde_json::Value;
53
53
use std::{
···
115
115
attestation: Option<String>,
116
116
}
117
117
118
-
struct FakeKeyResolver {}
118
+
/// A key resolver that supports `did:key:` identifiers directly.
119
+
///
120
+
/// This resolver handles key references that are encoded as `did:key:` strings,
121
+
/// parsing them to extract the cryptographic key data. For other DID methods,
122
+
/// it returns an error since those would require fetching DID documents.
123
+
struct DidKeyResolver {}
119
124
120
125
#[async_trait::async_trait]
121
-
impl KeyResolver for FakeKeyResolver {
122
-
async fn resolve(&self, _subject: &str) -> Result<KeyData> {
123
-
todo!()
126
+
impl KeyResolver for DidKeyResolver {
127
+
async fn resolve(&self, subject: &str) -> Result<KeyData> {
128
+
if subject.starts_with("did:key:") {
129
+
identify_key(subject)
130
+
.map_err(|e| anyhow!("Failed to parse did:key '{}': {}", subject, e))
131
+
} else {
132
+
Err(anyhow!(
133
+
"Subject '{}' is not a did:key: identifier. Only did:key: subjects are supported by this resolver.",
134
+
subject
135
+
))
136
+
}
124
137
}
125
138
}
126
139
···
175
188
identity_resolver,
176
189
};
177
190
178
-
let key_resolver = FakeKeyResolver {};
191
+
let key_resolver = DidKeyResolver {};
179
192
180
193
atproto_attestation::verify_record(
181
194
AnyInput::Serialize(record.clone()),
+6
crates/atproto-client/Cargo.toml
+6
crates/atproto-client/Cargo.toml
+165
crates/atproto-client/src/bin/atproto-client-put-record.rs
+165
crates/atproto-client/src/bin/atproto-client-put-record.rs
···
1
+
//! AT Protocol client tool for writing records to a repository.
2
+
//!
3
+
//! This binary tool creates or updates records in an AT Protocol repository
4
+
//! using app password authentication. It resolves the subject to a DID,
5
+
//! creates a session, and writes the record using the putRecord XRPC method.
6
+
//!
7
+
//! # Usage
8
+
//!
9
+
//! ```text
10
+
//! ATPROTO_PASSWORD=<password> atproto-client-put-record <subject> <record_key> <record_json>
11
+
//! ```
12
+
//!
13
+
//! # Environment Variables
14
+
//!
15
+
//! - `ATPROTO_PASSWORD` - Required. App password for authentication.
16
+
//! - `CERTIFICATE_BUNDLES` - Custom CA certificate bundles.
17
+
//! - `USER_AGENT` - Custom user agent string.
18
+
//! - `DNS_NAMESERVERS` - Custom DNS nameservers.
19
+
//! - `PLC_HOSTNAME` - Override PLC hostname (default: plc.directory).
20
+
21
+
use anyhow::Result;
22
+
use atproto_client::{
23
+
client::{AppPasswordAuth, Auth},
24
+
com::atproto::{
25
+
repo::{put_record, PutRecordRequest, PutRecordResponse},
26
+
server::create_session,
27
+
},
28
+
errors::CliError,
29
+
};
30
+
use atproto_identity::{
31
+
config::{CertificateBundles, DnsNameservers, default_env, optional_env, version},
32
+
plc,
33
+
resolve::{HickoryDnsResolver, resolve_subject},
34
+
web,
35
+
};
36
+
use std::env;
37
+
38
+
fn print_usage() {
39
+
eprintln!("Usage: atproto-client-put-record <subject> <record_key> <record_json>");
40
+
eprintln!();
41
+
eprintln!("Arguments:");
42
+
eprintln!(" <subject> Handle or DID of the repository owner");
43
+
eprintln!(" <record_key> Record key (rkey) for the record");
44
+
eprintln!(" <record_json> JSON record data (must include $type field)");
45
+
eprintln!();
46
+
eprintln!("Environment Variables:");
47
+
eprintln!(" ATPROTO_PASSWORD Required. App password for authentication.");
48
+
eprintln!(" CERTIFICATE_BUNDLES Custom CA certificate bundles.");
49
+
eprintln!(" USER_AGENT Custom user agent string.");
50
+
eprintln!(" DNS_NAMESERVERS Custom DNS nameservers.");
51
+
eprintln!(" PLC_HOSTNAME Override PLC hostname (default: plc.directory).");
52
+
}
53
+
54
+
#[tokio::main]
55
+
async fn main() -> Result<()> {
56
+
let args: Vec<String> = env::args().collect();
57
+
58
+
if args.len() != 4 {
59
+
print_usage();
60
+
std::process::exit(1);
61
+
}
62
+
63
+
let subject = &args[1];
64
+
let record_key = &args[2];
65
+
let record_json = &args[3];
66
+
67
+
// Get password from environment variable
68
+
let password = env::var("ATPROTO_PASSWORD").map_err(|_| {
69
+
anyhow::anyhow!("ATPROTO_PASSWORD environment variable is required")
70
+
})?;
71
+
72
+
// Set up HTTP client configuration
73
+
let certificate_bundles: CertificateBundles = optional_env("CERTIFICATE_BUNDLES").try_into()?;
74
+
let default_user_agent = format!(
75
+
"atproto-identity-rs ({}; +https://tangled.sh/@smokesignal.events/atproto-identity-rs)",
76
+
version()?
77
+
);
78
+
let user_agent = default_env("USER_AGENT", &default_user_agent);
79
+
let dns_nameservers: DnsNameservers = optional_env("DNS_NAMESERVERS").try_into()?;
80
+
let plc_hostname = default_env("PLC_HOSTNAME", "plc.directory");
81
+
82
+
let mut client_builder = reqwest::Client::builder();
83
+
for ca_certificate in certificate_bundles.as_ref() {
84
+
let cert = std::fs::read(ca_certificate)?;
85
+
let cert = reqwest::Certificate::from_pem(&cert)?;
86
+
client_builder = client_builder.add_root_certificate(cert);
87
+
}
88
+
89
+
client_builder = client_builder.user_agent(user_agent);
90
+
let http_client = client_builder.build()?;
91
+
92
+
let dns_resolver = HickoryDnsResolver::create_resolver(dns_nameservers.as_ref());
93
+
94
+
// Parse the record JSON
95
+
let record: serde_json::Value = serde_json::from_str(record_json).map_err(|err| {
96
+
tracing::error!(error = ?err, "Failed to parse record JSON");
97
+
anyhow::anyhow!("Failed to parse record JSON: {}", err)
98
+
})?;
99
+
100
+
// Extract collection from $type field
101
+
let collection = record
102
+
.get("$type")
103
+
.and_then(|v| v.as_str())
104
+
.ok_or_else(|| anyhow::anyhow!("Record must contain a $type field for the collection"))?
105
+
.to_string();
106
+
107
+
// Resolve subject to DID
108
+
let did = resolve_subject(&http_client, &dns_resolver, subject).await?;
109
+
110
+
// Get DID document to find PDS endpoint
111
+
let document = if did.starts_with("did:plc:") {
112
+
plc::query(&http_client, &plc_hostname, &did).await?
113
+
} else if did.starts_with("did:web:") {
114
+
web::query(&http_client, &did).await?
115
+
} else {
116
+
anyhow::bail!("Unsupported DID method: {}", did);
117
+
};
118
+
119
+
// Get PDS endpoint from the DID document
120
+
let pds_endpoints = document.pds_endpoints();
121
+
let pds_endpoint = pds_endpoints
122
+
.first()
123
+
.ok_or_else(|| CliError::NoPdsEndpointFound { did: did.clone() })?;
124
+
125
+
// Create session
126
+
let session = create_session(&http_client, pds_endpoint, &did, &password, None).await?;
127
+
128
+
// Set up app password authentication
129
+
let auth = Auth::AppPassword(AppPasswordAuth {
130
+
access_token: session.access_jwt.clone(),
131
+
});
132
+
133
+
// Create put record request
134
+
let put_request = PutRecordRequest {
135
+
repo: session.did.clone(),
136
+
collection,
137
+
record_key: record_key.clone(),
138
+
validate: true,
139
+
record,
140
+
swap_commit: None,
141
+
swap_record: None,
142
+
};
143
+
144
+
// Execute put record
145
+
let response = put_record(&http_client, &auth, pds_endpoint, put_request).await?;
146
+
147
+
match response {
148
+
PutRecordResponse::StrongRef { uri, cid, .. } => {
149
+
println!(
150
+
"{}",
151
+
serde_json::to_string_pretty(&serde_json::json!({
152
+
"uri": uri,
153
+
"cid": cid
154
+
}))?
155
+
);
156
+
}
157
+
PutRecordResponse::Error(err) => {
158
+
let error_message = err.error_message();
159
+
tracing::error!(error = %error_message, "putRecord failed");
160
+
anyhow::bail!("putRecord failed: {}", error_message);
161
+
}
162
+
}
163
+
164
+
Ok(())
165
+
}
+43
crates/atproto-extras/Cargo.toml
+43
crates/atproto-extras/Cargo.toml
···
1
+
[package]
2
+
name = "atproto-extras"
3
+
version = "0.13.0"
4
+
description = "AT Protocol extras - facet parsing and rich text utilities"
5
+
readme = "README.md"
6
+
homepage = "https://tangled.sh/@smokesignal.events/atproto-identity-rs"
7
+
documentation = "https://docs.rs/atproto-extras"
8
+
9
+
edition.workspace = true
10
+
rust-version.workspace = true
11
+
authors.workspace = true
12
+
repository.workspace = true
13
+
license.workspace = true
14
+
keywords.workspace = true
15
+
categories.workspace = true
16
+
17
+
[dependencies]
18
+
atproto-identity.workspace = true
19
+
atproto-record.workspace = true
20
+
21
+
anyhow.workspace = true
22
+
async-trait.workspace = true
23
+
clap = { workspace = true, optional = true }
24
+
regex.workspace = true
25
+
reqwest = { workspace = true, optional = true }
26
+
serde_json = { workspace = true, optional = true }
27
+
tokio = { workspace = true, optional = true }
28
+
29
+
[dev-dependencies]
30
+
tokio = { workspace = true, features = ["macros", "rt"] }
31
+
32
+
[features]
33
+
default = ["hickory-dns"]
34
+
hickory-dns = ["atproto-identity/hickory-dns"]
35
+
clap = ["dep:clap"]
36
+
cli = ["dep:clap", "dep:serde_json", "dep:tokio", "dep:reqwest"]
37
+
38
+
[[bin]]
39
+
name = "atproto-extras-parse-facets"
40
+
required-features = ["clap", "cli", "hickory-dns"]
41
+
42
+
[lints]
43
+
workspace = true
+128
crates/atproto-extras/README.md
+128
crates/atproto-extras/README.md
···
1
+
# atproto-extras
2
+
3
+
Extra utilities for AT Protocol applications, including rich text facet parsing.
4
+
5
+
## Features
6
+
7
+
- **Facet Parsing**: Extract mentions (`@handle`), URLs, and hashtags (`#tag`) from plain text with correct UTF-8 byte offset calculation
8
+
- **Identity Integration**: Resolve mention handles to DIDs during parsing
9
+
10
+
## Installation
11
+
12
+
Add to your `Cargo.toml`:
13
+
14
+
```toml
15
+
[dependencies]
16
+
atproto-extras = "0.13"
17
+
```
18
+
19
+
## Usage
20
+
21
+
### Parsing Text for Facets
22
+
23
+
```rust
24
+
use atproto_extras::{parse_urls, parse_tags};
25
+
use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
26
+
27
+
let text = "Check out https://example.com #rust";
28
+
29
+
// Parse URLs and tags - returns Vec<Facet> directly
30
+
let url_facets = parse_urls(text);
31
+
let tag_facets = parse_tags(text);
32
+
33
+
// Each facet includes byte positions and typed features
34
+
for facet in url_facets {
35
+
if let Some(FacetFeature::Link(link)) = facet.features.first() {
36
+
println!("URL at bytes {}..{}: {}",
37
+
facet.index.byte_start, facet.index.byte_end, link.uri);
38
+
}
39
+
}
40
+
41
+
for facet in tag_facets {
42
+
if let Some(FacetFeature::Tag(tag)) = facet.features.first() {
43
+
println!("Tag at bytes {}..{}: #{}",
44
+
facet.index.byte_start, facet.index.byte_end, tag.tag);
45
+
}
46
+
}
47
+
```
48
+
49
+
### Parsing Mentions
50
+
51
+
Mention parsing requires an `IdentityResolver` to convert handles to DIDs:
52
+
53
+
```rust
54
+
use atproto_extras::{parse_mentions, FacetLimits};
55
+
use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
56
+
57
+
let text = "Hello @alice.bsky.social!";
58
+
let limits = FacetLimits::default();
59
+
60
+
// Requires an async context and IdentityResolver
61
+
let facets = parse_mentions(text, &resolver, &limits).await;
62
+
63
+
for facet in facets {
64
+
if let Some(FacetFeature::Mention(mention)) = facet.features.first() {
65
+
println!("Mention at bytes {}..{} resolved to {}",
66
+
facet.index.byte_start, facet.index.byte_end, mention.did);
67
+
}
68
+
}
69
+
```
70
+
71
+
Mentions that cannot be resolved to a valid DID are automatically skipped. Mentions appearing within URLs are also excluded.
72
+
73
+
### Creating AT Protocol Facets
74
+
75
+
```rust
76
+
use atproto_extras::{parse_facets_from_text, FacetLimits};
77
+
78
+
let text = "Hello @alice.bsky.social! Check https://rust-lang.org #rust";
79
+
let limits = FacetLimits::default();
80
+
81
+
// Requires an async context and IdentityResolver
82
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
83
+
84
+
if let Some(facets) = facets {
85
+
for facet in &facets {
86
+
println!("Facet at {}..{}", facet.index.byte_start, facet.index.byte_end);
87
+
}
88
+
}
89
+
```
90
+
91
+
## Byte Offset Handling
92
+
93
+
AT Protocol facets use UTF-8 byte offsets, not character indices. This is critical for correct handling of multi-byte characters like emojis or non-ASCII text.
94
+
95
+
```rust
96
+
use atproto_extras::parse_urls;
97
+
98
+
// Text with emojis (multi-byte UTF-8 characters)
99
+
let text = "โจ Check https://example.com โจ";
100
+
101
+
let facets = parse_urls(text);
102
+
// Byte positions correctly account for the 4-byte emoji
103
+
assert_eq!(facets[0].index.byte_start, 11); // After "โจ Check " (4 + 1 + 6 = 11 bytes)
104
+
```
105
+
106
+
## Facet Limits
107
+
108
+
Use `FacetLimits` to control the maximum number of facets processed:
109
+
110
+
```rust
111
+
use atproto_extras::FacetLimits;
112
+
113
+
// Default limits
114
+
let limits = FacetLimits::default();
115
+
// mentions_max: 5, tags_max: 5, links_max: 5, max: 10
116
+
117
+
// Custom limits
118
+
let custom = FacetLimits {
119
+
mentions_max: 10,
120
+
tags_max: 10,
121
+
links_max: 10,
122
+
max: 20,
123
+
};
124
+
```
125
+
126
+
## License
127
+
128
+
MIT
+176
crates/atproto-extras/src/bin/atproto-extras-parse-facets.rs
+176
crates/atproto-extras/src/bin/atproto-extras-parse-facets.rs
···
1
+
//! Command-line tool for generating AT Protocol facet arrays from text.
2
+
//!
3
+
//! This tool parses a string and outputs the facet array in JSON format.
4
+
//! Facets include mentions (@handle), URLs (https://...), and hashtags (#tag).
5
+
//!
6
+
//! By default, mentions are detected but output with placeholder DIDs. Use
7
+
//! `--resolve-mentions` to resolve handles to actual DIDs (requires network access).
8
+
//!
9
+
//! # Usage
10
+
//!
11
+
//! ```bash
12
+
//! # Parse facets without resolving mentions
13
+
//! cargo run --features clap,serde_json,tokio,hickory-dns --bin atproto-extras-parse-facets -- "Check out https://example.com and #rust"
14
+
//!
15
+
//! # Resolve mentions to DIDs
16
+
//! cargo run --features clap,serde_json,tokio,hickory-dns --bin atproto-extras-parse-facets -- --resolve-mentions "Hello @bsky.app!"
17
+
//! ```
18
+
19
+
use atproto_extras::{FacetLimits, parse_mentions, parse_tags, parse_urls};
20
+
use atproto_identity::resolve::{HickoryDnsResolver, InnerIdentityResolver};
21
+
use atproto_record::lexicon::app::bsky::richtext::facet::{
22
+
ByteSlice, Facet, FacetFeature, Mention,
23
+
};
24
+
use clap::Parser;
25
+
use regex::bytes::Regex;
26
+
use std::sync::Arc;
27
+
28
+
/// Parse text and output AT Protocol facets as JSON.
29
+
#[derive(Parser)]
30
+
#[command(
31
+
name = "atproto-extras-parse-facets",
32
+
version,
33
+
about = "Parse text and output AT Protocol facets as JSON",
34
+
long_about = "This tool parses a string for mentions, URLs, and hashtags,\n\
35
+
then outputs the corresponding AT Protocol facet array in JSON format.\n\n\
36
+
By default, mentions are detected but output with placeholder DIDs.\n\
37
+
Use --resolve-mentions to resolve handles to actual DIDs (requires network)."
38
+
)]
39
+
struct Args {
40
+
/// The text to parse for facets
41
+
text: String,
42
+
43
+
/// Resolve mention handles to DIDs (requires network access)
44
+
#[arg(long)]
45
+
resolve_mentions: bool,
46
+
47
+
/// Show debug information on stderr
48
+
#[arg(long, short = 'd')]
49
+
debug: bool,
50
+
}
51
+
52
+
/// Parse mention spans from text without resolution (returns placeholder DIDs).
53
+
fn parse_mention_spans(text: &str) -> Vec<Facet> {
54
+
let mut facets = Vec::new();
55
+
56
+
// Get URL ranges to exclude mentions within URLs
57
+
let url_facets = parse_urls(text);
58
+
59
+
// Same regex pattern as parse_mentions
60
+
let mention_regex = Regex::new(
61
+
r"(?:^|[^\w])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)",
62
+
)
63
+
.expect("Invalid mention regex");
64
+
65
+
let text_bytes = text.as_bytes();
66
+
67
+
for capture in mention_regex.captures_iter(text_bytes) {
68
+
if let Some(mention_match) = capture.get(1) {
69
+
let start = mention_match.start();
70
+
let end = mention_match.end();
71
+
72
+
// Check if this mention overlaps with any URL
73
+
let overlaps_url = url_facets.iter().any(|facet| {
74
+
(start >= facet.index.byte_start && start < facet.index.byte_end)
75
+
|| (end > facet.index.byte_start && end <= facet.index.byte_end)
76
+
});
77
+
78
+
if !overlaps_url {
79
+
let handle = std::str::from_utf8(&mention_match.as_bytes()[1..])
80
+
.unwrap_or_default()
81
+
.to_string();
82
+
83
+
facets.push(Facet {
84
+
index: ByteSlice {
85
+
byte_start: start,
86
+
byte_end: end,
87
+
},
88
+
features: vec![FacetFeature::Mention(Mention {
89
+
did: format!("did:plc:<unresolved:{}>", handle),
90
+
})],
91
+
});
92
+
}
93
+
}
94
+
}
95
+
96
+
facets
97
+
}
98
+
99
+
#[tokio::main]
100
+
async fn main() {
101
+
let args = Args::parse();
102
+
let text = &args.text;
103
+
let mut facets: Vec<Facet> = Vec::new();
104
+
let limits = FacetLimits::default();
105
+
106
+
// Parse mentions (either resolved or with placeholders)
107
+
if args.resolve_mentions {
108
+
let http_client = reqwest::Client::new();
109
+
let dns_resolver = HickoryDnsResolver::create_resolver(&[]);
110
+
let resolver = InnerIdentityResolver {
111
+
http_client,
112
+
dns_resolver: Arc::new(dns_resolver),
113
+
plc_hostname: "plc.directory".to_string(),
114
+
};
115
+
let mention_facets = parse_mentions(text, &resolver, &limits).await;
116
+
facets.extend(mention_facets);
117
+
} else {
118
+
let mention_facets = parse_mention_spans(text);
119
+
facets.extend(mention_facets);
120
+
}
121
+
122
+
// Parse URLs
123
+
let url_facets = parse_urls(text);
124
+
facets.extend(url_facets);
125
+
126
+
// Parse hashtags
127
+
let tag_facets = parse_tags(text);
128
+
facets.extend(tag_facets);
129
+
130
+
// Sort facets by byte_start for consistent output
131
+
facets.sort_by_key(|f| f.index.byte_start);
132
+
133
+
// Output as JSON
134
+
if facets.is_empty() {
135
+
println!("null");
136
+
} else {
137
+
match serde_json::to_string_pretty(&facets) {
138
+
Ok(json) => println!("{}", json),
139
+
Err(e) => {
140
+
eprintln!(
141
+
"error-atproto-extras-parse-facets-1 Error serializing facets: {}",
142
+
e
143
+
);
144
+
std::process::exit(1);
145
+
}
146
+
}
147
+
}
148
+
149
+
// Show debug info if requested
150
+
if args.debug {
151
+
eprintln!();
152
+
eprintln!("--- Debug Info ---");
153
+
eprintln!("Input text: {:?}", text);
154
+
eprintln!("Text length: {} bytes", text.len());
155
+
eprintln!("Facets found: {}", facets.len());
156
+
eprintln!("Mentions resolved: {}", args.resolve_mentions);
157
+
158
+
// Show byte slice verification
159
+
let text_bytes = text.as_bytes();
160
+
for (i, facet) in facets.iter().enumerate() {
161
+
let start = facet.index.byte_start;
162
+
let end = facet.index.byte_end;
163
+
let slice_text =
164
+
std::str::from_utf8(&text_bytes[start..end]).unwrap_or("<invalid utf8>");
165
+
let feature_type = match &facet.features[0] {
166
+
FacetFeature::Mention(_) => "mention",
167
+
FacetFeature::Link(_) => "link",
168
+
FacetFeature::Tag(_) => "tag",
169
+
};
170
+
eprintln!(
171
+
" [{}] {} @ bytes {}..{}: {:?}",
172
+
i, feature_type, start, end, slice_text
173
+
);
174
+
}
175
+
}
176
+
}
+942
crates/atproto-extras/src/facets.rs
+942
crates/atproto-extras/src/facets.rs
···
1
+
//! Rich text facet parsing for AT Protocol.
2
+
//!
3
+
//! This module provides functionality for extracting semantic annotations (facets)
4
+
//! from plain text. Facets include mentions, links (URLs), and hashtags.
5
+
//!
6
+
//! # Overview
7
+
//!
8
+
//! AT Protocol rich text uses "facets" to annotate specific byte ranges within text with
9
+
//! semantic meaning. This module handles:
10
+
//!
11
+
//! - **Parsing**: Extract mentions, URLs, and hashtags from plain text
12
+
//! - **Facet Creation**: Build proper AT Protocol facet structures with resolved DIDs
13
+
//!
14
+
//! # Byte Offset Calculation
15
+
//!
16
+
//! This implementation correctly uses UTF-8 byte offsets as required by AT Protocol.
17
+
//! The facets use "inclusive start and exclusive end" byte ranges. All parsing is done
18
+
//! using `regex::bytes::Regex` which operates on byte slices and returns byte positions,
19
+
//! ensuring correct handling of multi-byte UTF-8 characters (emojis, CJK, accented chars).
20
+
//!
21
+
//! # Example
22
+
//!
23
+
//! ```ignore
24
+
//! use atproto_extras::facets::{parse_urls, parse_tags, FacetLimits};
25
+
//! use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
26
+
//!
27
+
//! let text = "Check out https://example.com #rust";
28
+
//!
29
+
//! // Parse URLs and tags as Facet objects
30
+
//! let url_facets = parse_urls(text);
31
+
//! let tag_facets = parse_tags(text);
32
+
//!
33
+
//! // Access facet data directly
34
+
//! for facet in url_facets {
35
+
//! if let Some(FacetFeature::Link(link)) = facet.features.first() {
36
+
//! println!("URL at bytes {}..{}: {}",
37
+
//! facet.index.byte_start, facet.index.byte_end, link.uri);
38
+
//! }
39
+
//! }
40
+
//! ```
41
+
42
+
use atproto_identity::resolve::IdentityResolver;
43
+
use atproto_record::lexicon::app::bsky::richtext::facet::{
44
+
ByteSlice, Facet, FacetFeature, Link, Mention, Tag,
45
+
};
46
+
use regex::bytes::Regex;
47
+
48
+
/// Configuration for facet parsing limits.
49
+
///
50
+
/// These limits protect against abuse by capping the number of facets
51
+
/// that will be processed. This is important for both performance and
52
+
/// security when handling user-generated content.
53
+
///
54
+
/// # Example
55
+
///
56
+
/// ```
57
+
/// use atproto_extras::FacetLimits;
58
+
///
59
+
/// // Use defaults
60
+
/// let limits = FacetLimits::default();
61
+
///
62
+
/// // Or customize
63
+
/// let custom = FacetLimits {
64
+
/// mentions_max: 10,
65
+
/// tags_max: 10,
66
+
/// links_max: 10,
67
+
/// max: 20,
68
+
/// };
69
+
/// ```
70
+
#[derive(Debug, Clone, Copy)]
71
+
pub struct FacetLimits {
72
+
/// Maximum number of mention facets to process (default: 5)
73
+
pub mentions_max: usize,
74
+
/// Maximum number of tag facets to process (default: 5)
75
+
pub tags_max: usize,
76
+
/// Maximum number of link facets to process (default: 5)
77
+
pub links_max: usize,
78
+
/// Maximum total number of facets to process (default: 10)
79
+
pub max: usize,
80
+
}
81
+
82
+
impl Default for FacetLimits {
83
+
fn default() -> Self {
84
+
Self {
85
+
mentions_max: 5,
86
+
tags_max: 5,
87
+
links_max: 5,
88
+
max: 10,
89
+
}
90
+
}
91
+
}
92
+
93
+
/// Parse mentions from text and return them as Facet objects with resolved DIDs.
94
+
///
95
+
/// This function extracts AT Protocol handle mentions (e.g., `@alice.bsky.social`)
96
+
/// from text, resolves each handle to a DID using the provided identity resolver,
97
+
/// and returns AT Protocol Facet objects with Mention features.
98
+
///
99
+
/// Mentions that cannot be resolved to a valid DID are skipped. Mentions that
100
+
/// appear within URLs are also excluded to avoid false positives.
101
+
///
102
+
/// # Arguments
103
+
///
104
+
/// * `text` - The text to parse for mentions
105
+
/// * `identity_resolver` - Resolver for converting handles to DIDs
106
+
/// * `limits` - Configuration for maximum mentions to process
107
+
///
108
+
/// # Returns
109
+
///
110
+
/// A vector of Facet objects for successfully resolved mentions.
111
+
///
112
+
/// # Example
113
+
///
114
+
/// ```ignore
115
+
/// use atproto_extras::{parse_mentions, FacetLimits};
116
+
/// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
117
+
///
118
+
/// let text = "Hello @alice.bsky.social!";
119
+
/// let limits = FacetLimits::default();
120
+
///
121
+
/// // Requires an async context and identity resolver
122
+
/// let facets = parse_mentions(text, &resolver, &limits).await;
123
+
///
124
+
/// for facet in facets {
125
+
/// if let Some(FacetFeature::Mention(mention)) = facet.features.first() {
126
+
/// println!("Mention {} resolved to {}",
127
+
/// &text[facet.index.byte_start..facet.index.byte_end],
128
+
/// mention.did);
129
+
/// }
130
+
/// }
131
+
/// ```
132
+
pub async fn parse_mentions(
133
+
text: &str,
134
+
identity_resolver: &dyn IdentityResolver,
135
+
limits: &FacetLimits,
136
+
) -> Vec<Facet> {
137
+
let mut facets = Vec::new();
138
+
139
+
// First, parse all URLs to exclude mention matches within them
140
+
let url_facets = parse_urls(text);
141
+
142
+
// Regex based on: https://atproto.com/specs/handle#handle-identifier-syntax
143
+
// Pattern: [$|\W](@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)
144
+
let mention_regex = Regex::new(
145
+
r"(?:^|[^\w])(@([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?)",
146
+
)
147
+
.unwrap();
148
+
149
+
let text_bytes = text.as_bytes();
150
+
let mut mention_count = 0;
151
+
152
+
for capture in mention_regex.captures_iter(text_bytes) {
153
+
if mention_count >= limits.mentions_max {
154
+
break;
155
+
}
156
+
157
+
if let Some(mention_match) = capture.get(1) {
158
+
let start = mention_match.start();
159
+
let end = mention_match.end();
160
+
161
+
// Check if this mention overlaps with any URL
162
+
let overlaps_url = url_facets.iter().any(|facet| {
163
+
// Check if mention is within or overlaps the URL span
164
+
(start >= facet.index.byte_start && start < facet.index.byte_end)
165
+
|| (end > facet.index.byte_start && end <= facet.index.byte_end)
166
+
});
167
+
168
+
// Only process the mention if it doesn't overlap with a URL
169
+
if !overlaps_url {
170
+
let handle = std::str::from_utf8(&mention_match.as_bytes()[1..])
171
+
.unwrap_or_default()
172
+
.to_string();
173
+
174
+
// Try to resolve the handle to a DID
175
+
// First try with at:// prefix, then without
176
+
let at_uri = format!("at://{}", handle);
177
+
let did_result = match identity_resolver.resolve(&at_uri).await {
178
+
Ok(doc) => Ok(doc),
179
+
Err(_) => identity_resolver.resolve(&handle).await,
180
+
};
181
+
182
+
// Only add the mention facet if we successfully resolved the DID
183
+
if let Ok(did_doc) = did_result {
184
+
facets.push(Facet {
185
+
index: ByteSlice {
186
+
byte_start: start,
187
+
byte_end: end,
188
+
},
189
+
features: vec![FacetFeature::Mention(Mention {
190
+
did: did_doc.id.to_string(),
191
+
})],
192
+
});
193
+
mention_count += 1;
194
+
}
195
+
}
196
+
}
197
+
}
198
+
199
+
facets
200
+
}
201
+
202
+
/// Parse URLs from text and return them as Facet objects.
203
+
///
204
+
/// This function extracts HTTP and HTTPS URLs from text with correct
205
+
/// byte position tracking for UTF-8 text, returning AT Protocol Facet objects
206
+
/// with Link features.
207
+
///
208
+
/// # Supported URL Patterns
209
+
///
210
+
/// - HTTP URLs: `http://example.com`
211
+
/// - HTTPS URLs: `https://example.com`
212
+
/// - URLs with paths, query strings, and fragments
213
+
/// - URLs with subdomains: `https://www.example.com`
214
+
///
215
+
/// # Example
216
+
///
217
+
/// ```
218
+
/// use atproto_extras::parse_urls;
219
+
/// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
220
+
///
221
+
/// let text = "Visit https://example.com/path?query=1 for more info";
222
+
/// let facets = parse_urls(text);
223
+
///
224
+
/// assert_eq!(facets.len(), 1);
225
+
/// assert_eq!(facets[0].index.byte_start, 6);
226
+
/// assert_eq!(facets[0].index.byte_end, 38);
227
+
/// if let Some(FacetFeature::Link(link)) = facets[0].features.first() {
228
+
/// assert_eq!(link.uri, "https://example.com/path?query=1");
229
+
/// }
230
+
/// ```
231
+
///
232
+
/// # Multi-byte Character Handling
233
+
///
234
+
/// Byte positions are correctly calculated even with emojis and other
235
+
/// multi-byte UTF-8 characters:
236
+
///
237
+
/// ```
238
+
/// use atproto_extras::parse_urls;
239
+
/// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
240
+
///
241
+
/// let text = "Check out https://example.com now!";
242
+
/// let facets = parse_urls(text);
243
+
/// let text_bytes = text.as_bytes();
244
+
///
245
+
/// // The byte slice matches the URL
246
+
/// let url_bytes = &text_bytes[facets[0].index.byte_start..facets[0].index.byte_end];
247
+
/// assert_eq!(std::str::from_utf8(url_bytes).unwrap(), "https://example.com");
248
+
/// ```
249
+
pub fn parse_urls(text: &str) -> Vec<Facet> {
250
+
let mut facets = Vec::new();
251
+
252
+
// Partial/naive URL regex based on: https://stackoverflow.com/a/3809435
253
+
// Pattern: [$|\W](https?:\/\/(www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]+\b([-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)
254
+
// Modified to use + instead of {1,6} to support longer TLDs and multi-level subdomains
255
+
let url_regex = Regex::new(
256
+
r"(?:^|[^\w])(https?://(?:www\.)?[-a-zA-Z0-9@:%._\+~#=]{1,256}\.[a-zA-Z0-9()]+\b(?:[-a-zA-Z0-9()@:%_\+.~#?&//=]*[-a-zA-Z0-9@%_\+~#//=])?)"
257
+
).unwrap();
258
+
259
+
let text_bytes = text.as_bytes();
260
+
for capture in url_regex.captures_iter(text_bytes) {
261
+
if let Some(url_match) = capture.get(1) {
262
+
let url = std::str::from_utf8(url_match.as_bytes())
263
+
.unwrap_or_default()
264
+
.to_string();
265
+
266
+
facets.push(Facet {
267
+
index: ByteSlice {
268
+
byte_start: url_match.start(),
269
+
byte_end: url_match.end(),
270
+
},
271
+
features: vec![FacetFeature::Link(Link { uri: url })],
272
+
});
273
+
}
274
+
}
275
+
276
+
facets
277
+
}
278
+
279
+
/// Parse hashtags from text and return them as Facet objects.
280
+
///
281
+
/// This function extracts hashtags (e.g., `#rust`, `#ATProto`) from text,
282
+
/// returning AT Protocol Facet objects with Tag features.
283
+
/// It supports both standard `#` and full-width `๏ผ` (U+FF03) hash symbols.
284
+
///
285
+
/// # Tag Syntax
286
+
///
287
+
/// - Tags must start with `#` or `๏ผ` (full-width)
288
+
/// - Tag content follows word character rules (`\w`)
289
+
/// - Purely numeric tags (e.g., `#123`) are excluded
290
+
///
291
+
/// # Example
292
+
///
293
+
/// ```
294
+
/// use atproto_extras::parse_tags;
295
+
/// use atproto_record::lexicon::app::bsky::richtext::facet::FacetFeature;
296
+
///
297
+
/// let text = "Learning #rust and #golang today! #100DaysOfCode";
298
+
/// let facets = parse_tags(text);
299
+
///
300
+
/// assert_eq!(facets.len(), 3);
301
+
/// if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() {
302
+
/// assert_eq!(tag.tag, "rust");
303
+
/// }
304
+
/// if let Some(FacetFeature::Tag(tag)) = facets[1].features.first() {
305
+
/// assert_eq!(tag.tag, "golang");
306
+
/// }
307
+
/// if let Some(FacetFeature::Tag(tag)) = facets[2].features.first() {
308
+
/// assert_eq!(tag.tag, "100DaysOfCode");
309
+
/// }
310
+
/// ```
311
+
///
312
+
/// # Numeric Tags
313
+
///
314
+
/// Purely numeric tags are excluded:
315
+
///
316
+
/// ```
317
+
/// use atproto_extras::parse_tags;
318
+
///
319
+
/// let text = "Item #42 is special";
320
+
/// let facets = parse_tags(text);
321
+
///
322
+
/// // #42 is not extracted because it's purely numeric
323
+
/// assert_eq!(facets.len(), 0);
324
+
/// ```
325
+
pub fn parse_tags(text: &str) -> Vec<Facet> {
326
+
let mut facets = Vec::new();
327
+
328
+
// Regex based on: https://github.com/bluesky-social/atproto/blob/d91988fe79030b61b556dd6f16a46f0c3b9d0b44/packages/api/src/rich-text/util.ts
329
+
// Simplified for Rust - matches hashtags at word boundaries
330
+
// Pattern matches: start of string or non-word char, then # or ๏ผ, then tag content
331
+
let tag_regex = Regex::new(r"(?:^|[^\w])([#\xEF\xBC\x83])([\w]+(?:[\w]*)*)").unwrap();
332
+
333
+
let text_bytes = text.as_bytes();
334
+
335
+
// Work with bytes for proper position tracking
336
+
for capture in tag_regex.captures_iter(text_bytes) {
337
+
if let (Some(full_match), Some(hash_match), Some(tag_match)) =
338
+
(capture.get(0), capture.get(1), capture.get(2))
339
+
{
340
+
// Calculate the absolute byte position of the hash symbol
341
+
// The full match includes the preceding character (if any)
342
+
// so we need to adjust for that
343
+
let match_start = full_match.start();
344
+
let hash_offset = hash_match.start() - full_match.start();
345
+
let start = match_start + hash_offset;
346
+
let end = match_start + hash_offset + hash_match.len() + tag_match.len();
347
+
348
+
// Extract just the tag text (without the hash symbol)
349
+
let tag = std::str::from_utf8(tag_match.as_bytes()).unwrap_or_default();
350
+
351
+
// Only include tags that are not purely numeric
352
+
if !tag.chars().all(|c| c.is_ascii_digit()) {
353
+
facets.push(Facet {
354
+
index: ByteSlice {
355
+
byte_start: start,
356
+
byte_end: end,
357
+
},
358
+
features: vec![FacetFeature::Tag(Tag {
359
+
tag: tag.to_string(),
360
+
})],
361
+
});
362
+
}
363
+
}
364
+
}
365
+
366
+
facets
367
+
}
368
+
369
+
/// Parse facets from text and return a vector of Facet objects.
370
+
///
371
+
/// This function extracts mentions, URLs, and hashtags from the provided text
372
+
/// and creates AT Protocol facets with proper byte indices.
373
+
///
374
+
/// Mentions are resolved to actual DIDs using the provided identity resolver.
375
+
/// If a handle cannot be resolved to a DID, the mention facet is skipped.
376
+
///
377
+
/// # Arguments
378
+
///
379
+
/// * `text` - The text to extract facets from
380
+
/// * `identity_resolver` - Resolver for converting handles to DIDs
381
+
/// * `limits` - Configuration for maximum facets per type and total
382
+
///
383
+
/// # Returns
384
+
///
385
+
/// Optional vector of facets. Returns `None` if no facets were found.
386
+
///
387
+
/// # Example
388
+
///
389
+
/// ```ignore
390
+
/// use atproto_extras::{parse_facets_from_text, FacetLimits};
391
+
///
392
+
/// let text = "Hello @alice.bsky.social! Check #rust at https://rust-lang.org";
393
+
/// let limits = FacetLimits::default();
394
+
///
395
+
/// // Requires an async context and identity resolver
396
+
/// let facets = parse_facets_from_text(text, &resolver, &limits).await;
397
+
///
398
+
/// if let Some(facets) = facets {
399
+
/// for facet in &facets {
400
+
/// println!("Facet at {}..{}", facet.index.byte_start, facet.index.byte_end);
401
+
/// }
402
+
/// }
403
+
/// ```
404
+
///
405
+
/// # Mention Resolution
406
+
///
407
+
/// Mentions are only included if the handle resolves to a valid DID:
408
+
///
409
+
/// ```ignore
410
+
/// let text = "@valid.handle.com and @invalid.handle.xyz";
411
+
/// let facets = parse_facets_from_text(text, &resolver, &limits).await;
412
+
///
413
+
/// // Only @valid.handle.com appears as a facet if @invalid.handle.xyz
414
+
/// // cannot be resolved to a DID
415
+
/// ```
416
+
pub async fn parse_facets_from_text(
417
+
text: &str,
418
+
identity_resolver: &dyn IdentityResolver,
419
+
limits: &FacetLimits,
420
+
) -> Option<Vec<Facet>> {
421
+
let mut facets = Vec::new();
422
+
423
+
// Parse mentions (already limited by mentions_max in parse_mentions)
424
+
let mention_facets = parse_mentions(text, identity_resolver, limits).await;
425
+
facets.extend(mention_facets);
426
+
427
+
// Parse URLs (limited by links_max)
428
+
let url_facets = parse_urls(text);
429
+
for (idx, facet) in url_facets.into_iter().enumerate() {
430
+
if idx >= limits.links_max {
431
+
break;
432
+
}
433
+
facets.push(facet);
434
+
}
435
+
436
+
// Parse hashtags (limited by tags_max)
437
+
let tag_facets = parse_tags(text);
438
+
for (idx, facet) in tag_facets.into_iter().enumerate() {
439
+
if idx >= limits.tags_max {
440
+
break;
441
+
}
442
+
facets.push(facet);
443
+
}
444
+
445
+
// Apply global facet limit (truncate if exceeds max)
446
+
if facets.len() > limits.max {
447
+
facets.truncate(limits.max);
448
+
}
449
+
450
+
// Only return facets if we found any
451
+
if !facets.is_empty() {
452
+
Some(facets)
453
+
} else {
454
+
None
455
+
}
456
+
}
457
+
458
+
#[cfg(test)]
459
+
mod tests {
460
+
use async_trait::async_trait;
461
+
use atproto_identity::model::Document;
462
+
use std::collections::HashMap;
463
+
464
+
use super::*;
465
+
466
+
/// Mock identity resolver for testing
467
+
struct MockIdentityResolver {
468
+
handles_to_dids: HashMap<String, String>,
469
+
}
470
+
471
+
impl MockIdentityResolver {
472
+
fn new() -> Self {
473
+
let mut handles_to_dids = HashMap::new();
474
+
handles_to_dids.insert(
475
+
"alice.bsky.social".to_string(),
476
+
"did:plc:alice123".to_string(),
477
+
);
478
+
handles_to_dids.insert(
479
+
"at://alice.bsky.social".to_string(),
480
+
"did:plc:alice123".to_string(),
481
+
);
482
+
Self { handles_to_dids }
483
+
}
484
+
485
+
fn add_identity(&mut self, handle: &str, did: &str) {
486
+
self.handles_to_dids
487
+
.insert(handle.to_string(), did.to_string());
488
+
self.handles_to_dids
489
+
.insert(format!("at://{}", handle), did.to_string());
490
+
}
491
+
}
492
+
493
+
#[async_trait]
494
+
impl IdentityResolver for MockIdentityResolver {
495
+
async fn resolve(&self, handle: &str) -> anyhow::Result<Document> {
496
+
let handle_key = handle.to_string();
497
+
498
+
if let Some(did) = self.handles_to_dids.get(&handle_key) {
499
+
Ok(Document {
500
+
context: vec![],
501
+
id: did.clone(),
502
+
also_known_as: vec![format!("at://{}", handle_key.trim_start_matches("at://"))],
503
+
verification_method: vec![],
504
+
service: vec![],
505
+
extra: HashMap::new(),
506
+
})
507
+
} else {
508
+
Err(anyhow::anyhow!("Handle not found"))
509
+
}
510
+
}
511
+
}
512
+
513
+
#[tokio::test]
514
+
async fn test_parse_facets_from_text_comprehensive() {
515
+
let mut resolver = MockIdentityResolver::new();
516
+
resolver.add_identity("bob.test.com", "did:plc:bob456");
517
+
518
+
let limits = FacetLimits::default();
519
+
let text = "Join @alice.bsky.social and @bob.test.com at https://example.com #rust #golang";
520
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
521
+
522
+
assert!(facets.is_some());
523
+
let facets = facets.unwrap();
524
+
assert_eq!(facets.len(), 5); // 2 mentions, 1 URL, 2 hashtags
525
+
526
+
// Check first mention
527
+
assert_eq!(facets[0].index.byte_start, 5);
528
+
assert_eq!(facets[0].index.byte_end, 23);
529
+
if let FacetFeature::Mention(ref mention) = facets[0].features[0] {
530
+
assert_eq!(mention.did, "did:plc:alice123");
531
+
} else {
532
+
panic!("Expected Mention feature");
533
+
}
534
+
535
+
// Check second mention
536
+
assert_eq!(facets[1].index.byte_start, 28);
537
+
assert_eq!(facets[1].index.byte_end, 41);
538
+
if let FacetFeature::Mention(mention) = &facets[1].features[0] {
539
+
assert_eq!(mention.did, "did:plc:bob456");
540
+
} else {
541
+
panic!("Expected Mention feature");
542
+
}
543
+
544
+
// Check URL
545
+
assert_eq!(facets[2].index.byte_start, 45);
546
+
assert_eq!(facets[2].index.byte_end, 64);
547
+
if let FacetFeature::Link(link) = &facets[2].features[0] {
548
+
assert_eq!(link.uri, "https://example.com");
549
+
} else {
550
+
panic!("Expected Link feature");
551
+
}
552
+
553
+
// Check first hashtag
554
+
assert_eq!(facets[3].index.byte_start, 65);
555
+
assert_eq!(facets[3].index.byte_end, 70);
556
+
if let FacetFeature::Tag(tag) = &facets[3].features[0] {
557
+
assert_eq!(tag.tag, "rust");
558
+
} else {
559
+
panic!("Expected Tag feature");
560
+
}
561
+
562
+
// Check second hashtag
563
+
assert_eq!(facets[4].index.byte_start, 71);
564
+
assert_eq!(facets[4].index.byte_end, 78);
565
+
if let FacetFeature::Tag(tag) = &facets[4].features[0] {
566
+
assert_eq!(tag.tag, "golang");
567
+
} else {
568
+
panic!("Expected Tag feature");
569
+
}
570
+
}
571
+
572
+
#[tokio::test]
573
+
async fn test_parse_facets_from_text_with_unresolvable_mention() {
574
+
let resolver = MockIdentityResolver::new();
575
+
let limits = FacetLimits::default();
576
+
577
+
// Only alice.bsky.social is in the resolver, not unknown.handle.com
578
+
let text = "Contact @unknown.handle.com for details #rust";
579
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
580
+
581
+
assert!(facets.is_some());
582
+
let facets = facets.unwrap();
583
+
// Should only have 1 facet (the hashtag) since the mention couldn't be resolved
584
+
assert_eq!(facets.len(), 1);
585
+
586
+
// Check that it's the hashtag facet
587
+
if let FacetFeature::Tag(tag) = &facets[0].features[0] {
588
+
assert_eq!(tag.tag, "rust");
589
+
} else {
590
+
panic!("Expected Tag feature");
591
+
}
592
+
}
593
+
594
+
#[tokio::test]
595
+
async fn test_parse_facets_from_text_empty() {
596
+
let resolver = MockIdentityResolver::new();
597
+
let limits = FacetLimits::default();
598
+
let text = "No mentions, URLs, or hashtags here";
599
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
600
+
assert!(facets.is_none());
601
+
}
602
+
603
+
#[tokio::test]
604
+
async fn test_parse_facets_from_text_url_with_at_mention() {
605
+
let resolver = MockIdentityResolver::new();
606
+
let limits = FacetLimits::default();
607
+
608
+
// URLs with @ should not create mention facets
609
+
let text = "Tangled https://tangled.org/@smokesignal.events";
610
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
611
+
612
+
assert!(facets.is_some());
613
+
let facets = facets.unwrap();
614
+
615
+
// Should have exactly 1 facet (the URL), not 2 (URL + mention)
616
+
assert_eq!(
617
+
facets.len(),
618
+
1,
619
+
"Expected 1 facet (URL only), got {}",
620
+
facets.len()
621
+
);
622
+
623
+
// Verify it's a link facet, not a mention
624
+
if let FacetFeature::Link(link) = &facets[0].features[0] {
625
+
assert_eq!(link.uri, "https://tangled.org/@smokesignal.events");
626
+
} else {
627
+
panic!("Expected Link feature, got Mention or Tag instead");
628
+
}
629
+
}
630
+
631
+
#[tokio::test]
632
+
async fn test_parse_facets_with_mention_limit() {
633
+
let mut resolver = MockIdentityResolver::new();
634
+
resolver.add_identity("bob.test.com", "did:plc:bob456");
635
+
resolver.add_identity("charlie.test.com", "did:plc:charlie789");
636
+
637
+
// Limit to 2 mentions
638
+
let limits = FacetLimits {
639
+
mentions_max: 2,
640
+
tags_max: 5,
641
+
links_max: 5,
642
+
max: 10,
643
+
};
644
+
645
+
let text = "Join @alice.bsky.social @bob.test.com @charlie.test.com";
646
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
647
+
648
+
assert!(facets.is_some());
649
+
let facets = facets.unwrap();
650
+
// Should only have 2 mentions (alice and bob), charlie should be skipped
651
+
assert_eq!(facets.len(), 2);
652
+
653
+
// Verify they're both mentions
654
+
for facet in &facets {
655
+
assert!(matches!(facet.features[0], FacetFeature::Mention(_)));
656
+
}
657
+
}
658
+
659
+
#[tokio::test]
660
+
async fn test_parse_facets_with_global_limit() {
661
+
let mut resolver = MockIdentityResolver::new();
662
+
resolver.add_identity("bob.test.com", "did:plc:bob456");
663
+
664
+
// Very restrictive global limit
665
+
let limits = FacetLimits {
666
+
mentions_max: 5,
667
+
tags_max: 5,
668
+
links_max: 5,
669
+
max: 3, // Only allow 3 total facets
670
+
};
671
+
672
+
let text =
673
+
"Join @alice.bsky.social @bob.test.com at https://example.com #rust #golang #python";
674
+
let facets = parse_facets_from_text(text, &resolver, &limits).await;
675
+
676
+
assert!(facets.is_some());
677
+
let facets = facets.unwrap();
678
+
// Should be truncated to 3 facets total
679
+
assert_eq!(facets.len(), 3);
680
+
}
681
+
682
+
#[test]
683
+
fn test_parse_urls_multiple_links() {
684
+
let text = "IETF124 is happening in Montreal, Nov 1st to 7th https://www.ietf.org/meeting/124/\n\nWe're confirmed for two days of ATProto community sessions on Monday, Nov 3rd & Tuesday, Mov 4th at ECTO Co-Op. Many of us will also be participating in the free-to-attend IETF hackathon on Sunday, Nov 2nd.\n\nLatest updates and attendees in the forum https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164";
685
+
686
+
let facets = parse_urls(text);
687
+
688
+
// Should find both URLs
689
+
assert_eq!(
690
+
facets.len(),
691
+
2,
692
+
"Expected 2 URLs but found {}",
693
+
facets.len()
694
+
);
695
+
696
+
// Check first URL
697
+
if let Some(FacetFeature::Link(link)) = facets[0].features.first() {
698
+
assert_eq!(link.uri, "https://www.ietf.org/meeting/124/");
699
+
} else {
700
+
panic!("Expected Link feature");
701
+
}
702
+
703
+
// Check second URL
704
+
if let Some(FacetFeature::Link(link)) = facets[1].features.first() {
705
+
assert_eq!(
706
+
link.uri,
707
+
"https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164"
708
+
);
709
+
} else {
710
+
panic!("Expected Link feature");
711
+
}
712
+
}
713
+
714
+
#[test]
715
+
fn test_parse_urls_with_html_entity() {
716
+
// Test with the HTML entity & in the text
717
+
let text = "IETF124 is happening in Montreal, Nov 1st to 7th https://www.ietf.org/meeting/124/\n\nWe're confirmed for two days of ATProto community sessions on Monday, Nov 3rd & Tuesday, Mov 4th at ECTO Co-Op. Many of us will also be participating in the free-to-attend IETF hackathon on Sunday, Nov 2nd.\n\nLatest updates and attendees in the forum https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164";
718
+
719
+
let facets = parse_urls(text);
720
+
721
+
// Should find both URLs
722
+
assert_eq!(
723
+
facets.len(),
724
+
2,
725
+
"Expected 2 URLs but found {}",
726
+
facets.len()
727
+
);
728
+
729
+
// Check first URL
730
+
if let Some(FacetFeature::Link(link)) = facets[0].features.first() {
731
+
assert_eq!(link.uri, "https://www.ietf.org/meeting/124/");
732
+
} else {
733
+
panic!("Expected Link feature");
734
+
}
735
+
736
+
// Check second URL
737
+
if let Some(FacetFeature::Link(link)) = facets[1].features.first() {
738
+
assert_eq!(
739
+
link.uri,
740
+
"https://discourse.atprotocol.community/t/update-on-timing-and-plan-for-montreal/164"
741
+
);
742
+
} else {
743
+
panic!("Expected Link feature");
744
+
}
745
+
}
746
+
747
+
#[test]
748
+
fn test_byte_offset_with_html_entities() {
749
+
// This test demonstrates that HTML entity escaping shifts byte positions.
750
+
// The byte positions shift:
751
+
// In original: '&' is at byte 8 (1 byte)
752
+
// In escaped: '&' starts at byte 8 (5 bytes)
753
+
// This causes facet byte offsets to be misaligned if text is escaped before rendering.
754
+
755
+
// If we have a URL after the ampersand in the original:
756
+
let original_with_url = "Nov 3rd & Tuesday https://example.com";
757
+
let escaped_with_url = "Nov 3rd & Tuesday https://example.com";
758
+
759
+
// Parse URLs from both versions
760
+
let original_facets = parse_urls(original_with_url);
761
+
let escaped_facets = parse_urls(escaped_with_url);
762
+
763
+
// Both should find the URL, but at different byte positions
764
+
assert_eq!(original_facets.len(), 1);
765
+
assert_eq!(escaped_facets.len(), 1);
766
+
767
+
// The byte positions will be different
768
+
assert_eq!(original_facets[0].index.byte_start, 18); // After "Nov 3rd & Tuesday "
769
+
assert_eq!(escaped_facets[0].index.byte_start, 22); // After "Nov 3rd & Tuesday " (4 extra bytes for &)
770
+
}
771
+
772
+
#[test]
773
+
fn test_parse_urls_from_atproto_record_text() {
774
+
// Test parsing URLs from real AT Protocol record description text.
775
+
// This demonstrates the correct byte positions that should be used for facets.
776
+
let text = "Dev, Power Users, and Generally inquisitive folks get a completely unprofessionally amateur interview. Just a yap sesh where chat is part of the call!\n\nโจthe danielโจ & I will be on a Zoom call and I will stream out to https://stream.place/psingletary.com\n\nSubscribe to the publications! https://atprotocalls.leaflet.pub/";
777
+
778
+
let facets = parse_urls(text);
779
+
780
+
assert_eq!(facets.len(), 2, "Should find 2 URLs");
781
+
782
+
// First URL: https://stream.place/psingletary.com
783
+
assert_eq!(facets[0].index.byte_start, 221);
784
+
assert_eq!(facets[0].index.byte_end, 257);
785
+
if let Some(FacetFeature::Link(link)) = facets[0].features.first() {
786
+
assert_eq!(link.uri, "https://stream.place/psingletary.com");
787
+
}
788
+
789
+
// Second URL: https://atprotocalls.leaflet.pub/
790
+
assert_eq!(facets[1].index.byte_start, 290);
791
+
assert_eq!(facets[1].index.byte_end, 323);
792
+
if let Some(FacetFeature::Link(link)) = facets[1].features.first() {
793
+
assert_eq!(link.uri, "https://atprotocalls.leaflet.pub/");
794
+
}
795
+
796
+
// Verify the byte slices match the expected text
797
+
let text_bytes = text.as_bytes();
798
+
assert_eq!(
799
+
std::str::from_utf8(&text_bytes[221..257]).unwrap(),
800
+
"https://stream.place/psingletary.com"
801
+
);
802
+
assert_eq!(
803
+
std::str::from_utf8(&text_bytes[290..323]).unwrap(),
804
+
"https://atprotocalls.leaflet.pub/"
805
+
);
806
+
}
807
+
808
+
#[tokio::test]
809
+
async fn test_parse_mentions_basic() {
810
+
let resolver = MockIdentityResolver::new();
811
+
let limits = FacetLimits::default();
812
+
let text = "Hello @alice.bsky.social!";
813
+
let facets = parse_mentions(text, &resolver, &limits).await;
814
+
815
+
assert_eq!(facets.len(), 1);
816
+
assert_eq!(facets[0].index.byte_start, 6);
817
+
assert_eq!(facets[0].index.byte_end, 24);
818
+
if let Some(FacetFeature::Mention(mention)) = facets[0].features.first() {
819
+
assert_eq!(mention.did, "did:plc:alice123");
820
+
} else {
821
+
panic!("Expected Mention feature");
822
+
}
823
+
}
824
+
825
+
#[tokio::test]
826
+
async fn test_parse_mentions_multiple() {
827
+
let mut resolver = MockIdentityResolver::new();
828
+
resolver.add_identity("bob.example.com", "did:plc:bob456");
829
+
let limits = FacetLimits::default();
830
+
let text = "CC @alice.bsky.social and @bob.example.com";
831
+
let facets = parse_mentions(text, &resolver, &limits).await;
832
+
833
+
assert_eq!(facets.len(), 2);
834
+
if let Some(FacetFeature::Mention(mention)) = facets[0].features.first() {
835
+
assert_eq!(mention.did, "did:plc:alice123");
836
+
}
837
+
if let Some(FacetFeature::Mention(mention)) = facets[1].features.first() {
838
+
assert_eq!(mention.did, "did:plc:bob456");
839
+
}
840
+
}
841
+
842
+
#[tokio::test]
843
+
async fn test_parse_mentions_unresolvable() {
844
+
let resolver = MockIdentityResolver::new();
845
+
let limits = FacetLimits::default();
846
+
// unknown.handle.com is not in the resolver
847
+
let text = "Hello @unknown.handle.com!";
848
+
let facets = parse_mentions(text, &resolver, &limits).await;
849
+
850
+
// Should be empty since the handle can't be resolved
851
+
assert_eq!(facets.len(), 0);
852
+
}
853
+
854
+
#[tokio::test]
855
+
async fn test_parse_mentions_in_url_excluded() {
856
+
let resolver = MockIdentityResolver::new();
857
+
let limits = FacetLimits::default();
858
+
// The @smokesignal.events is inside a URL and should not be parsed as a mention
859
+
let text = "Check https://tangled.org/@smokesignal.events";
860
+
let facets = parse_mentions(text, &resolver, &limits).await;
861
+
862
+
// Should be empty since the mention is inside a URL
863
+
assert_eq!(facets.len(), 0);
864
+
}
865
+
866
+
#[test]
867
+
fn test_parse_tags_basic() {
868
+
let text = "Learning #rust today!";
869
+
let facets = parse_tags(text);
870
+
871
+
assert_eq!(facets.len(), 1);
872
+
assert_eq!(facets[0].index.byte_start, 9);
873
+
assert_eq!(facets[0].index.byte_end, 14);
874
+
if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() {
875
+
assert_eq!(tag.tag, "rust");
876
+
} else {
877
+
panic!("Expected Tag feature");
878
+
}
879
+
}
880
+
881
+
#[test]
882
+
fn test_parse_tags_multiple() {
883
+
let text = "#rust #golang #python are great!";
884
+
let facets = parse_tags(text);
885
+
886
+
assert_eq!(facets.len(), 3);
887
+
if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() {
888
+
assert_eq!(tag.tag, "rust");
889
+
}
890
+
if let Some(FacetFeature::Tag(tag)) = facets[1].features.first() {
891
+
assert_eq!(tag.tag, "golang");
892
+
}
893
+
if let Some(FacetFeature::Tag(tag)) = facets[2].features.first() {
894
+
assert_eq!(tag.tag, "python");
895
+
}
896
+
}
897
+
898
+
#[test]
899
+
fn test_parse_tags_excludes_numeric() {
900
+
let text = "Item #42 is special #test123";
901
+
let facets = parse_tags(text);
902
+
903
+
// #42 should be excluded (purely numeric), #test123 should be included
904
+
assert_eq!(facets.len(), 1);
905
+
if let Some(FacetFeature::Tag(tag)) = facets[0].features.first() {
906
+
assert_eq!(tag.tag, "test123");
907
+
}
908
+
}
909
+
910
+
#[test]
911
+
fn test_parse_urls_basic() {
912
+
let text = "Visit https://example.com today!";
913
+
let facets = parse_urls(text);
914
+
915
+
assert_eq!(facets.len(), 1);
916
+
assert_eq!(facets[0].index.byte_start, 6);
917
+
assert_eq!(facets[0].index.byte_end, 25);
918
+
if let Some(FacetFeature::Link(link)) = facets[0].features.first() {
919
+
assert_eq!(link.uri, "https://example.com");
920
+
}
921
+
}
922
+
923
+
#[test]
924
+
fn test_parse_urls_with_path() {
925
+
let text = "Check https://example.com/path/to/page?query=1#section";
926
+
let facets = parse_urls(text);
927
+
928
+
assert_eq!(facets.len(), 1);
929
+
if let Some(FacetFeature::Link(link)) = facets[0].features.first() {
930
+
assert_eq!(link.uri, "https://example.com/path/to/page?query=1#section");
931
+
}
932
+
}
933
+
934
+
#[test]
935
+
fn test_facet_limits_default() {
936
+
let limits = FacetLimits::default();
937
+
assert_eq!(limits.mentions_max, 5);
938
+
assert_eq!(limits.tags_max, 5);
939
+
assert_eq!(limits.links_max, 5);
940
+
assert_eq!(limits.max, 10);
941
+
}
942
+
}
+50
crates/atproto-extras/src/lib.rs
+50
crates/atproto-extras/src/lib.rs
···
1
+
//! Extra utilities for AT Protocol applications.
2
+
//!
3
+
//! This crate provides additional utilities that complement the core AT Protocol
4
+
//! identity and record crates. Currently, it focuses on rich text facet parsing.
5
+
//!
6
+
//! ## Features
7
+
//!
8
+
//! - **Facet Parsing**: Extract mentions, URLs, and hashtags from plain text
9
+
//! with correct UTF-8 byte offset calculation
10
+
//! - **Identity Integration**: Resolve mention handles to DIDs during parsing
11
+
//!
12
+
//! ## Example
13
+
//!
14
+
//! ```ignore
15
+
//! use atproto_extras::{parse_facets_from_text, FacetLimits};
16
+
//!
17
+
//! // Parse facets from text (requires an IdentityResolver)
18
+
//! let text = "Hello @alice.bsky.social! Check out https://example.com #rust";
19
+
//! let limits = FacetLimits::default();
20
+
//! let facets = parse_facets_from_text(text, &resolver, &limits).await;
21
+
//! ```
22
+
//!
23
+
//! ## Byte Offset Calculation
24
+
//!
25
+
//! This implementation correctly uses UTF-8 byte offsets as required by AT Protocol.
26
+
//! The facets use "inclusive start and exclusive end" byte ranges. All parsing is done
27
+
//! using `regex::bytes::Regex` which operates on byte slices and returns byte positions,
28
+
//! ensuring correct handling of multi-byte UTF-8 characters (emojis, CJK, accented chars).
29
+
30
+
#![forbid(unsafe_code)]
31
+
#![warn(missing_docs)]
32
+
33
+
/// Rich text facet parsing for AT Protocol.
34
+
///
35
+
/// This module provides functionality for extracting semantic annotations (facets)
36
+
/// from plain text. Facets include:
37
+
///
38
+
/// - **Mentions**: User handles prefixed with `@` (e.g., `@alice.bsky.social`)
39
+
/// - **Links**: HTTP/HTTPS URLs
40
+
/// - **Tags**: Hashtags prefixed with `#` or `๏ผ` (e.g., `#rust`)
41
+
///
42
+
/// ## Byte Offsets
43
+
///
44
+
/// All facet indices use UTF-8 byte offsets, not character indices. This is
45
+
/// critical for correct handling of multi-byte characters like emojis or
46
+
/// non-ASCII text.
47
+
pub mod facets;
48
+
49
+
/// Re-export commonly used types for convenience.
50
+
pub use facets::{FacetLimits, parse_facets_from_text, parse_mentions, parse_tags, parse_urls};
+19
-1
crates/atproto-identity/src/model.rs
+19
-1
crates/atproto-identity/src/model.rs
···
70
70
/// The DID identifier (e.g., "did:plc:abc123").
71
71
pub id: String,
72
72
/// Alternative identifiers like handles and domains.
73
+
#[serde(default)]
73
74
pub also_known_as: Vec<String>,
74
75
/// Available services for this identity.
76
+
#[serde(default)]
75
77
pub service: Vec<Service>,
76
78
77
79
/// Cryptographic verification methods.
78
-
#[serde(alias = "verificationMethod")]
80
+
#[serde(alias = "verificationMethod", default)]
79
81
pub verification_method: Vec<VerificationMethod>,
80
82
81
83
/// Additional document properties not explicitly defined.
···
402
404
let document = document.unwrap();
403
405
assert_eq!(document.id, "did:plc:cbkjy5n7bk3ax2wplmtjofq2");
404
406
}
407
+
}
408
+
409
+
#[test]
410
+
fn test_deserialize_service_did_document() {
411
+
// DID document from api.bsky.app - a service DID without alsoKnownAs
412
+
let document = serde_json::from_str::<Document>(
413
+
r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1"],"id":"did:web:api.bsky.app","verificationMethod":[{"id":"did:web:api.bsky.app#atproto","type":"Multikey","controller":"did:web:api.bsky.app","publicKeyMultibase":"zQ3shpRzb2NDriwCSSsce6EqGxG23kVktHZc57C3NEcuNy1jg"}],"service":[{"id":"#bsky_notif","type":"BskyNotificationService","serviceEndpoint":"https://api.bsky.app"},{"id":"#bsky_appview","type":"BskyAppView","serviceEndpoint":"https://api.bsky.app"}]}"##,
414
+
);
415
+
assert!(document.is_ok(), "Failed to parse: {:?}", document.err());
416
+
417
+
let document = document.unwrap();
418
+
assert_eq!(document.id, "did:web:api.bsky.app");
419
+
assert!(document.also_known_as.is_empty());
420
+
assert_eq!(document.service.len(), 2);
421
+
assert_eq!(document.service[0].id, "#bsky_notif");
422
+
assert_eq!(document.service[1].id, "#bsky_appview");
405
423
}
406
424
}
+75
-24
crates/atproto-jetstream/src/consumer.rs
+75
-24
crates/atproto-jetstream/src/consumer.rs
···
2
2
//!
3
3
//! WebSocket event consumption with background processing and
4
4
//! customizable event handler dispatch.
5
+
//!
6
+
//! ## Memory Efficiency
7
+
//!
8
+
//! This module is optimized for high-throughput event processing with minimal allocations:
9
+
//!
10
+
//! - **Arc-based event sharing**: Events are wrapped in `Arc` and shared across all handlers,
11
+
//! avoiding expensive clones of event data structures.
12
+
//! - **Zero-copy handler IDs**: Handler identifiers use string slices to avoid allocations
13
+
//! during registration and dispatch.
14
+
//! - **Optimized query building**: WebSocket query strings are built with pre-allocated
15
+
//! capacity to minimize reallocations.
16
+
//!
17
+
//! ## Usage
18
+
//!
19
+
//! Implement the `EventHandler` trait to process events:
20
+
//!
21
+
//! ```rust
22
+
//! use atproto_jetstream::{EventHandler, JetstreamEvent};
23
+
//! use async_trait::async_trait;
24
+
//! use std::sync::Arc;
25
+
//! use anyhow::Result;
26
+
//!
27
+
//! struct MyHandler;
28
+
//!
29
+
//! #[async_trait]
30
+
//! impl EventHandler for MyHandler {
31
+
//! async fn handle_event(&self, event: Arc<JetstreamEvent>) -> Result<()> {
32
+
//! // Process event without cloning
33
+
//! Ok(())
34
+
//! }
35
+
//!
36
+
//! fn handler_id(&self) -> &str {
37
+
//! "my-handler"
38
+
//! }
39
+
//! }
40
+
//! ```
5
41
6
42
use crate::errors::ConsumerError;
7
43
use anyhow::Result;
···
133
169
#[async_trait]
134
170
pub trait EventHandler: Send + Sync {
135
171
/// Handle a received event
136
-
async fn handle_event(&self, event: JetstreamEvent) -> Result<()>;
172
+
///
173
+
/// Events are wrapped in Arc to enable efficient sharing across multiple handlers
174
+
/// without cloning the entire event data structure.
175
+
async fn handle_event(&self, event: Arc<JetstreamEvent>) -> Result<()>;
137
176
138
177
/// Get the handler's identifier
139
-
fn handler_id(&self) -> String;
178
+
///
179
+
/// Returns a string slice to avoid unnecessary allocations.
180
+
fn handler_id(&self) -> &str;
140
181
}
141
182
142
183
#[cfg_attr(debug_assertions, derive(Debug))]
···
167
208
pub struct Consumer {
168
209
config: ConsumerTaskConfig,
169
210
handlers: Arc<RwLock<HashMap<String, Arc<dyn EventHandler>>>>,
170
-
event_sender: Arc<RwLock<Option<broadcast::Sender<JetstreamEvent>>>>,
211
+
event_sender: Arc<RwLock<Option<broadcast::Sender<Arc<JetstreamEvent>>>>>,
171
212
}
172
213
173
214
impl Consumer {
···
185
226
let handler_id = handler.handler_id();
186
227
let mut handlers = self.handlers.write().await;
187
228
188
-
if handlers.contains_key(&handler_id) {
229
+
if handlers.contains_key(handler_id) {
189
230
return Err(ConsumerError::HandlerRegistrationFailed(format!(
190
231
"Handler with ID '{}' already registered",
191
232
handler_id
···
193
234
.into());
194
235
}
195
236
196
-
handlers.insert(handler_id.clone(), handler);
237
+
handlers.insert(handler_id.to_string(), handler);
197
238
Ok(())
198
239
}
199
240
···
205
246
}
206
247
207
248
/// Get a broadcast receiver for events
208
-
pub async fn get_event_receiver(&self) -> Result<broadcast::Receiver<JetstreamEvent>> {
249
+
///
250
+
/// Events are wrapped in Arc to enable efficient sharing without cloning.
251
+
pub async fn get_event_receiver(&self) -> Result<broadcast::Receiver<Arc<JetstreamEvent>>> {
209
252
let sender_guard = self.event_sender.read().await;
210
253
match sender_guard.as_ref() {
211
254
Some(sender) => Ok(sender.subscribe()),
···
249
292
tracing::info!("Starting Jetstream consumer");
250
293
251
294
// Build WebSocket URL with query parameters
252
-
let mut query_params = vec![];
295
+
// Pre-allocate capacity to avoid reallocations during string building
296
+
let capacity = 50 // Base parameters
297
+
+ self.config.collections.len() * 30 // Estimate per collection
298
+
+ self.config.dids.len() * 60; // Estimate per DID
299
+
let mut query_string = String::with_capacity(capacity);
253
300
254
301
// Add compression parameter
255
-
query_params.push(format!("compress={}", self.config.compression));
302
+
query_string.push_str("compress=");
303
+
query_string.push_str(if self.config.compression { "true" } else { "false" });
256
304
257
305
// Add requireHello parameter
258
-
query_params.push(format!("requireHello={}", self.config.require_hello));
306
+
query_string.push_str("&requireHello=");
307
+
query_string.push_str(if self.config.require_hello { "true" } else { "false" });
259
308
260
309
// Add wantedCollections if specified (each collection as a separate query parameter)
261
310
if !self.config.collections.is_empty() && !self.config.require_hello {
262
311
for collection in &self.config.collections {
263
-
query_params.push(format!(
264
-
"wantedCollections={}",
265
-
urlencoding::encode(collection)
266
-
));
312
+
query_string.push_str("&wantedCollections=");
313
+
query_string.push_str(&urlencoding::encode(collection));
267
314
}
268
315
}
269
316
270
317
// Add wantedDids if specified (each DID as a separate query parameter)
271
318
if !self.config.dids.is_empty() && !self.config.require_hello {
272
319
for did in &self.config.dids {
273
-
query_params.push(format!("wantedDids={}", urlencoding::encode(did)));
320
+
query_string.push_str("&wantedDids=");
321
+
query_string.push_str(&urlencoding::encode(did));
274
322
}
275
323
}
276
324
277
325
// Add maxMessageSizeBytes if specified
278
326
if let Some(max_size) = self.config.max_message_size_bytes {
279
-
query_params.push(format!("maxMessageSizeBytes={}", max_size));
327
+
use std::fmt::Write;
328
+
write!(&mut query_string, "&maxMessageSizeBytes={}", max_size).unwrap();
280
329
}
281
330
282
331
// Add cursor if specified
283
332
if let Some(cursor) = self.config.cursor {
284
-
query_params.push(format!("cursor={}", cursor));
333
+
use std::fmt::Write;
334
+
write!(&mut query_string, "&cursor={}", cursor).unwrap();
285
335
}
286
-
287
-
let query_string = query_params.join("&");
288
336
let ws_url = Uri::from_str(&format!(
289
337
"wss://{}/subscribe?{}",
290
338
self.config.jetstream_hostname, query_string
···
335
383
break;
336
384
},
337
385
() = &mut sleeper => {
338
-
// consumer_control_insert(&self.pool, &self.config.jetstream_hostname, time_usec).await?;
339
-
340
386
sleeper.as_mut().reset(Instant::now() + interval);
341
387
},
342
388
item = client.next() => {
···
404
450
}
405
451
406
452
/// Dispatch event to all registered handlers
453
+
///
454
+
/// Wraps the event in Arc once and shares it across all handlers,
455
+
/// avoiding expensive clones of the event data structure.
407
456
async fn dispatch_to_handlers(&self, event: JetstreamEvent) -> Result<()> {
408
457
let handlers = self.handlers.read().await;
458
+
let event = Arc::new(event);
409
459
410
460
for (handler_id, handler) in handlers.iter() {
411
461
let handler_span = tracing::debug_span!("handler_dispatch", handler_id = %handler_id);
462
+
let event_ref = Arc::clone(&event);
412
463
async {
413
-
if let Err(err) = handler.handle_event(event.clone()).await {
464
+
if let Err(err) = handler.handle_event(event_ref).await {
414
465
tracing::error!(
415
466
error = ?err,
416
467
handler_id = %handler_id,
···
440
491
441
492
#[async_trait]
442
493
impl EventHandler for LoggingHandler {
443
-
async fn handle_event(&self, _event: JetstreamEvent) -> Result<()> {
494
+
async fn handle_event(&self, _event: Arc<JetstreamEvent>) -> Result<()> {
444
495
Ok(())
445
496
}
446
497
447
-
fn handler_id(&self) -> String {
448
-
self.id.clone()
498
+
fn handler_id(&self) -> &str {
499
+
&self.id
449
500
}
450
501
}
451
502
+374
-5
crates/atproto-oauth/src/scopes.rs
+374
-5
crates/atproto-oauth/src/scopes.rs
···
38
38
Atproto,
39
39
/// Transition scope for migration operations
40
40
Transition(TransitionScope),
41
+
/// Include scope for referencing permission sets by NSID
42
+
Include(IncludeScope),
41
43
/// OpenID Connect scope - required for OpenID Connect authentication
42
44
OpenId,
43
45
/// Profile scope - access to user profile information
···
91
93
Generic,
92
94
/// Email transition operations
93
95
Email,
96
+
}
97
+
98
+
/// Include scope for referencing permission sets by NSID
99
+
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
100
+
pub struct IncludeScope {
101
+
/// The permission set NSID (e.g., "app.example.authFull")
102
+
pub nsid: String,
103
+
/// Optional audience DID for inherited RPC permissions
104
+
pub aud: Option<String>,
94
105
}
95
106
96
107
/// Blob scope with mime type constraints
···
310
321
"rpc",
311
322
"atproto",
312
323
"transition",
324
+
"include",
313
325
"openid",
314
326
"profile",
315
327
"email",
···
349
361
"rpc" => Self::parse_rpc(suffix),
350
362
"atproto" => Self::parse_atproto(suffix),
351
363
"transition" => Self::parse_transition(suffix),
364
+
"include" => Self::parse_include(suffix),
352
365
"openid" => Self::parse_openid(suffix),
353
366
"profile" => Self::parse_profile(suffix),
354
367
"email" => Self::parse_email(suffix),
···
573
586
Ok(Scope::Transition(scope))
574
587
}
575
588
589
+
fn parse_include(suffix: Option<&str>) -> Result<Self, ParseError> {
590
+
let (nsid, params) = match suffix {
591
+
Some(s) => {
592
+
if let Some(pos) = s.find('?') {
593
+
(&s[..pos], Some(&s[pos + 1..]))
594
+
} else {
595
+
(s, None)
596
+
}
597
+
}
598
+
None => return Err(ParseError::MissingResource),
599
+
};
600
+
601
+
if nsid.is_empty() {
602
+
return Err(ParseError::MissingResource);
603
+
}
604
+
605
+
let aud = if let Some(params) = params {
606
+
let parsed_params = parse_query_string(params);
607
+
parsed_params
608
+
.get("aud")
609
+
.and_then(|v| v.first())
610
+
.map(|s| url_decode(s))
611
+
} else {
612
+
None
613
+
};
614
+
615
+
Ok(Scope::Include(IncludeScope {
616
+
nsid: nsid.to_string(),
617
+
aud,
618
+
}))
619
+
}
620
+
576
621
fn parse_openid(suffix: Option<&str>) -> Result<Self, ParseError> {
577
622
if suffix.is_some() {
578
623
return Err(ParseError::InvalidResource(
···
677
722
if let Some(lxm) = scope.lxm.iter().next() {
678
723
match lxm {
679
724
RpcLexicon::All => "rpc:*".to_string(),
680
-
RpcLexicon::Nsid(nsid) => format!("rpc:{}", nsid),
725
+
RpcLexicon::Nsid(nsid) => format!("rpc:{}?aud=*", nsid),
726
+
}
727
+
} else {
728
+
"rpc:*".to_string()
729
+
}
730
+
} else if scope.lxm.len() == 1 && scope.aud.len() == 1 {
731
+
// Single lxm and single aud (aud is not All, handled above)
732
+
if let (Some(lxm), Some(aud)) =
733
+
(scope.lxm.iter().next(), scope.aud.iter().next())
734
+
{
735
+
match (lxm, aud) {
736
+
(RpcLexicon::Nsid(nsid), RpcAudience::Did(did)) => {
737
+
format!("rpc:{}?aud={}", nsid, did)
738
+
}
739
+
(RpcLexicon::All, RpcAudience::Did(did)) => {
740
+
format!("rpc:*?aud={}", did)
741
+
}
742
+
_ => "rpc:*".to_string(),
681
743
}
682
744
} else {
683
745
"rpc:*".to_string()
···
713
775
TransitionScope::Generic => "transition:generic".to_string(),
714
776
TransitionScope::Email => "transition:email".to_string(),
715
777
},
778
+
Scope::Include(scope) => {
779
+
if let Some(ref aud) = scope.aud {
780
+
format!("include:{}?aud={}", scope.nsid, url_encode(aud))
781
+
} else {
782
+
format!("include:{}", scope.nsid)
783
+
}
784
+
}
716
785
Scope::OpenId => "openid".to_string(),
717
786
Scope::Profile => "profile".to_string(),
718
787
Scope::Email => "email".to_string(),
···
732
801
// Other scopes don't grant transition scopes
733
802
(_, Scope::Transition(_)) => false,
734
803
(Scope::Transition(_), _) => false,
804
+
// Include scopes only grant themselves (exact match including aud)
805
+
(Scope::Include(a), Scope::Include(b)) => a == b,
806
+
// Other scopes don't grant include scopes
807
+
(_, Scope::Include(_)) => false,
808
+
(Scope::Include(_), _) => false,
735
809
// OpenID Connect scopes only grant themselves
736
810
(Scope::OpenId, Scope::OpenId) => true,
737
811
(Scope::OpenId, _) => false,
···
873
947
params
874
948
}
875
949
950
+
/// Decode a percent-encoded string
951
+
fn url_decode(s: &str) -> String {
952
+
let mut result = String::with_capacity(s.len());
953
+
let mut chars = s.chars().peekable();
954
+
955
+
while let Some(c) = chars.next() {
956
+
if c == '%' {
957
+
let hex: String = chars.by_ref().take(2).collect();
958
+
if hex.len() == 2 {
959
+
if let Ok(byte) = u8::from_str_radix(&hex, 16) {
960
+
result.push(byte as char);
961
+
continue;
962
+
}
963
+
}
964
+
result.push('%');
965
+
result.push_str(&hex);
966
+
} else {
967
+
result.push(c);
968
+
}
969
+
}
970
+
971
+
result
972
+
}
973
+
974
+
/// Encode a string for use in a URL query parameter
975
+
fn url_encode(s: &str) -> String {
976
+
let mut result = String::with_capacity(s.len() * 3);
977
+
978
+
for c in s.chars() {
979
+
match c {
980
+
'A'..='Z' | 'a'..='z' | '0'..='9' | '-' | '_' | '.' | '~' | ':' => {
981
+
result.push(c);
982
+
}
983
+
_ => {
984
+
for byte in c.to_string().as_bytes() {
985
+
result.push_str(&format!("%{:02X}", byte));
986
+
}
987
+
}
988
+
}
989
+
}
990
+
991
+
result
992
+
}
993
+
876
994
/// Error type for scope parsing
877
995
#[derive(Debug, Clone, PartialEq, Eq)]
878
996
pub enum ParseError {
···
1056
1174
("repo:foo.bar", "repo:foo.bar"),
1057
1175
("repo:foo.bar?action=create", "repo:foo.bar?action=create"),
1058
1176
("rpc:*", "rpc:*"),
1177
+
("rpc:com.example.service", "rpc:com.example.service?aud=*"),
1178
+
(
1179
+
"rpc:com.example.service?aud=did:example:123",
1180
+
"rpc:com.example.service?aud=did:example:123",
1181
+
),
1059
1182
];
1060
1183
1061
1184
for (input, expected) in tests {
···
1677
1800
1678
1801
// Test with complex scopes including query parameters
1679
1802
let scopes = vec![
1680
-
Scope::parse("rpc:com.example.service?aud=did:example:123&lxm=com.example.method")
1681
-
.unwrap(),
1803
+
Scope::parse("rpc:com.example.service?aud=did:example:123").unwrap(),
1682
1804
Scope::parse("repo:foo.bar?action=create&action=update").unwrap(),
1683
1805
Scope::parse("blob:image/*?accept=image/png&accept=image/jpeg").unwrap(),
1684
1806
];
1685
1807
let result = Scope::serialize_multiple(&scopes);
1686
1808
// The result should be sorted alphabetically
1687
-
// Note: RPC scope with query params is serialized as "rpc?aud=...&lxm=..."
1809
+
// Single lxm + single aud is serialized as "rpc:[lxm]?aud=[aud]"
1688
1810
assert!(result.starts_with("blob:"));
1689
1811
assert!(result.contains(" repo:"));
1690
-
assert!(result.contains("rpc?aud=did:example:123&lxm=com.example.service"));
1812
+
assert!(result.contains("rpc:com.example.service?aud=did:example:123"));
1691
1813
1692
1814
// Test with transition scopes
1693
1815
let scopes = vec![
···
1835
1957
assert!(!result.contains(&Scope::parse("account:email").unwrap()));
1836
1958
assert!(result.contains(&Scope::parse("account:email?action=manage").unwrap()));
1837
1959
assert!(result.contains(&Scope::parse("account:repo").unwrap()));
1960
+
}
1961
+
1962
+
#[test]
1963
+
fn test_repo_nsid_with_wildcard_suffix() {
1964
+
// Test parsing "repo:app.bsky.feed.*" - the asterisk is treated as a literal part of the NSID,
1965
+
// not as a wildcard pattern. Only "repo:*" has special wildcard behavior for ALL collections.
1966
+
let scope = Scope::parse("repo:app.bsky.feed.*").unwrap();
1967
+
1968
+
// Verify it parses as a specific NSID, not as a wildcard
1969
+
assert_eq!(
1970
+
scope,
1971
+
Scope::Repo(RepoScope {
1972
+
collection: RepoCollection::Nsid("app.bsky.feed.*".to_string()),
1973
+
actions: {
1974
+
let mut actions = BTreeSet::new();
1975
+
actions.insert(RepoAction::Create);
1976
+
actions.insert(RepoAction::Update);
1977
+
actions.insert(RepoAction::Delete);
1978
+
actions
1979
+
}
1980
+
})
1981
+
);
1982
+
1983
+
// Verify normalization preserves the literal NSID
1984
+
assert_eq!(scope.to_string_normalized(), "repo:app.bsky.feed.*");
1985
+
1986
+
// Test that it does NOT grant access to "app.bsky.feed.post"
1987
+
// (because "app.bsky.feed.*" is a literal NSID, not a pattern)
1988
+
let specific_feed = Scope::parse("repo:app.bsky.feed.post").unwrap();
1989
+
assert!(!scope.grants(&specific_feed));
1990
+
1991
+
// Test that only "repo:*" grants access to "app.bsky.feed.*"
1992
+
let repo_all = Scope::parse("repo:*").unwrap();
1993
+
assert!(repo_all.grants(&scope));
1994
+
1995
+
// Test that "repo:app.bsky.feed.*" only grants itself
1996
+
assert!(scope.grants(&scope));
1997
+
1998
+
// Test with actions
1999
+
let scope_with_create = Scope::parse("repo:app.bsky.feed.*?action=create").unwrap();
2000
+
assert_eq!(
2001
+
scope_with_create,
2002
+
Scope::Repo(RepoScope {
2003
+
collection: RepoCollection::Nsid("app.bsky.feed.*".to_string()),
2004
+
actions: {
2005
+
let mut actions = BTreeSet::new();
2006
+
actions.insert(RepoAction::Create);
2007
+
actions
2008
+
}
2009
+
})
2010
+
);
2011
+
2012
+
// The full scope (with all actions) grants the create-only scope
2013
+
assert!(scope.grants(&scope_with_create));
2014
+
// But the create-only scope does NOT grant the full scope
2015
+
assert!(!scope_with_create.grants(&scope));
2016
+
2017
+
// Test parsing multiple scopes with NSID wildcards
2018
+
let scopes = Scope::parse_multiple("repo:app.bsky.feed.* repo:app.bsky.graph.* repo:*").unwrap();
2019
+
assert_eq!(scopes.len(), 3);
2020
+
2021
+
// Test that parse_multiple_reduced properly reduces when "repo:*" is present
2022
+
let reduced = Scope::parse_multiple_reduced("repo:app.bsky.feed.* repo:app.bsky.graph.* repo:*").unwrap();
2023
+
assert_eq!(reduced.len(), 1);
2024
+
assert_eq!(reduced[0], repo_all);
2025
+
}
2026
+
2027
+
#[test]
2028
+
fn test_include_scope_parsing() {
2029
+
// Test basic include scope
2030
+
let scope = Scope::parse("include:app.example.authFull").unwrap();
2031
+
assert_eq!(
2032
+
scope,
2033
+
Scope::Include(IncludeScope {
2034
+
nsid: "app.example.authFull".to_string(),
2035
+
aud: None,
2036
+
})
2037
+
);
2038
+
2039
+
// Test include scope with audience
2040
+
let scope = Scope::parse("include:app.example.authFull?aud=did:web:api.example.com").unwrap();
2041
+
assert_eq!(
2042
+
scope,
2043
+
Scope::Include(IncludeScope {
2044
+
nsid: "app.example.authFull".to_string(),
2045
+
aud: Some("did:web:api.example.com".to_string()),
2046
+
})
2047
+
);
2048
+
2049
+
// Test include scope with URL-encoded audience (with fragment)
2050
+
let scope = Scope::parse("include:app.example.authFull?aud=did:web:api.example.com%23svc_chat").unwrap();
2051
+
assert_eq!(
2052
+
scope,
2053
+
Scope::Include(IncludeScope {
2054
+
nsid: "app.example.authFull".to_string(),
2055
+
aud: Some("did:web:api.example.com#svc_chat".to_string()),
2056
+
})
2057
+
);
2058
+
2059
+
// Test missing NSID
2060
+
assert!(matches!(
2061
+
Scope::parse("include"),
2062
+
Err(ParseError::MissingResource)
2063
+
));
2064
+
2065
+
// Test empty NSID with query params
2066
+
assert!(matches!(
2067
+
Scope::parse("include:?aud=did:example:123"),
2068
+
Err(ParseError::MissingResource)
2069
+
));
2070
+
}
2071
+
2072
+
#[test]
2073
+
fn test_include_scope_normalization() {
2074
+
// Test normalization without audience
2075
+
let scope = Scope::parse("include:com.example.authBasic").unwrap();
2076
+
assert_eq!(scope.to_string_normalized(), "include:com.example.authBasic");
2077
+
2078
+
// Test normalization with audience (no special chars)
2079
+
let scope = Scope::parse("include:com.example.authBasic?aud=did:plc:xyz123").unwrap();
2080
+
assert_eq!(
2081
+
scope.to_string_normalized(),
2082
+
"include:com.example.authBasic?aud=did:plc:xyz123"
2083
+
);
2084
+
2085
+
// Test normalization with URL encoding (fragment needs encoding)
2086
+
let scope = Scope::parse("include:app.example.authFull?aud=did:web:api.example.com%23svc_chat").unwrap();
2087
+
let normalized = scope.to_string_normalized();
2088
+
assert_eq!(
2089
+
normalized,
2090
+
"include:app.example.authFull?aud=did:web:api.example.com%23svc_chat"
2091
+
);
2092
+
}
2093
+
2094
+
#[test]
2095
+
fn test_include_scope_grants() {
2096
+
let include1 = Scope::parse("include:app.example.authFull").unwrap();
2097
+
let include2 = Scope::parse("include:app.example.authBasic").unwrap();
2098
+
let include1_with_aud = Scope::parse("include:app.example.authFull?aud=did:plc:xyz").unwrap();
2099
+
let account = Scope::parse("account:email").unwrap();
2100
+
2101
+
// Include scopes only grant themselves (exact match)
2102
+
assert!(include1.grants(&include1));
2103
+
assert!(!include1.grants(&include2));
2104
+
assert!(!include1.grants(&include1_with_aud)); // Different because aud differs
2105
+
assert!(include1_with_aud.grants(&include1_with_aud));
2106
+
2107
+
// Include scopes don't grant other scope types
2108
+
assert!(!include1.grants(&account));
2109
+
assert!(!account.grants(&include1));
2110
+
2111
+
// Include scopes don't grant atproto or transition
2112
+
let atproto = Scope::parse("atproto").unwrap();
2113
+
let transition = Scope::parse("transition:generic").unwrap();
2114
+
assert!(!include1.grants(&atproto));
2115
+
assert!(!include1.grants(&transition));
2116
+
assert!(!atproto.grants(&include1));
2117
+
assert!(!transition.grants(&include1));
2118
+
}
2119
+
2120
+
#[test]
2121
+
fn test_parse_multiple_with_include() {
2122
+
let scopes = Scope::parse_multiple("atproto include:app.example.auth repo:*").unwrap();
2123
+
assert_eq!(scopes.len(), 3);
2124
+
assert_eq!(scopes[0], Scope::Atproto);
2125
+
assert!(matches!(scopes[1], Scope::Include(_)));
2126
+
assert!(matches!(scopes[2], Scope::Repo(_)));
2127
+
2128
+
// Test with URL-encoded audience
2129
+
let scopes = Scope::parse_multiple(
2130
+
"include:app.example.auth?aud=did:web:api.example.com%23svc account:email"
2131
+
).unwrap();
2132
+
assert_eq!(scopes.len(), 2);
2133
+
if let Scope::Include(inc) = &scopes[0] {
2134
+
assert_eq!(inc.nsid, "app.example.auth");
2135
+
assert_eq!(inc.aud, Some("did:web:api.example.com#svc".to_string()));
2136
+
} else {
2137
+
panic!("Expected Include scope");
2138
+
}
2139
+
}
2140
+
2141
+
#[test]
2142
+
fn test_parse_multiple_reduced_with_include() {
2143
+
// Include scopes don't reduce each other (each is distinct)
2144
+
let scopes = Scope::parse_multiple_reduced(
2145
+
"include:app.example.auth include:app.example.other include:app.example.auth"
2146
+
).unwrap();
2147
+
assert_eq!(scopes.len(), 2); // Duplicates are removed
2148
+
assert!(scopes.contains(&Scope::Include(IncludeScope {
2149
+
nsid: "app.example.auth".to_string(),
2150
+
aud: None,
2151
+
})));
2152
+
assert!(scopes.contains(&Scope::Include(IncludeScope {
2153
+
nsid: "app.example.other".to_string(),
2154
+
aud: None,
2155
+
})));
2156
+
2157
+
// Include scopes with different audiences are not duplicates
2158
+
let scopes = Scope::parse_multiple_reduced(
2159
+
"include:app.example.auth include:app.example.auth?aud=did:plc:xyz"
2160
+
).unwrap();
2161
+
assert_eq!(scopes.len(), 2);
2162
+
}
2163
+
2164
+
#[test]
2165
+
fn test_serialize_multiple_with_include() {
2166
+
let scopes = vec![
2167
+
Scope::parse("repo:*").unwrap(),
2168
+
Scope::parse("include:app.example.authFull").unwrap(),
2169
+
Scope::Atproto,
2170
+
];
2171
+
let result = Scope::serialize_multiple(&scopes);
2172
+
assert_eq!(result, "atproto include:app.example.authFull repo:*");
2173
+
2174
+
// Test with URL-encoded audience
2175
+
let scopes = vec![
2176
+
Scope::Include(IncludeScope {
2177
+
nsid: "app.example.auth".to_string(),
2178
+
aud: Some("did:web:api.example.com#svc".to_string()),
2179
+
}),
2180
+
];
2181
+
let result = Scope::serialize_multiple(&scopes);
2182
+
assert_eq!(result, "include:app.example.auth?aud=did:web:api.example.com%23svc");
2183
+
}
2184
+
2185
+
#[test]
2186
+
fn test_remove_scope_with_include() {
2187
+
let scopes = vec![
2188
+
Scope::Atproto,
2189
+
Scope::parse("include:app.example.auth").unwrap(),
2190
+
Scope::parse("account:email").unwrap(),
2191
+
];
2192
+
let to_remove = Scope::parse("include:app.example.auth").unwrap();
2193
+
let result = Scope::remove_scope(&scopes, &to_remove);
2194
+
assert_eq!(result.len(), 2);
2195
+
assert!(!result.contains(&to_remove));
2196
+
assert!(result.contains(&Scope::Atproto));
2197
+
}
2198
+
2199
+
#[test]
2200
+
fn test_include_scope_roundtrip() {
2201
+
// Test that parse and serialize are inverses
2202
+
let original = "include:com.example.authBasicFeatures?aud=did:web:api.example.com%23svc_appview";
2203
+
let scope = Scope::parse(original).unwrap();
2204
+
let serialized = scope.to_string_normalized();
2205
+
let reparsed = Scope::parse(&serialized).unwrap();
2206
+
assert_eq!(scope, reparsed);
1838
2207
}
1839
2208
}
+53
crates/atproto-tap/Cargo.toml
+53
crates/atproto-tap/Cargo.toml
···
1
+
[package]
2
+
name = "atproto-tap"
3
+
version = "0.13.0"
4
+
description = "AT Protocol TAP (Trusted Attestation Protocol) service consumer"
5
+
readme = "README.md"
6
+
homepage = "https://tangled.sh/@smokesignal.events/atproto-identity-rs"
7
+
documentation = "https://docs.rs/atproto-tap"
8
+
9
+
edition.workspace = true
10
+
rust-version.workspace = true
11
+
authors.workspace = true
12
+
repository.workspace = true
13
+
license.workspace = true
14
+
keywords.workspace = true
15
+
categories.workspace = true
16
+
17
+
[dependencies]
18
+
tokio = { workspace = true, features = ["sync", "time"] }
19
+
tokio-stream = "0.1"
20
+
tokio-websockets = { workspace = true }
21
+
futures = { workspace = true }
22
+
reqwest = { workspace = true }
23
+
serde = { workspace = true }
24
+
serde_json = { workspace = true }
25
+
thiserror = { workspace = true }
26
+
tracing = { workspace = true }
27
+
http = { workspace = true }
28
+
base64 = { workspace = true }
29
+
atproto-identity.workspace = true
30
+
atproto-client = { workspace = true, optional = true }
31
+
32
+
# Memory efficiency
33
+
compact_str = { version = "0.8", features = ["serde"] }
34
+
itoa = "1.0"
35
+
36
+
# Optional for CLI
37
+
clap = { workspace = true, optional = true }
38
+
tracing-subscriber = { version = "0.3", features = ["env-filter"], optional = true }
39
+
40
+
[features]
41
+
default = []
42
+
clap = ["dep:clap", "dep:tracing-subscriber", "dep:atproto-client", "tokio/rt-multi-thread", "tokio/macros", "tokio/signal"]
43
+
44
+
[[bin]]
45
+
name = "atproto-tap-client"
46
+
required-features = ["clap"]
47
+
48
+
[[bin]]
49
+
name = "atproto-tap-extras"
50
+
required-features = ["clap"]
51
+
52
+
[lints]
53
+
workspace = true
+351
crates/atproto-tap/src/bin/atproto-tap-client.rs
+351
crates/atproto-tap/src/bin/atproto-tap-client.rs
···
1
+
//! Command-line client for TAP services.
2
+
//!
3
+
//! This tool provides commands for consuming TAP events and managing tracked repositories.
4
+
//!
5
+
//! # Usage
6
+
//!
7
+
//! ```bash
8
+
//! # Stream events from a TAP service
9
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 read
10
+
//!
11
+
//! # Stream with authentication and filters
12
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret read --live-only
13
+
//!
14
+
//! # Add repositories to track
15
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret repos add did:plc:xyz did:plc:abc
16
+
//!
17
+
//! # Remove repositories from tracking
18
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 -p secret repos remove did:plc:xyz
19
+
//!
20
+
//! # Resolve a DID to its DID document
21
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 resolve did:plc:xyz
22
+
//!
23
+
//! # Resolve a DID and only output the handle
24
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 resolve did:plc:xyz --handle-only
25
+
//!
26
+
//! # Get repository tracking info
27
+
//! cargo run --features cli --bin atproto-tap-client -- localhost:2480 info did:plc:xyz
28
+
//! ```
29
+
30
+
use atproto_tap::{TapClient, TapConfig, TapEvent, connect};
31
+
use clap::{Parser, Subcommand};
32
+
use std::time::Duration;
33
+
use tokio_stream::StreamExt;
34
+
35
+
/// TAP service client for consuming events and managing repositories.
36
+
#[derive(Parser)]
37
+
#[command(
38
+
name = "atproto-tap-client",
39
+
version,
40
+
about = "TAP service client for AT Protocol",
41
+
long_about = "Connect to a TAP service to stream repository/identity events or manage tracked repositories.\n\n\
42
+
Events are printed to stdout as JSON, one per line.\n\
43
+
Use Ctrl+C to gracefully stop the consumer."
44
+
)]
45
+
struct Args {
46
+
/// TAP service hostname (e.g., localhost:2480)
47
+
hostname: String,
48
+
49
+
/// Admin password for authentication
50
+
#[arg(short, long, global = true)]
51
+
password: Option<String>,
52
+
53
+
#[command(subcommand)]
54
+
command: Command,
55
+
}
56
+
57
+
#[derive(Subcommand)]
58
+
enum Command {
59
+
/// Connect to TAP and stream events as JSON
60
+
Read {
61
+
/// Disable acknowledgments
62
+
#[arg(long)]
63
+
no_acks: bool,
64
+
65
+
/// Maximum reconnection attempts (0 = unlimited)
66
+
#[arg(long, default_value = "0")]
67
+
max_reconnects: u32,
68
+
69
+
/// Print debug information to stderr
70
+
#[arg(short, long)]
71
+
debug: bool,
72
+
73
+
/// Filter to specific collections (comma-separated)
74
+
#[arg(long)]
75
+
collections: Option<String>,
76
+
77
+
/// Only show live events (skip backfill)
78
+
#[arg(long)]
79
+
live_only: bool,
80
+
},
81
+
82
+
/// Manage tracked repositories
83
+
Repos {
84
+
#[command(subcommand)]
85
+
action: ReposAction,
86
+
},
87
+
88
+
/// Resolve a DID to its DID document
89
+
Resolve {
90
+
/// DID to resolve (e.g., did:plc:xyz123)
91
+
did: String,
92
+
93
+
/// Only output the handle (instead of full DID document)
94
+
#[arg(long)]
95
+
handle_only: bool,
96
+
},
97
+
98
+
/// Get tracking info for a repository
99
+
Info {
100
+
/// DID to get info for (e.g., did:plc:xyz123)
101
+
did: String,
102
+
},
103
+
}
104
+
105
+
#[derive(Subcommand)]
106
+
enum ReposAction {
107
+
/// Add repositories to track
108
+
Add {
109
+
/// DIDs to add (e.g., did:plc:xyz123)
110
+
#[arg(required = true)]
111
+
dids: Vec<String>,
112
+
},
113
+
114
+
/// Remove repositories from tracking
115
+
Remove {
116
+
/// DIDs to remove
117
+
#[arg(required = true)]
118
+
dids: Vec<String>,
119
+
},
120
+
}
121
+
122
+
#[tokio::main]
123
+
async fn main() {
124
+
let args = Args::parse();
125
+
126
+
match args.command {
127
+
Command::Read {
128
+
no_acks,
129
+
max_reconnects,
130
+
debug,
131
+
collections,
132
+
live_only,
133
+
} => {
134
+
run_read(
135
+
&args.hostname,
136
+
args.password,
137
+
no_acks,
138
+
max_reconnects,
139
+
debug,
140
+
collections,
141
+
live_only,
142
+
)
143
+
.await;
144
+
}
145
+
Command::Repos { action } => {
146
+
run_repos(&args.hostname, args.password, action).await;
147
+
}
148
+
Command::Resolve { did, handle_only } => {
149
+
run_resolve(&args.hostname, args.password, &did, handle_only).await;
150
+
}
151
+
Command::Info { did } => {
152
+
run_info(&args.hostname, args.password, &did).await;
153
+
}
154
+
}
155
+
}
156
+
157
+
async fn run_read(
158
+
hostname: &str,
159
+
password: Option<String>,
160
+
no_acks: bool,
161
+
max_reconnects: u32,
162
+
debug: bool,
163
+
collections: Option<String>,
164
+
live_only: bool,
165
+
) {
166
+
// Initialize tracing if debug mode
167
+
if debug {
168
+
tracing_subscriber::fmt()
169
+
.with_env_filter("atproto_tap=debug")
170
+
.with_writer(std::io::stderr)
171
+
.init();
172
+
}
173
+
174
+
// Build configuration
175
+
let mut config_builder = TapConfig::builder()
176
+
.hostname(hostname)
177
+
.send_acks(!no_acks);
178
+
179
+
if let Some(password) = password {
180
+
config_builder = config_builder.admin_password(password);
181
+
}
182
+
183
+
if max_reconnects > 0 {
184
+
config_builder = config_builder.max_reconnect_attempts(Some(max_reconnects));
185
+
}
186
+
187
+
// Set reasonable defaults for CLI usage
188
+
config_builder = config_builder
189
+
.initial_reconnect_delay(Duration::from_secs(1))
190
+
.max_reconnect_delay(Duration::from_secs(30));
191
+
192
+
let config = config_builder.build();
193
+
194
+
eprintln!("Connecting to TAP service at {}...", hostname);
195
+
196
+
let mut stream = connect(config);
197
+
198
+
// Parse collection filters
199
+
let collection_filters: Vec<String> = collections
200
+
.map(|c| c.split(',').map(|s| s.trim().to_string()).collect())
201
+
.unwrap_or_default();
202
+
203
+
// Handle Ctrl+C
204
+
let ctrl_c = tokio::signal::ctrl_c();
205
+
tokio::pin!(ctrl_c);
206
+
207
+
loop {
208
+
tokio::select! {
209
+
Some(result) = stream.next() => {
210
+
match result {
211
+
Ok(event) => {
212
+
// Apply filters
213
+
let should_print = match event.as_ref() {
214
+
TapEvent::Record { record, .. } => {
215
+
// Filter by live flag
216
+
if live_only && !record.live {
217
+
false
218
+
}
219
+
// Filter by collection
220
+
else if !collection_filters.is_empty() {
221
+
collection_filters.iter().any(|c| record.collection.as_ref() == c)
222
+
} else {
223
+
true
224
+
}
225
+
}
226
+
TapEvent::Identity { .. } => !live_only, // Always show identity unless live_only
227
+
};
228
+
229
+
if should_print {
230
+
// Print as JSON to stdout
231
+
match serde_json::to_string(event.as_ref()) {
232
+
Ok(json) => println!("{}", json),
233
+
Err(e) => {
234
+
eprintln!("Failed to serialize event: {}", e);
235
+
}
236
+
}
237
+
}
238
+
}
239
+
Err(e) => {
240
+
eprintln!("Error: {}", e);
241
+
242
+
// Exit on fatal errors
243
+
if e.is_fatal() {
244
+
eprintln!("Fatal error, exiting");
245
+
std::process::exit(1);
246
+
}
247
+
}
248
+
}
249
+
}
250
+
_ = &mut ctrl_c => {
251
+
eprintln!("\nReceived Ctrl+C, shutting down...");
252
+
stream.close().await;
253
+
break;
254
+
}
255
+
}
256
+
}
257
+
258
+
eprintln!("Client stopped");
259
+
}
260
+
261
+
async fn run_repos(hostname: &str, password: Option<String>, action: ReposAction) {
262
+
let client = TapClient::new(hostname, password);
263
+
264
+
match action {
265
+
ReposAction::Add { dids } => {
266
+
let did_refs: Vec<&str> = dids.iter().map(|s| s.as_str()).collect();
267
+
268
+
match client.add_repos(&did_refs).await {
269
+
Ok(()) => {
270
+
eprintln!("Added {} repository(ies) to tracking", dids.len());
271
+
for did in &dids {
272
+
println!("{}", did);
273
+
}
274
+
}
275
+
Err(e) => {
276
+
eprintln!("Failed to add repositories: {}", e);
277
+
std::process::exit(1);
278
+
}
279
+
}
280
+
}
281
+
ReposAction::Remove { dids } => {
282
+
let did_refs: Vec<&str> = dids.iter().map(|s| s.as_str()).collect();
283
+
284
+
match client.remove_repos(&did_refs).await {
285
+
Ok(()) => {
286
+
eprintln!("Removed {} repository(ies) from tracking", dids.len());
287
+
for did in &dids {
288
+
println!("{}", did);
289
+
}
290
+
}
291
+
Err(e) => {
292
+
eprintln!("Failed to remove repositories: {}", e);
293
+
std::process::exit(1);
294
+
}
295
+
}
296
+
}
297
+
}
298
+
}
299
+
300
+
async fn run_resolve(hostname: &str, password: Option<String>, did: &str, handle_only: bool) {
301
+
let client = TapClient::new(hostname, password);
302
+
303
+
match client.resolve(did).await {
304
+
Ok(doc) => {
305
+
if handle_only {
306
+
// Use the handles() method from atproto_identity::model::Document
307
+
match doc.handles() {
308
+
Some(handle) => println!("{}", handle),
309
+
None => {
310
+
eprintln!("No handle found in DID document");
311
+
std::process::exit(1);
312
+
}
313
+
}
314
+
} else {
315
+
// Print full DID document as JSON
316
+
match serde_json::to_string_pretty(&doc) {
317
+
Ok(json) => println!("{}", json),
318
+
Err(e) => {
319
+
eprintln!("Failed to serialize DID document: {}", e);
320
+
std::process::exit(1);
321
+
}
322
+
}
323
+
}
324
+
}
325
+
Err(e) => {
326
+
eprintln!("Failed to resolve DID: {}", e);
327
+
std::process::exit(1);
328
+
}
329
+
}
330
+
}
331
+
332
+
async fn run_info(hostname: &str, password: Option<String>, did: &str) {
333
+
let client = TapClient::new(hostname, password);
334
+
335
+
match client.info(did).await {
336
+
Ok(info) => {
337
+
// Print as JSON for easy parsing
338
+
match serde_json::to_string_pretty(&info) {
339
+
Ok(json) => println!("{}", json),
340
+
Err(e) => {
341
+
eprintln!("Failed to serialize info: {}", e);
342
+
std::process::exit(1);
343
+
}
344
+
}
345
+
}
346
+
Err(e) => {
347
+
eprintln!("Failed to get repository info: {}", e);
348
+
std::process::exit(1);
349
+
}
350
+
}
351
+
}
+214
crates/atproto-tap/src/bin/atproto-tap-extras.rs
+214
crates/atproto-tap/src/bin/atproto-tap-extras.rs
···
1
+
//! Additional TAP client utilities for AT Protocol.
2
+
//!
3
+
//! This tool provides extra commands for managing TAP tracked repositories
4
+
//! based on social graph data.
5
+
//!
6
+
//! # Usage
7
+
//!
8
+
//! ```bash
9
+
//! # Add all accounts followed by a DID to TAP tracking
10
+
//! cargo run --features cli --bin atproto-tap-extras -- localhost:2480 repos-add-followers did:plc:xyz
11
+
//!
12
+
//! # With authentication
13
+
//! cargo run --features cli --bin atproto-tap-extras -- localhost:2480 -p secret repos-add-followers did:plc:xyz
14
+
//! ```
15
+
16
+
use atproto_client::client::Auth;
17
+
use atproto_client::com::atproto::repo::{ListRecordsParams, list_records};
18
+
use atproto_identity::plc::query as plc_query;
19
+
use atproto_tap::TapClient;
20
+
use clap::{Parser, Subcommand};
21
+
use serde::Deserialize;
22
+
23
+
/// TAP extras utility for managing tracked repositories.
24
+
#[derive(Parser)]
25
+
#[command(
26
+
name = "atproto-tap-extras",
27
+
version,
28
+
about = "TAP extras utility for AT Protocol",
29
+
long_about = "Additional utilities for managing TAP tracked repositories based on social graph data."
30
+
)]
31
+
struct Args {
32
+
/// TAP service hostname (e.g., localhost:2480)
33
+
hostname: String,
34
+
35
+
/// Admin password for TAP authentication
36
+
#[arg(short, long, global = true)]
37
+
password: Option<String>,
38
+
39
+
/// PLC directory hostname for DID resolution
40
+
#[arg(long, default_value = "plc.directory", global = true)]
41
+
plc_hostname: String,
42
+
43
+
#[command(subcommand)]
44
+
command: Command,
45
+
}
46
+
47
+
#[derive(Subcommand)]
48
+
enum Command {
49
+
/// Add accounts followed by a DID to TAP tracking.
50
+
///
51
+
/// Fetches all app.bsky.graph.follow records from the specified DID's repository
52
+
/// and adds the followed DIDs to TAP for tracking.
53
+
ReposAddFollowers {
54
+
/// DID to read followers from (e.g., did:plc:xyz123)
55
+
did: String,
56
+
57
+
/// Batch size for adding repos to TAP
58
+
#[arg(long, default_value = "100")]
59
+
batch_size: usize,
60
+
61
+
/// Dry run - print DIDs without adding to TAP
62
+
#[arg(long)]
63
+
dry_run: bool,
64
+
},
65
+
}
66
+
67
+
/// Follow record structure from app.bsky.graph.follow.
68
+
#[derive(Debug, Deserialize)]
69
+
struct FollowRecord {
70
+
/// The DID of the account being followed.
71
+
subject: String,
72
+
}
73
+
74
+
#[tokio::main]
75
+
async fn main() {
76
+
let args = Args::parse();
77
+
78
+
match args.command {
79
+
Command::ReposAddFollowers {
80
+
did,
81
+
batch_size,
82
+
dry_run,
83
+
} => {
84
+
run_repos_add_followers(
85
+
&args.hostname,
86
+
args.password,
87
+
&args.plc_hostname,
88
+
&did,
89
+
batch_size,
90
+
dry_run,
91
+
)
92
+
.await;
93
+
}
94
+
}
95
+
}
96
+
97
+
async fn run_repos_add_followers(
98
+
tap_hostname: &str,
99
+
tap_password: Option<String>,
100
+
plc_hostname: &str,
101
+
did: &str,
102
+
batch_size: usize,
103
+
dry_run: bool,
104
+
) {
105
+
let http_client = reqwest::Client::new();
106
+
107
+
// Resolve the DID to get the PDS endpoint
108
+
eprintln!("Resolving DID: {}", did);
109
+
let document = match plc_query(&http_client, plc_hostname, did).await {
110
+
Ok(doc) => doc,
111
+
Err(e) => {
112
+
eprintln!("Failed to resolve DID: {}", e);
113
+
std::process::exit(1);
114
+
}
115
+
};
116
+
117
+
let pds_endpoints = document.pds_endpoints();
118
+
if pds_endpoints.is_empty() {
119
+
eprintln!("No PDS endpoint found in DID document");
120
+
std::process::exit(1);
121
+
}
122
+
let pds_url = pds_endpoints[0];
123
+
eprintln!("Using PDS: {}", pds_url);
124
+
125
+
// Collect all followed DIDs
126
+
let mut followed_dids: Vec<String> = Vec::new();
127
+
let mut cursor: Option<String> = None;
128
+
let collection = "app.bsky.graph.follow".to_string();
129
+
130
+
eprintln!("Fetching follow records...");
131
+
132
+
loop {
133
+
let params = if let Some(c) = cursor.take() {
134
+
ListRecordsParams::new().limit(100).cursor(c)
135
+
} else {
136
+
ListRecordsParams::new().limit(100)
137
+
};
138
+
139
+
let response = match list_records::<FollowRecord>(
140
+
&http_client,
141
+
&Auth::None,
142
+
pds_url,
143
+
did.to_string(),
144
+
collection.clone(),
145
+
params,
146
+
)
147
+
.await
148
+
{
149
+
Ok(resp) => resp,
150
+
Err(e) => {
151
+
eprintln!("Failed to list records: {}", e);
152
+
std::process::exit(1);
153
+
}
154
+
};
155
+
156
+
for record in &response.records {
157
+
followed_dids.push(record.value.subject.clone());
158
+
}
159
+
160
+
eprintln!(
161
+
" Fetched {} records (total: {})",
162
+
response.records.len(),
163
+
followed_dids.len()
164
+
);
165
+
166
+
match response.cursor {
167
+
Some(c) if !response.records.is_empty() => {
168
+
cursor = Some(c);
169
+
}
170
+
_ => break,
171
+
}
172
+
}
173
+
174
+
if followed_dids.is_empty() {
175
+
eprintln!("No follow records found");
176
+
return;
177
+
}
178
+
179
+
eprintln!("Found {} followed accounts", followed_dids.len());
180
+
181
+
if dry_run {
182
+
eprintln!("\nDry run - would add these DIDs to TAP:");
183
+
for did in &followed_dids {
184
+
println!("{}", did);
185
+
}
186
+
return;
187
+
}
188
+
189
+
// Add to TAP in batches
190
+
let tap_client = TapClient::new(tap_hostname, tap_password);
191
+
let mut added = 0;
192
+
193
+
for chunk in followed_dids.chunks(batch_size) {
194
+
let did_refs: Vec<&str> = chunk.iter().map(|s| s.as_str()).collect();
195
+
196
+
match tap_client.add_repos(&did_refs).await {
197
+
Ok(()) => {
198
+
added += chunk.len();
199
+
eprintln!("Added {} DIDs to TAP (total: {})", chunk.len(), added);
200
+
}
201
+
Err(e) => {
202
+
eprintln!("Failed to add repos to TAP: {}", e);
203
+
std::process::exit(1);
204
+
}
205
+
}
206
+
}
207
+
208
+
eprintln!("Successfully added {} DIDs to TAP", added);
209
+
210
+
// Print all added DIDs
211
+
for did in &followed_dids {
212
+
println!("{}", did);
213
+
}
214
+
}
+371
crates/atproto-tap/src/client.rs
+371
crates/atproto-tap/src/client.rs
···
1
+
//! HTTP client for TAP management API.
2
+
//!
3
+
//! This module provides [`TapClient`] for interacting with the TAP service's
4
+
//! HTTP management endpoints, including adding/removing tracked repositories.
5
+
6
+
use crate::errors::TapError;
7
+
use atproto_identity::model::Document;
8
+
use base64::Engine;
9
+
use base64::engine::general_purpose::STANDARD as BASE64;
10
+
use reqwest::header::{AUTHORIZATION, CONTENT_TYPE, HeaderMap, HeaderValue};
11
+
use serde::{Deserialize, Serialize};
12
+
13
+
/// HTTP client for TAP management API.
14
+
///
15
+
/// Provides methods for managing which repositories the TAP service tracks,
16
+
/// checking service health, and querying repository status.
17
+
///
18
+
/// # Example
19
+
///
20
+
/// ```ignore
21
+
/// use atproto_tap::TapClient;
22
+
///
23
+
/// let client = TapClient::new("localhost:2480", Some("admin_password".to_string()));
24
+
///
25
+
/// // Add repositories to track
26
+
/// client.add_repos(&["did:plc:xyz123", "did:plc:abc456"]).await?;
27
+
///
28
+
/// // Check health
29
+
/// if client.health().await? {
30
+
/// println!("TAP service is healthy");
31
+
/// }
32
+
/// ```
33
+
#[derive(Debug, Clone)]
34
+
pub struct TapClient {
35
+
http_client: reqwest::Client,
36
+
base_url: String,
37
+
auth_header: Option<HeaderValue>,
38
+
}
39
+
40
+
impl TapClient {
41
+
/// Create a new TAP management client.
42
+
///
43
+
/// # Arguments
44
+
///
45
+
/// * `hostname` - TAP service hostname (e.g., "localhost:2480")
46
+
/// * `admin_password` - Optional admin password for authentication
47
+
pub fn new(hostname: &str, admin_password: Option<String>) -> Self {
48
+
let auth_header = admin_password.map(|password| {
49
+
let credentials = format!("admin:{}", password);
50
+
let encoded = BASE64.encode(credentials.as_bytes());
51
+
HeaderValue::from_str(&format!("Basic {}", encoded))
52
+
.expect("Invalid auth header value")
53
+
});
54
+
55
+
Self {
56
+
http_client: reqwest::Client::new(),
57
+
base_url: format!("http://{}", hostname),
58
+
auth_header,
59
+
}
60
+
}
61
+
62
+
/// Create default headers for requests.
63
+
fn default_headers(&self) -> HeaderMap {
64
+
let mut headers = HeaderMap::new();
65
+
headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
66
+
if let Some(auth) = &self.auth_header {
67
+
headers.insert(AUTHORIZATION, auth.clone());
68
+
}
69
+
headers
70
+
}
71
+
72
+
/// Add repositories to track.
73
+
///
74
+
/// Sends a POST request to `/repos/add` with the list of DIDs.
75
+
///
76
+
/// # Arguments
77
+
///
78
+
/// * `dids` - Slice of DID strings to track
79
+
///
80
+
/// # Example
81
+
///
82
+
/// ```ignore
83
+
/// client.add_repos(&[
84
+
/// "did:plc:z72i7hdynmk6r22z27h6tvur",
85
+
/// "did:plc:ewvi7nxzyoun6zhxrhs64oiz",
86
+
/// ]).await?;
87
+
/// ```
88
+
pub async fn add_repos(&self, dids: &[&str]) -> Result<(), TapError> {
89
+
let url = format!("{}/repos/add", self.base_url);
90
+
let body = AddReposRequest {
91
+
dids: dids.iter().map(|s| s.to_string()).collect(),
92
+
};
93
+
94
+
let response = self
95
+
.http_client
96
+
.post(&url)
97
+
.headers(self.default_headers())
98
+
.json(&body)
99
+
.send()
100
+
.await?;
101
+
102
+
if response.status().is_success() {
103
+
tracing::debug!(count = dids.len(), "Added repositories to TAP");
104
+
Ok(())
105
+
} else {
106
+
let status = response.status().as_u16();
107
+
let message = response.text().await.unwrap_or_default();
108
+
Err(TapError::HttpResponseError { status, message })
109
+
}
110
+
}
111
+
112
+
/// Remove repositories from tracking.
113
+
///
114
+
/// Sends a POST request to `/repos/remove` with the list of DIDs.
115
+
///
116
+
/// # Arguments
117
+
///
118
+
/// * `dids` - Slice of DID strings to stop tracking
119
+
pub async fn remove_repos(&self, dids: &[&str]) -> Result<(), TapError> {
120
+
let url = format!("{}/repos/remove", self.base_url);
121
+
let body = AddReposRequest {
122
+
dids: dids.iter().map(|s| s.to_string()).collect(),
123
+
};
124
+
125
+
let response = self
126
+
.http_client
127
+
.post(&url)
128
+
.headers(self.default_headers())
129
+
.json(&body)
130
+
.send()
131
+
.await?;
132
+
133
+
if response.status().is_success() {
134
+
tracing::debug!(count = dids.len(), "Removed repositories from TAP");
135
+
Ok(())
136
+
} else {
137
+
let status = response.status().as_u16();
138
+
let message = response.text().await.unwrap_or_default();
139
+
Err(TapError::HttpResponseError { status, message })
140
+
}
141
+
}
142
+
143
+
/// Check service health.
144
+
///
145
+
/// Sends a GET request to `/health`.
146
+
///
147
+
/// # Returns
148
+
///
149
+
/// `true` if the service is healthy, `false` otherwise.
150
+
pub async fn health(&self) -> Result<bool, TapError> {
151
+
let url = format!("{}/health", self.base_url);
152
+
153
+
let response = self
154
+
.http_client
155
+
.get(&url)
156
+
.headers(self.default_headers())
157
+
.send()
158
+
.await?;
159
+
160
+
Ok(response.status().is_success())
161
+
}
162
+
163
+
/// Resolve a DID to its DID document.
164
+
///
165
+
/// Sends a GET request to `/resolve/:did`.
166
+
///
167
+
/// # Arguments
168
+
///
169
+
/// * `did` - The DID to resolve
170
+
///
171
+
/// # Returns
172
+
///
173
+
/// The DID document for the identity.
174
+
pub async fn resolve(&self, did: &str) -> Result<Document, TapError> {
175
+
let url = format!("{}/resolve/{}", self.base_url, did);
176
+
177
+
let response = self
178
+
.http_client
179
+
.get(&url)
180
+
.headers(self.default_headers())
181
+
.send()
182
+
.await?;
183
+
184
+
if response.status().is_success() {
185
+
let doc: Document = response.json().await?;
186
+
Ok(doc)
187
+
} else {
188
+
let status = response.status().as_u16();
189
+
let message = response.text().await.unwrap_or_default();
190
+
Err(TapError::HttpResponseError { status, message })
191
+
}
192
+
}
193
+
194
+
/// Get info about a tracked repository.
195
+
///
196
+
/// Sends a GET request to `/info/:did`.
197
+
///
198
+
/// # Arguments
199
+
///
200
+
/// * `did` - The DID to get info for
201
+
///
202
+
/// # Returns
203
+
///
204
+
/// Repository tracking information.
205
+
pub async fn info(&self, did: &str) -> Result<RepoInfo, TapError> {
206
+
let url = format!("{}/info/{}", self.base_url, did);
207
+
208
+
let response = self
209
+
.http_client
210
+
.get(&url)
211
+
.headers(self.default_headers())
212
+
.send()
213
+
.await?;
214
+
215
+
if response.status().is_success() {
216
+
let info: RepoInfo = response.json().await?;
217
+
Ok(info)
218
+
} else {
219
+
let status = response.status().as_u16();
220
+
let message = response.text().await.unwrap_or_default();
221
+
Err(TapError::HttpResponseError { status, message })
222
+
}
223
+
}
224
+
}
225
+
226
+
/// Request body for adding/removing repositories.
227
+
#[derive(Debug, Serialize)]
228
+
struct AddReposRequest {
229
+
dids: Vec<String>,
230
+
}
231
+
232
+
/// Repository tracking information.
233
+
#[derive(Debug, Clone, Serialize, Deserialize)]
234
+
pub struct RepoInfo {
235
+
/// The repository DID.
236
+
pub did: Box<str>,
237
+
/// Current sync state.
238
+
pub state: RepoState,
239
+
/// The handle for the repository.
240
+
#[serde(default)]
241
+
pub handle: Option<Box<str>>,
242
+
/// Number of records in the repository.
243
+
#[serde(default)]
244
+
pub records: u64,
245
+
/// Current repository revision.
246
+
#[serde(default)]
247
+
pub rev: Option<Box<str>>,
248
+
/// Number of retries for syncing.
249
+
#[serde(default)]
250
+
pub retries: u32,
251
+
/// Error message if any.
252
+
#[serde(default)]
253
+
pub error: Option<Box<str>>,
254
+
/// Additional fields may be present depending on TAP version.
255
+
#[serde(flatten)]
256
+
pub extra: serde_json::Value,
257
+
}
258
+
259
+
/// Repository sync state.
260
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
261
+
#[serde(rename_all = "lowercase")]
262
+
pub enum RepoState {
263
+
/// Repository is active and synced.
264
+
Active,
265
+
/// Repository is currently syncing.
266
+
Syncing,
267
+
/// Repository is fully synced.
268
+
Synced,
269
+
/// Sync failed for this repository.
270
+
Failed,
271
+
/// Repository is queued for sync.
272
+
Queued,
273
+
/// Unknown state.
274
+
#[serde(other)]
275
+
Unknown,
276
+
}
277
+
278
+
/// Deprecated alias for RepoState.
279
+
#[deprecated(since = "0.13.0", note = "Use RepoState instead")]
280
+
pub type RepoStatus = RepoState;
281
+
282
+
impl std::fmt::Display for RepoState {
283
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
284
+
match self {
285
+
RepoState::Active => write!(f, "active"),
286
+
RepoState::Syncing => write!(f, "syncing"),
287
+
RepoState::Synced => write!(f, "synced"),
288
+
RepoState::Failed => write!(f, "failed"),
289
+
RepoState::Queued => write!(f, "queued"),
290
+
RepoState::Unknown => write!(f, "unknown"),
291
+
}
292
+
}
293
+
}
294
+
295
+
#[cfg(test)]
296
+
mod tests {
297
+
use super::*;
298
+
299
+
#[test]
300
+
fn test_client_creation() {
301
+
let client = TapClient::new("localhost:2480", None);
302
+
assert_eq!(client.base_url, "http://localhost:2480");
303
+
assert!(client.auth_header.is_none());
304
+
305
+
let client = TapClient::new("localhost:2480", Some("secret".to_string()));
306
+
assert!(client.auth_header.is_some());
307
+
}
308
+
309
+
#[test]
310
+
fn test_repo_state_display() {
311
+
assert_eq!(RepoState::Active.to_string(), "active");
312
+
assert_eq!(RepoState::Syncing.to_string(), "syncing");
313
+
assert_eq!(RepoState::Synced.to_string(), "synced");
314
+
assert_eq!(RepoState::Failed.to_string(), "failed");
315
+
assert_eq!(RepoState::Queued.to_string(), "queued");
316
+
assert_eq!(RepoState::Unknown.to_string(), "unknown");
317
+
}
318
+
319
+
#[test]
320
+
fn test_repo_state_deserialize() {
321
+
let json = r#""active""#;
322
+
let state: RepoState = serde_json::from_str(json).unwrap();
323
+
assert_eq!(state, RepoState::Active);
324
+
325
+
let json = r#""syncing""#;
326
+
let state: RepoState = serde_json::from_str(json).unwrap();
327
+
assert_eq!(state, RepoState::Syncing);
328
+
329
+
let json = r#""some_new_state""#;
330
+
let state: RepoState = serde_json::from_str(json).unwrap();
331
+
assert_eq!(state, RepoState::Unknown);
332
+
}
333
+
334
+
#[test]
335
+
fn test_repo_info_deserialize() {
336
+
let json = r#"{"did":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","error":"","handle":"ngerakines.me","records":21382,"retries":0,"rev":"3mam4aazabs2m","state":"active"}"#;
337
+
let info: RepoInfo = serde_json::from_str(json).unwrap();
338
+
assert_eq!(&*info.did, "did:plc:cbkjy5n7bk3ax2wplmtjofq2");
339
+
assert_eq!(info.state, RepoState::Active);
340
+
assert_eq!(info.handle.as_deref(), Some("ngerakines.me"));
341
+
assert_eq!(info.records, 21382);
342
+
assert_eq!(info.retries, 0);
343
+
assert_eq!(info.rev.as_deref(), Some("3mam4aazabs2m"));
344
+
// Empty string deserializes as Some("")
345
+
assert_eq!(info.error.as_deref(), Some(""));
346
+
}
347
+
348
+
#[test]
349
+
fn test_repo_info_deserialize_minimal() {
350
+
// Test with only required fields
351
+
let json = r#"{"did":"did:plc:test","state":"syncing"}"#;
352
+
let info: RepoInfo = serde_json::from_str(json).unwrap();
353
+
assert_eq!(&*info.did, "did:plc:test");
354
+
assert_eq!(info.state, RepoState::Syncing);
355
+
assert_eq!(info.handle, None);
356
+
assert_eq!(info.records, 0);
357
+
assert_eq!(info.retries, 0);
358
+
assert_eq!(info.rev, None);
359
+
assert_eq!(info.error, None);
360
+
}
361
+
362
+
#[test]
363
+
fn test_add_repos_request_serialize() {
364
+
let req = AddReposRequest {
365
+
dids: vec!["did:plc:xyz".to_string(), "did:plc:abc".to_string()],
366
+
};
367
+
let json = serde_json::to_string(&req).unwrap();
368
+
assert!(json.contains("dids"));
369
+
assert!(json.contains("did:plc:xyz"));
370
+
}
371
+
}
+220
crates/atproto-tap/src/config.rs
+220
crates/atproto-tap/src/config.rs
···
1
+
//! Configuration for TAP stream connections.
2
+
//!
3
+
//! This module provides the [`TapConfig`] struct for configuring TAP stream
4
+
//! connections, including hostname, authentication, and reconnection behavior.
5
+
6
+
use std::time::Duration;
7
+
8
+
/// Configuration for a TAP stream connection.
9
+
///
10
+
/// Use [`TapConfig::builder()`] for ergonomic construction with defaults.
11
+
///
12
+
/// # Example
13
+
///
14
+
/// ```
15
+
/// use atproto_tap::TapConfig;
16
+
/// use std::time::Duration;
17
+
///
18
+
/// let config = TapConfig::builder()
19
+
/// .hostname("localhost:2480")
20
+
/// .admin_password("secret")
21
+
/// .send_acks(true)
22
+
/// .max_reconnect_attempts(Some(10))
23
+
/// .build();
24
+
/// ```
25
+
#[derive(Debug, Clone)]
26
+
pub struct TapConfig {
27
+
/// TAP service hostname (e.g., "localhost:2480").
28
+
///
29
+
/// The WebSocket URL is constructed as `ws://{hostname}/channel`.
30
+
pub hostname: String,
31
+
32
+
/// Optional admin password for authentication.
33
+
///
34
+
/// If set, HTTP Basic Auth is used with username "admin".
35
+
pub admin_password: Option<String>,
36
+
37
+
/// Whether to send acknowledgments for received messages.
38
+
///
39
+
/// Default: `true`. Set to `false` if the TAP service has acks disabled.
40
+
pub send_acks: bool,
41
+
42
+
/// User-Agent header value for WebSocket connections.
43
+
pub user_agent: String,
44
+
45
+
/// Maximum reconnection attempts before giving up.
46
+
///
47
+
/// `None` means unlimited reconnection attempts (default).
48
+
pub max_reconnect_attempts: Option<u32>,
49
+
50
+
/// Initial delay before first reconnection attempt.
51
+
///
52
+
/// Default: 1 second.
53
+
pub initial_reconnect_delay: Duration,
54
+
55
+
/// Maximum delay between reconnection attempts.
56
+
///
57
+
/// Default: 60 seconds.
58
+
pub max_reconnect_delay: Duration,
59
+
60
+
/// Multiplier for exponential backoff between reconnections.
61
+
///
62
+
/// Default: 2.0 (doubles the delay each attempt).
63
+
pub reconnect_backoff_multiplier: f64,
64
+
}
65
+
66
+
impl Default for TapConfig {
67
+
fn default() -> Self {
68
+
Self {
69
+
hostname: "localhost:2480".to_string(),
70
+
admin_password: None,
71
+
send_acks: true,
72
+
user_agent: format!("atproto-tap/{}", env!("CARGO_PKG_VERSION")),
73
+
max_reconnect_attempts: None,
74
+
initial_reconnect_delay: Duration::from_secs(1),
75
+
max_reconnect_delay: Duration::from_secs(60),
76
+
reconnect_backoff_multiplier: 2.0,
77
+
}
78
+
}
79
+
}
80
+
81
+
impl TapConfig {
82
+
/// Create a new configuration builder with defaults.
83
+
pub fn builder() -> TapConfigBuilder {
84
+
TapConfigBuilder::default()
85
+
}
86
+
87
+
/// Create a minimal configuration for the given hostname.
88
+
pub fn new(hostname: impl Into<String>) -> Self {
89
+
Self {
90
+
hostname: hostname.into(),
91
+
..Default::default()
92
+
}
93
+
}
94
+
95
+
/// Returns the WebSocket URL for the TAP channel.
96
+
pub fn ws_url(&self) -> String {
97
+
format!("ws://{}/channel", self.hostname)
98
+
}
99
+
100
+
/// Returns the HTTP base URL for the TAP management API.
101
+
pub fn http_base_url(&self) -> String {
102
+
format!("http://{}", self.hostname)
103
+
}
104
+
}
105
+
106
+
/// Builder for [`TapConfig`].
107
+
#[derive(Debug, Clone, Default)]
108
+
pub struct TapConfigBuilder {
109
+
config: TapConfig,
110
+
}
111
+
112
+
impl TapConfigBuilder {
113
+
/// Set the TAP service hostname.
114
+
pub fn hostname(mut self, hostname: impl Into<String>) -> Self {
115
+
self.config.hostname = hostname.into();
116
+
self
117
+
}
118
+
119
+
/// Set the admin password for authentication.
120
+
pub fn admin_password(mut self, password: impl Into<String>) -> Self {
121
+
self.config.admin_password = Some(password.into());
122
+
self
123
+
}
124
+
125
+
/// Set whether to send acknowledgments.
126
+
pub fn send_acks(mut self, send_acks: bool) -> Self {
127
+
self.config.send_acks = send_acks;
128
+
self
129
+
}
130
+
131
+
/// Set the User-Agent header value.
132
+
pub fn user_agent(mut self, user_agent: impl Into<String>) -> Self {
133
+
self.config.user_agent = user_agent.into();
134
+
self
135
+
}
136
+
137
+
/// Set the maximum reconnection attempts.
138
+
///
139
+
/// `None` means unlimited attempts.
140
+
pub fn max_reconnect_attempts(mut self, max: Option<u32>) -> Self {
141
+
self.config.max_reconnect_attempts = max;
142
+
self
143
+
}
144
+
145
+
/// Set the initial reconnection delay.
146
+
pub fn initial_reconnect_delay(mut self, delay: Duration) -> Self {
147
+
self.config.initial_reconnect_delay = delay;
148
+
self
149
+
}
150
+
151
+
/// Set the maximum reconnection delay.
152
+
pub fn max_reconnect_delay(mut self, delay: Duration) -> Self {
153
+
self.config.max_reconnect_delay = delay;
154
+
self
155
+
}
156
+
157
+
/// Set the reconnection backoff multiplier.
158
+
pub fn reconnect_backoff_multiplier(mut self, multiplier: f64) -> Self {
159
+
self.config.reconnect_backoff_multiplier = multiplier;
160
+
self
161
+
}
162
+
163
+
/// Build the configuration.
164
+
pub fn build(self) -> TapConfig {
165
+
self.config
166
+
}
167
+
}
168
+
169
+
#[cfg(test)]
170
+
mod tests {
171
+
use super::*;
172
+
173
+
#[test]
174
+
fn test_default_config() {
175
+
let config = TapConfig::default();
176
+
assert_eq!(config.hostname, "localhost:2480");
177
+
assert!(config.admin_password.is_none());
178
+
assert!(config.send_acks);
179
+
assert!(config.max_reconnect_attempts.is_none());
180
+
assert_eq!(config.initial_reconnect_delay, Duration::from_secs(1));
181
+
assert_eq!(config.max_reconnect_delay, Duration::from_secs(60));
182
+
assert!((config.reconnect_backoff_multiplier - 2.0).abs() < f64::EPSILON);
183
+
}
184
+
185
+
#[test]
186
+
fn test_builder() {
187
+
let config = TapConfig::builder()
188
+
.hostname("tap.example.com:2480")
189
+
.admin_password("secret123")
190
+
.send_acks(false)
191
+
.max_reconnect_attempts(Some(5))
192
+
.initial_reconnect_delay(Duration::from_millis(500))
193
+
.max_reconnect_delay(Duration::from_secs(30))
194
+
.reconnect_backoff_multiplier(1.5)
195
+
.build();
196
+
197
+
assert_eq!(config.hostname, "tap.example.com:2480");
198
+
assert_eq!(config.admin_password, Some("secret123".to_string()));
199
+
assert!(!config.send_acks);
200
+
assert_eq!(config.max_reconnect_attempts, Some(5));
201
+
assert_eq!(config.initial_reconnect_delay, Duration::from_millis(500));
202
+
assert_eq!(config.max_reconnect_delay, Duration::from_secs(30));
203
+
assert!((config.reconnect_backoff_multiplier - 1.5).abs() < f64::EPSILON);
204
+
}
205
+
206
+
#[test]
207
+
fn test_ws_url() {
208
+
let config = TapConfig::new("localhost:2480");
209
+
assert_eq!(config.ws_url(), "ws://localhost:2480/channel");
210
+
211
+
let config = TapConfig::new("tap.example.com:8080");
212
+
assert_eq!(config.ws_url(), "ws://tap.example.com:8080/channel");
213
+
}
214
+
215
+
#[test]
216
+
fn test_http_base_url() {
217
+
let config = TapConfig::new("localhost:2480");
218
+
assert_eq!(config.http_base_url(), "http://localhost:2480");
219
+
}
220
+
}
+168
crates/atproto-tap/src/connection.rs
+168
crates/atproto-tap/src/connection.rs
···
1
+
//! WebSocket connection management for TAP streams.
2
+
//!
3
+
//! This module handles the low-level WebSocket connection to a TAP service,
4
+
//! including authentication and message sending/receiving.
5
+
6
+
use crate::config::TapConfig;
7
+
use crate::errors::TapError;
8
+
use base64::Engine;
9
+
use base64::engine::general_purpose::STANDARD as BASE64;
10
+
use futures::{SinkExt, StreamExt};
11
+
use http::Uri;
12
+
use std::str::FromStr;
13
+
use tokio_websockets::{ClientBuilder, Message, WebSocketStream};
14
+
use tokio_websockets::MaybeTlsStream;
15
+
use tokio::net::TcpStream;
16
+
17
+
/// WebSocket connection to a TAP service.
18
+
pub(crate) struct TapConnection {
19
+
/// The underlying WebSocket stream.
20
+
ws: WebSocketStream<MaybeTlsStream<TcpStream>>,
21
+
/// Pre-allocated buffer for acknowledgment messages.
22
+
ack_buffer: Vec<u8>,
23
+
}
24
+
25
+
impl TapConnection {
26
+
/// Establish a new WebSocket connection to the TAP service.
27
+
pub async fn connect(config: &TapConfig) -> Result<Self, TapError> {
28
+
let uri = Uri::from_str(&config.ws_url())
29
+
.map_err(|e| TapError::InvalidUrl(e.to_string()))?;
30
+
31
+
let mut builder = ClientBuilder::from_uri(uri);
32
+
33
+
// Add User-Agent header
34
+
builder = builder
35
+
.add_header(
36
+
http::header::USER_AGENT,
37
+
http::HeaderValue::from_str(&config.user_agent)
38
+
.map_err(|e| TapError::ConnectionFailed(format!("Invalid user agent: {}", e)))?,
39
+
)
40
+
.map_err(|e| TapError::ConnectionFailed(format!("Failed to add header: {}", e)))?;
41
+
42
+
// Add Basic Auth header if password is configured
43
+
if let Some(password) = &config.admin_password {
44
+
let credentials = format!("admin:{}", password);
45
+
let encoded = BASE64.encode(credentials.as_bytes());
46
+
let auth_value = format!("Basic {}", encoded);
47
+
48
+
builder = builder
49
+
.add_header(
50
+
http::header::AUTHORIZATION,
51
+
http::HeaderValue::from_str(&auth_value)
52
+
.map_err(|e| TapError::ConnectionFailed(format!("Invalid auth header: {}", e)))?,
53
+
)
54
+
.map_err(|e| TapError::ConnectionFailed(format!("Failed to add auth header: {}", e)))?;
55
+
}
56
+
57
+
// Connect
58
+
let (ws, _response) = builder
59
+
.connect()
60
+
.await
61
+
.map_err(|e| TapError::ConnectionFailed(e.to_string()))?;
62
+
63
+
tracing::debug!(hostname = %config.hostname, "Connected to TAP service");
64
+
65
+
Ok(Self {
66
+
ws,
67
+
ack_buffer: Vec::with_capacity(48), // {"type":"ack","id":18446744073709551615} is 40 bytes max
68
+
})
69
+
}
70
+
71
+
/// Receive the next message from the WebSocket.
72
+
///
73
+
/// Returns `None` if the connection was closed cleanly.
74
+
pub async fn recv(&mut self) -> Result<Option<String>, TapError> {
75
+
match self.ws.next().await {
76
+
Some(Ok(msg)) => {
77
+
if msg.is_text() {
78
+
msg.as_text()
79
+
.map(|s| Some(s.to_string()))
80
+
.ok_or_else(|| TapError::ParseError("Failed to get text from message".into()))
81
+
} else if msg.is_close() {
82
+
tracing::debug!("Received close frame from TAP service");
83
+
Ok(None)
84
+
} else {
85
+
// Ignore ping/pong and binary messages
86
+
tracing::trace!("Received non-text message, ignoring");
87
+
// Recurse to get the next text message
88
+
Box::pin(self.recv()).await
89
+
}
90
+
}
91
+
Some(Err(e)) => Err(TapError::ConnectionFailed(e.to_string())),
92
+
None => {
93
+
tracing::debug!("WebSocket stream ended");
94
+
Ok(None)
95
+
}
96
+
}
97
+
}
98
+
99
+
/// Send an acknowledgment for the given event ID.
100
+
///
101
+
/// Uses a pre-allocated buffer and itoa for allocation-free formatting.
102
+
/// Format: `{"type":"ack","id":12345}`
103
+
pub async fn send_ack(&mut self, id: u64) -> Result<(), TapError> {
104
+
self.ack_buffer.clear();
105
+
self.ack_buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":");
106
+
let mut itoa_buf = itoa::Buffer::new();
107
+
self.ack_buffer.extend_from_slice(itoa_buf.format(id).as_bytes());
108
+
self.ack_buffer.push(b'}');
109
+
110
+
// All bytes are ASCII so this is always valid UTF-8
111
+
let msg = std::str::from_utf8(&self.ack_buffer)
112
+
.expect("ack buffer contains only ASCII");
113
+
114
+
self.ws
115
+
.send(Message::text(msg.to_string()))
116
+
.await
117
+
.map_err(|e| TapError::AckFailed(e.to_string()))?;
118
+
119
+
// Flush to ensure the ack is sent immediately
120
+
self.ws
121
+
.flush()
122
+
.await
123
+
.map_err(|e| TapError::AckFailed(format!("Failed to flush ack: {}", e)))?;
124
+
125
+
tracing::trace!(id, "Sent ack");
126
+
Ok(())
127
+
}
128
+
129
+
/// Close the WebSocket connection gracefully.
130
+
pub async fn close(&mut self) -> Result<(), TapError> {
131
+
self.ws
132
+
.close()
133
+
.await
134
+
.map_err(|e| TapError::ConnectionFailed(format!("Failed to close: {}", e)))?;
135
+
Ok(())
136
+
}
137
+
}
138
+
139
+
#[cfg(test)]
140
+
mod tests {
141
+
#[test]
142
+
fn test_ack_buffer_format() {
143
+
// Test that our manual JSON formatting is correct
144
+
// Format: {"type":"ack","id":12345}
145
+
let mut buffer = Vec::with_capacity(64);
146
+
147
+
let id: u64 = 12345;
148
+
buffer.clear();
149
+
buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":");
150
+
let mut itoa_buf = itoa::Buffer::new();
151
+
buffer.extend_from_slice(itoa_buf.format(id).as_bytes());
152
+
buffer.push(b'}');
153
+
154
+
let result = std::str::from_utf8(&buffer).unwrap();
155
+
assert_eq!(result, r#"{"type":"ack","id":12345}"#);
156
+
157
+
// Test max u64
158
+
let id: u64 = u64::MAX;
159
+
buffer.clear();
160
+
buffer.extend_from_slice(b"{\"type\":\"ack\",\"id\":");
161
+
buffer.extend_from_slice(itoa_buf.format(id).as_bytes());
162
+
buffer.push(b'}');
163
+
164
+
let result = std::str::from_utf8(&buffer).unwrap();
165
+
assert_eq!(result, r#"{"type":"ack","id":18446744073709551615}"#);
166
+
assert!(buffer.len() <= 64); // Fits in our pre-allocated buffer
167
+
}
168
+
}
+143
crates/atproto-tap/src/errors.rs
+143
crates/atproto-tap/src/errors.rs
···
1
+
//! Error types for TAP operations.
2
+
//!
3
+
//! This module defines the error types returned by TAP stream and client operations.
4
+
5
+
use thiserror::Error;
6
+
7
+
/// Errors that can occur during TAP operations.
8
+
#[derive(Debug, Error)]
9
+
pub enum TapError {
10
+
/// WebSocket connection failed.
11
+
#[error("error-atproto-tap-connection-1 WebSocket connection failed: {0}")]
12
+
ConnectionFailed(String),
13
+
14
+
/// Connection was closed unexpectedly.
15
+
#[error("error-atproto-tap-connection-2 Connection closed unexpectedly")]
16
+
ConnectionClosed,
17
+
18
+
/// Maximum reconnection attempts exceeded.
19
+
#[error("error-atproto-tap-connection-3 Maximum reconnection attempts exceeded after {0} attempts")]
20
+
MaxReconnectAttemptsExceeded(u32),
21
+
22
+
/// Authentication failed.
23
+
#[error("error-atproto-tap-auth-1 Authentication failed: {0}")]
24
+
AuthenticationFailed(String),
25
+
26
+
/// Failed to parse a message from the server.
27
+
#[error("error-atproto-tap-parse-1 Failed to parse message: {0}")]
28
+
ParseError(String),
29
+
30
+
/// Failed to send an acknowledgment.
31
+
#[error("error-atproto-tap-ack-1 Failed to send acknowledgment: {0}")]
32
+
AckFailed(String),
33
+
34
+
/// HTTP request failed.
35
+
#[error("error-atproto-tap-http-1 HTTP request failed: {0}")]
36
+
HttpError(String),
37
+
38
+
/// HTTP response indicated an error.
39
+
#[error("error-atproto-tap-http-2 HTTP error response: {status} - {message}")]
40
+
HttpResponseError {
41
+
/// HTTP status code.
42
+
status: u16,
43
+
/// Error message from response.
44
+
message: String,
45
+
},
46
+
47
+
/// Invalid URL.
48
+
#[error("error-atproto-tap-url-1 Invalid URL: {0}")]
49
+
InvalidUrl(String),
50
+
51
+
/// I/O error.
52
+
#[error("error-atproto-tap-io-1 I/O error: {0}")]
53
+
IoError(#[from] std::io::Error),
54
+
55
+
/// JSON serialization/deserialization error.
56
+
#[error("error-atproto-tap-json-1 JSON error: {0}")]
57
+
JsonError(#[from] serde_json::Error),
58
+
59
+
/// Stream has been closed and cannot be used.
60
+
#[error("error-atproto-tap-stream-1 Stream is closed")]
61
+
StreamClosed,
62
+
63
+
/// Operation timed out.
64
+
#[error("error-atproto-tap-timeout-1 Operation timed out")]
65
+
Timeout,
66
+
}
67
+
68
+
impl TapError {
69
+
/// Returns true if this error indicates a connection issue that may be recoverable.
70
+
pub fn is_connection_error(&self) -> bool {
71
+
matches!(
72
+
self,
73
+
TapError::ConnectionFailed(_)
74
+
| TapError::ConnectionClosed
75
+
| TapError::IoError(_)
76
+
| TapError::Timeout
77
+
)
78
+
}
79
+
80
+
/// Returns true if this error is a parse error that doesn't affect connection state.
81
+
pub fn is_parse_error(&self) -> bool {
82
+
matches!(self, TapError::ParseError(_) | TapError::JsonError(_))
83
+
}
84
+
85
+
/// Returns true if this error is fatal and the stream should not attempt recovery.
86
+
pub fn is_fatal(&self) -> bool {
87
+
matches!(
88
+
self,
89
+
TapError::MaxReconnectAttemptsExceeded(_)
90
+
| TapError::AuthenticationFailed(_)
91
+
| TapError::StreamClosed
92
+
)
93
+
}
94
+
}
95
+
96
+
impl From<reqwest::Error> for TapError {
97
+
fn from(err: reqwest::Error) -> Self {
98
+
if err.is_timeout() {
99
+
TapError::Timeout
100
+
} else if err.is_connect() {
101
+
TapError::ConnectionFailed(err.to_string())
102
+
} else {
103
+
TapError::HttpError(err.to_string())
104
+
}
105
+
}
106
+
}
107
+
108
+
#[cfg(test)]
109
+
mod tests {
110
+
use super::*;
111
+
112
+
#[test]
113
+
fn test_error_classification() {
114
+
assert!(TapError::ConnectionFailed("test".into()).is_connection_error());
115
+
assert!(TapError::ConnectionClosed.is_connection_error());
116
+
assert!(TapError::Timeout.is_connection_error());
117
+
118
+
assert!(TapError::ParseError("test".into()).is_parse_error());
119
+
assert!(TapError::JsonError(serde_json::from_str::<()>("invalid").unwrap_err()).is_parse_error());
120
+
121
+
assert!(TapError::MaxReconnectAttemptsExceeded(5).is_fatal());
122
+
assert!(TapError::AuthenticationFailed("test".into()).is_fatal());
123
+
assert!(TapError::StreamClosed.is_fatal());
124
+
125
+
// Non-fatal errors
126
+
assert!(!TapError::ConnectionFailed("test".into()).is_fatal());
127
+
assert!(!TapError::ParseError("test".into()).is_fatal());
128
+
}
129
+
130
+
#[test]
131
+
fn test_error_display() {
132
+
let err = TapError::ConnectionFailed("refused".to_string());
133
+
assert!(err.to_string().contains("error-atproto-tap-connection-1"));
134
+
assert!(err.to_string().contains("refused"));
135
+
136
+
let err = TapError::HttpResponseError {
137
+
status: 404,
138
+
message: "Not Found".to_string(),
139
+
};
140
+
assert!(err.to_string().contains("404"));
141
+
assert!(err.to_string().contains("Not Found"));
142
+
}
143
+
}
+488
crates/atproto-tap/src/events.rs
+488
crates/atproto-tap/src/events.rs
···
1
+
//! TAP event types for AT Protocol record and identity events.
2
+
//!
3
+
//! This module defines the event structures received from a TAP service.
4
+
//! Events are optimized for memory efficiency using:
5
+
//! - `CompactString` for small strings (SSO for โค24 bytes)
6
+
//! - `Box<str>` for immutable strings (no capacity overhead)
7
+
//! - `serde_json::Value` for record payloads (allows lazy access)
8
+
9
+
use compact_str::CompactString;
10
+
use serde::de::{self, Deserializer, IgnoredAny, MapAccess, Visitor};
11
+
use serde::{Deserialize, Serialize, de::DeserializeOwned};
12
+
use std::fmt;
13
+
14
+
/// A TAP event received from the stream.
15
+
///
16
+
/// TAP delivers two types of events:
17
+
/// - `Record`: Repository record changes (create, update, delete)
18
+
/// - `Identity`: Identity/handle changes for accounts
19
+
#[derive(Debug, Clone, Serialize, Deserialize)]
20
+
#[serde(tag = "type", rename_all = "lowercase")]
21
+
pub enum TapEvent {
22
+
/// A repository record event (create, update, or delete).
23
+
Record {
24
+
/// Sequential event identifier.
25
+
id: u64,
26
+
/// The record event data.
27
+
record: RecordEvent,
28
+
},
29
+
/// An identity change event.
30
+
Identity {
31
+
/// Sequential event identifier.
32
+
id: u64,
33
+
/// The identity event data.
34
+
identity: IdentityEvent,
35
+
},
36
+
}
37
+
38
+
impl TapEvent {
39
+
/// Returns the event ID.
40
+
pub fn id(&self) -> u64 {
41
+
match self {
42
+
TapEvent::Record { id, .. } => *id,
43
+
TapEvent::Identity { id, .. } => *id,
44
+
}
45
+
}
46
+
}
47
+
48
+
/// Extract only the event ID from a JSON string without fully parsing it.
49
+
///
50
+
/// This is a fallback parser used when full `TapEvent` parsing fails (e.g., due to
51
+
/// deeply nested records hitting serde_json's recursion limit). It uses `IgnoredAny`
52
+
/// to efficiently skip over nested content without building data structures, allowing
53
+
/// us to extract the ID for acknowledgment even when full parsing fails.
54
+
///
55
+
/// # Example
56
+
///
57
+
/// ```
58
+
/// use atproto_tap::extract_event_id;
59
+
///
60
+
/// let json = r#"{"type":"record","id":12345,"record":{"deeply":"nested"}}"#;
61
+
/// assert_eq!(extract_event_id(json), Some(12345));
62
+
/// ```
63
+
pub fn extract_event_id(json: &str) -> Option<u64> {
64
+
let mut deserializer = serde_json::Deserializer::from_str(json);
65
+
deserializer.disable_recursion_limit();
66
+
EventIdOnly::deserialize(&mut deserializer).ok().map(|e| e.id)
67
+
}
68
+
69
+
/// Internal struct for extracting only the "id" field from a TAP event.
70
+
#[derive(Debug)]
71
+
struct EventIdOnly {
72
+
id: u64,
73
+
}
74
+
75
+
impl<'de> Deserialize<'de> for EventIdOnly {
76
+
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
77
+
where
78
+
D: Deserializer<'de>,
79
+
{
80
+
deserializer.deserialize_map(EventIdOnlyVisitor)
81
+
}
82
+
}
83
+
84
+
struct EventIdOnlyVisitor;
85
+
86
+
impl<'de> Visitor<'de> for EventIdOnlyVisitor {
87
+
type Value = EventIdOnly;
88
+
89
+
fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
90
+
formatter.write_str("a map with an 'id' field")
91
+
}
92
+
93
+
fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
94
+
where
95
+
M: MapAccess<'de>,
96
+
{
97
+
let mut id: Option<u64> = None;
98
+
99
+
while let Some(key) = map.next_key::<&str>()? {
100
+
if key == "id" {
101
+
id = Some(map.next_value()?);
102
+
// Found what we need - skip the rest efficiently using IgnoredAny
103
+
// which handles deeply nested structures without recursion issues
104
+
while map.next_entry::<IgnoredAny, IgnoredAny>()?.is_some() {}
105
+
break;
106
+
} else {
107
+
// Skip this value without fully parsing it
108
+
map.next_value::<IgnoredAny>()?;
109
+
}
110
+
}
111
+
112
+
id.map(|id| EventIdOnly { id })
113
+
.ok_or_else(|| de::Error::missing_field("id"))
114
+
}
115
+
}
116
+
117
+
/// A repository record event from TAP.
118
+
///
119
+
/// Contains information about a record change in a user's repository,
120
+
/// including the action taken and the record data (for creates/updates).
121
+
#[derive(Debug, Clone, Serialize, Deserialize)]
122
+
pub struct RecordEvent {
123
+
/// True if from live firehose, false if from backfill/resync.
124
+
///
125
+
/// During initial sync or recovery, TAP delivers historical events
126
+
/// with `live: false`. Once caught up, live events have `live: true`.
127
+
pub live: bool,
128
+
129
+
/// Repository revision identifier.
130
+
///
131
+
/// Typically 13 characters, stored inline via CompactString SSO.
132
+
pub rev: CompactString,
133
+
134
+
/// Actor DID (e.g., "did:plc:xyz123").
135
+
pub did: Box<str>,
136
+
137
+
/// Collection NSID (e.g., "app.bsky.feed.post").
138
+
pub collection: Box<str>,
139
+
140
+
/// Record key within the collection.
141
+
///
142
+
/// Typically a TID (13 characters), stored inline via CompactString SSO.
143
+
pub rkey: CompactString,
144
+
145
+
/// The action performed on the record.
146
+
pub action: RecordAction,
147
+
148
+
/// Content identifier (CID) of the record.
149
+
///
150
+
/// Present for create and update actions, absent for delete.
151
+
#[serde(skip_serializing_if = "Option::is_none")]
152
+
pub cid: Option<CompactString>,
153
+
154
+
/// Record data as JSON value.
155
+
///
156
+
/// Present for create and update actions, absent for delete.
157
+
/// Use [`parse_record`](Self::parse_record) to deserialize on demand.
158
+
#[serde(skip_serializing_if = "Option::is_none")]
159
+
pub record: Option<serde_json::Value>,
160
+
}
161
+
162
+
impl RecordEvent {
163
+
/// Parse the record payload into a typed structure.
164
+
///
165
+
/// This method deserializes the raw JSON on demand, avoiding
166
+
/// unnecessary allocations when the record data isn't needed.
167
+
///
168
+
/// # Errors
169
+
///
170
+
/// Returns an error if the record is absent (delete events) or
171
+
/// if deserialization fails.
172
+
///
173
+
/// # Example
174
+
///
175
+
/// ```ignore
176
+
/// use serde::Deserialize;
177
+
///
178
+
/// #[derive(Deserialize)]
179
+
/// struct Post {
180
+
/// text: String,
181
+
/// #[serde(rename = "createdAt")]
182
+
/// created_at: String,
183
+
/// }
184
+
///
185
+
/// let post: Post = record_event.parse_record()?;
186
+
/// println!("Post text: {}", post.text);
187
+
/// ```
188
+
pub fn parse_record<T: DeserializeOwned>(&self) -> Result<T, serde_json::Error> {
189
+
match &self.record {
190
+
Some(value) => serde_json::from_value(value.clone()),
191
+
None => Err(serde::de::Error::custom("no record data (delete event)")),
192
+
}
193
+
}
194
+
195
+
/// Returns the record as a JSON Value reference, if present.
196
+
pub fn record_value(&self) -> Option<&serde_json::Value> {
197
+
self.record.as_ref()
198
+
}
199
+
200
+
/// Returns true if this is a delete event.
201
+
pub fn is_delete(&self) -> bool {
202
+
self.action == RecordAction::Delete
203
+
}
204
+
205
+
/// Returns the AT-URI for this record.
206
+
///
207
+
/// Format: `at://{did}/{collection}/{rkey}`
208
+
pub fn at_uri(&self) -> String {
209
+
format!("at://{}/{}/{}", self.did, self.collection, self.rkey)
210
+
}
211
+
}
212
+
213
+
/// The action performed on a record.
214
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
215
+
#[serde(rename_all = "lowercase")]
216
+
pub enum RecordAction {
217
+
/// A new record was created.
218
+
Create,
219
+
/// An existing record was updated.
220
+
Update,
221
+
/// A record was deleted.
222
+
Delete,
223
+
}
224
+
225
+
impl std::fmt::Display for RecordAction {
226
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
227
+
match self {
228
+
RecordAction::Create => write!(f, "create"),
229
+
RecordAction::Update => write!(f, "update"),
230
+
RecordAction::Delete => write!(f, "delete"),
231
+
}
232
+
}
233
+
}
234
+
235
+
/// An identity change event from TAP.
236
+
///
237
+
/// Contains information about handle or account status changes.
238
+
#[derive(Debug, Clone, Serialize, Deserialize)]
239
+
pub struct IdentityEvent {
240
+
/// Actor DID.
241
+
pub did: Box<str>,
242
+
243
+
/// Current handle for the account.
244
+
pub handle: Box<str>,
245
+
246
+
/// Whether the account is currently active.
247
+
#[serde(default)]
248
+
pub is_active: bool,
249
+
250
+
/// Account status.
251
+
#[serde(default)]
252
+
pub status: IdentityStatus,
253
+
}
254
+
255
+
/// Account status in an identity event.
256
+
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
257
+
#[serde(rename_all = "lowercase")]
258
+
pub enum IdentityStatus {
259
+
/// Account is active and in good standing.
260
+
#[default]
261
+
Active,
262
+
/// Account has been deactivated by the user.
263
+
Deactivated,
264
+
/// Account has been suspended.
265
+
Suspended,
266
+
/// Account has been deleted.
267
+
Deleted,
268
+
/// Account has been taken down.
269
+
Takendown,
270
+
}
271
+
272
+
impl std::fmt::Display for IdentityStatus {
273
+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
274
+
match self {
275
+
IdentityStatus::Active => write!(f, "active"),
276
+
IdentityStatus::Deactivated => write!(f, "deactivated"),
277
+
IdentityStatus::Suspended => write!(f, "suspended"),
278
+
IdentityStatus::Deleted => write!(f, "deleted"),
279
+
IdentityStatus::Takendown => write!(f, "takendown"),
280
+
}
281
+
}
282
+
}
283
+
284
+
#[cfg(test)]
285
+
mod tests {
286
+
use super::*;
287
+
288
+
#[test]
289
+
fn test_parse_record_event() {
290
+
let json = r#"{
291
+
"id": 12345,
292
+
"type": "record",
293
+
"record": {
294
+
"live": true,
295
+
"rev": "3lyileto4q52k",
296
+
"did": "did:plc:z72i7hdynmk6r22z27h6tvur",
297
+
"collection": "app.bsky.feed.post",
298
+
"rkey": "3lyiletddxt2c",
299
+
"action": "create",
300
+
"cid": "bafyreigroo6vhxt62ufcndhaxzas6btq4jmniuz4egszbwuqgiyisqwqoy",
301
+
"record": {"$type": "app.bsky.feed.post", "text": "Hello world!", "createdAt": "2025-01-01T00:00:00Z"}
302
+
}
303
+
}"#;
304
+
305
+
let event: TapEvent = serde_json::from_str(json).expect("Failed to parse");
306
+
307
+
match event {
308
+
TapEvent::Record { id, record } => {
309
+
assert_eq!(id, 12345);
310
+
assert!(record.live);
311
+
assert_eq!(record.rev.as_str(), "3lyileto4q52k");
312
+
assert_eq!(&*record.did, "did:plc:z72i7hdynmk6r22z27h6tvur");
313
+
assert_eq!(&*record.collection, "app.bsky.feed.post");
314
+
assert_eq!(record.rkey.as_str(), "3lyiletddxt2c");
315
+
assert_eq!(record.action, RecordAction::Create);
316
+
assert!(record.cid.is_some());
317
+
assert!(record.record.is_some());
318
+
319
+
// Test lazy parsing
320
+
#[derive(Deserialize)]
321
+
struct Post {
322
+
text: String,
323
+
}
324
+
let post: Post = record.parse_record().expect("Failed to parse record");
325
+
assert_eq!(post.text, "Hello world!");
326
+
}
327
+
_ => panic!("Expected Record event"),
328
+
}
329
+
}
330
+
331
+
#[test]
332
+
fn test_parse_delete_event() {
333
+
let json = r#"{
334
+
"id": 12346,
335
+
"type": "record",
336
+
"record": {
337
+
"live": true,
338
+
"rev": "3lyileto4q52k",
339
+
"did": "did:plc:z72i7hdynmk6r22z27h6tvur",
340
+
"collection": "app.bsky.feed.post",
341
+
"rkey": "3lyiletddxt2c",
342
+
"action": "delete"
343
+
}
344
+
}"#;
345
+
346
+
let event: TapEvent = serde_json::from_str(json).expect("Failed to parse");
347
+
348
+
match event {
349
+
TapEvent::Record { id, record } => {
350
+
assert_eq!(id, 12346);
351
+
assert_eq!(record.action, RecordAction::Delete);
352
+
assert!(record.is_delete());
353
+
assert!(record.cid.is_none());
354
+
assert!(record.record.is_none());
355
+
}
356
+
_ => panic!("Expected Record event"),
357
+
}
358
+
}
359
+
360
+
#[test]
361
+
fn test_parse_identity_event() {
362
+
let json = r#"{
363
+
"id": 12347,
364
+
"type": "identity",
365
+
"identity": {
366
+
"did": "did:plc:z72i7hdynmk6r22z27h6tvur",
367
+
"handle": "user.bsky.social",
368
+
"is_active": true,
369
+
"status": "active"
370
+
}
371
+
}"#;
372
+
373
+
let event: TapEvent = serde_json::from_str(json).expect("Failed to parse");
374
+
375
+
match event {
376
+
TapEvent::Identity { id, identity } => {
377
+
assert_eq!(id, 12347);
378
+
assert_eq!(&*identity.did, "did:plc:z72i7hdynmk6r22z27h6tvur");
379
+
assert_eq!(&*identity.handle, "user.bsky.social");
380
+
assert!(identity.is_active);
381
+
assert_eq!(identity.status, IdentityStatus::Active);
382
+
}
383
+
_ => panic!("Expected Identity event"),
384
+
}
385
+
}
386
+
387
+
#[test]
388
+
fn test_record_action_display() {
389
+
assert_eq!(RecordAction::Create.to_string(), "create");
390
+
assert_eq!(RecordAction::Update.to_string(), "update");
391
+
assert_eq!(RecordAction::Delete.to_string(), "delete");
392
+
}
393
+
394
+
#[test]
395
+
fn test_identity_status_display() {
396
+
assert_eq!(IdentityStatus::Active.to_string(), "active");
397
+
assert_eq!(IdentityStatus::Deactivated.to_string(), "deactivated");
398
+
assert_eq!(IdentityStatus::Suspended.to_string(), "suspended");
399
+
assert_eq!(IdentityStatus::Deleted.to_string(), "deleted");
400
+
assert_eq!(IdentityStatus::Takendown.to_string(), "takendown");
401
+
}
402
+
403
+
#[test]
404
+
fn test_at_uri() {
405
+
let record = RecordEvent {
406
+
live: true,
407
+
rev: "3lyileto4q52k".into(),
408
+
did: "did:plc:xyz".into(),
409
+
collection: "app.bsky.feed.post".into(),
410
+
rkey: "abc123".into(),
411
+
action: RecordAction::Create,
412
+
cid: None,
413
+
record: None,
414
+
};
415
+
416
+
assert_eq!(record.at_uri(), "at://did:plc:xyz/app.bsky.feed.post/abc123");
417
+
}
418
+
419
+
#[test]
420
+
fn test_event_id() {
421
+
let record_event = TapEvent::Record {
422
+
id: 100,
423
+
record: RecordEvent {
424
+
live: true,
425
+
rev: "rev".into(),
426
+
did: "did".into(),
427
+
collection: "col".into(),
428
+
rkey: "rkey".into(),
429
+
action: RecordAction::Create,
430
+
cid: None,
431
+
record: None,
432
+
},
433
+
};
434
+
assert_eq!(record_event.id(), 100);
435
+
436
+
let identity_event = TapEvent::Identity {
437
+
id: 200,
438
+
identity: IdentityEvent {
439
+
did: "did".into(),
440
+
handle: "handle".into(),
441
+
is_active: true,
442
+
status: IdentityStatus::Active,
443
+
},
444
+
};
445
+
assert_eq!(identity_event.id(), 200);
446
+
}
447
+
448
+
#[test]
449
+
fn test_extract_event_id_simple() {
450
+
let json = r#"{"type":"record","id":12345,"record":{"deeply":"nested"}}"#;
451
+
assert_eq!(extract_event_id(json), Some(12345));
452
+
}
453
+
454
+
#[test]
455
+
fn test_extract_event_id_at_end() {
456
+
let json = r#"{"type":"record","record":{"deeply":"nested"},"id":99999}"#;
457
+
assert_eq!(extract_event_id(json), Some(99999));
458
+
}
459
+
460
+
#[test]
461
+
fn test_extract_event_id_missing() {
462
+
let json = r#"{"type":"record","record":{"deeply":"nested"}}"#;
463
+
assert_eq!(extract_event_id(json), None);
464
+
}
465
+
466
+
#[test]
467
+
fn test_extract_event_id_invalid_json() {
468
+
let json = r#"{"type":"record","id":123"#; // Truncated JSON
469
+
assert_eq!(extract_event_id(json), None);
470
+
}
471
+
472
+
#[test]
473
+
fn test_extract_event_id_deeply_nested() {
474
+
// Create a deeply nested JSON that would exceed serde_json's default recursion limit
475
+
let mut json = String::from(r#"{"id":42,"record":{"nested":"#);
476
+
for _ in 0..200 {
477
+
json.push_str("[");
478
+
}
479
+
json.push_str("1");
480
+
for _ in 0..200 {
481
+
json.push_str("]");
482
+
}
483
+
json.push_str("}}");
484
+
485
+
// extract_event_id should still work because it uses IgnoredAny with disabled recursion limit
486
+
assert_eq!(extract_event_id(&json), Some(42));
487
+
}
488
+
}
+119
crates/atproto-tap/src/lib.rs
+119
crates/atproto-tap/src/lib.rs
···
1
+
//! TAP (Trusted Attestation Protocol) service consumer for AT Protocol.
2
+
//!
3
+
//! This crate provides a client for consuming events from a TAP service,
4
+
//! which delivers filtered, verified AT Protocol repository events.
5
+
//!
6
+
//! # Overview
7
+
//!
8
+
//! TAP is a single-tenant service that subscribes to an AT Protocol Relay and
9
+
//! outputs filtered, verified events. Key features include:
10
+
//!
11
+
//! - **Verified Events**: MST integrity checks and signature verification
12
+
//! - **Automatic Backfill**: Historical events delivered with `live: false`
13
+
//! - **Repository Filtering**: Track specific DIDs or collections
14
+
//! - **Acknowledgment Protocol**: At-least-once delivery semantics
15
+
//!
16
+
//! # Quick Start
17
+
//!
18
+
//! ```ignore
19
+
//! use atproto_tap::{connect_to, TapEvent};
20
+
//! use tokio_stream::StreamExt;
21
+
//!
22
+
//! #[tokio::main]
23
+
//! async fn main() {
24
+
//! let mut stream = connect_to("localhost:2480");
25
+
//!
26
+
//! while let Some(result) = stream.next().await {
27
+
//! match result {
28
+
//! Ok(event) => match event.as_ref() {
29
+
//! TapEvent::Record { record, .. } => {
30
+
//! println!("{} {} {}", record.action, record.collection, record.did);
31
+
//! }
32
+
//! TapEvent::Identity { identity, .. } => {
33
+
//! println!("Identity: {} = {}", identity.did, identity.handle);
34
+
//! }
35
+
//! },
36
+
//! Err(e) => eprintln!("Error: {}", e),
37
+
//! }
38
+
//! }
39
+
//! }
40
+
//! ```
41
+
//!
42
+
//! # Using with `tokio::select!`
43
+
//!
44
+
//! The stream integrates naturally with Tokio's select macro:
45
+
//!
46
+
//! ```ignore
47
+
//! use atproto_tap::{connect, TapConfig};
48
+
//! use tokio_stream::StreamExt;
49
+
//! use tokio::signal;
50
+
//!
51
+
//! #[tokio::main]
52
+
//! async fn main() {
53
+
//! let config = TapConfig::builder()
54
+
//! .hostname("localhost:2480")
55
+
//! .admin_password("secret")
56
+
//! .build();
57
+
//!
58
+
//! let mut stream = connect(config);
59
+
//!
60
+
//! loop {
61
+
//! tokio::select! {
62
+
//! Some(result) = stream.next() => {
63
+
//! // Process event
64
+
//! }
65
+
//! _ = signal::ctrl_c() => {
66
+
//! break;
67
+
//! }
68
+
//! }
69
+
//! }
70
+
//! }
71
+
//! ```
72
+
//!
73
+
//! # Management API
74
+
//!
75
+
//! Use [`TapClient`] to manage tracked repositories:
76
+
//!
77
+
//! ```ignore
78
+
//! use atproto_tap::TapClient;
79
+
//!
80
+
//! let client = TapClient::new("localhost:2480", Some("password".to_string()));
81
+
//!
82
+
//! // Add repositories to track
83
+
//! client.add_repos(&["did:plc:xyz123"]).await?;
84
+
//!
85
+
//! // Check service health
86
+
//! if client.health().await? {
87
+
//! println!("TAP service is healthy");
88
+
//! }
89
+
//! ```
90
+
//!
91
+
//! # Memory Efficiency
92
+
//!
93
+
//! This crate is optimized for high-throughput event processing:
94
+
//!
95
+
//! - **Arc-wrapped events**: Events are shared via `Arc` for zero-cost sharing
96
+
//! - **CompactString**: Small strings use inline storage (no heap allocation)
97
+
//! - **Box<str>**: Immutable strings without capacity overhead
98
+
//! - **RawValue**: Record payloads are lazily parsed on demand
99
+
//! - **Pre-allocated buffers**: Ack messages avoid per-message allocations
100
+
101
+
#![forbid(unsafe_code)]
102
+
#![warn(missing_docs)]
103
+
104
+
mod client;
105
+
mod config;
106
+
mod connection;
107
+
mod errors;
108
+
mod events;
109
+
mod stream;
110
+
111
+
// Re-export public types
112
+
pub use atproto_identity::model::{Document, Service, VerificationMethod};
113
+
pub use client::{RepoInfo, RepoState, TapClient};
114
+
#[allow(deprecated)]
115
+
pub use client::RepoStatus;
116
+
pub use config::{TapConfig, TapConfigBuilder};
117
+
pub use errors::TapError;
118
+
pub use events::{IdentityEvent, IdentityStatus, RecordAction, RecordEvent, TapEvent, extract_event_id};
119
+
pub use stream::{TapStream, connect, connect_to};
+330
crates/atproto-tap/src/stream.rs
+330
crates/atproto-tap/src/stream.rs
···
1
+
//! TAP event stream implementation.
2
+
//!
3
+
//! This module provides [`TapStream`], an async stream that yields TAP events
4
+
//! with automatic connection management and reconnection handling.
5
+
//!
6
+
//! # Design
7
+
//!
8
+
//! The stream encapsulates all connection logic, allowing consumers to simply
9
+
//! iterate over events using standard stream combinators or `tokio::select!`.
10
+
//!
11
+
//! Reconnection is handled automatically with exponential backoff. Parse errors
12
+
//! are yielded as `Err` items but don't affect connection state - only connection
13
+
//! errors trigger reconnection attempts.
14
+
15
+
use crate::config::TapConfig;
16
+
use crate::connection::TapConnection;
17
+
use crate::errors::TapError;
18
+
use crate::events::{TapEvent, extract_event_id};
19
+
use futures::Stream;
20
+
use std::pin::Pin;
21
+
use std::sync::Arc;
22
+
use std::task::{Context, Poll};
23
+
use std::time::Duration;
24
+
use tokio::sync::mpsc;
25
+
26
+
/// An async stream of TAP events with automatic reconnection.
27
+
///
28
+
/// `TapStream` implements [`Stream`] and yields `Result<Arc<TapEvent>, TapError>`.
29
+
/// Events are wrapped in `Arc` for efficient zero-cost sharing across consumers.
30
+
///
31
+
/// # Connection Management
32
+
///
33
+
/// The stream automatically:
34
+
/// - Connects on first poll
35
+
/// - Reconnects with exponential backoff on connection errors
36
+
/// - Sends acknowledgments after parsing each message (if enabled)
37
+
/// - Yields parse errors without affecting connection state
38
+
///
39
+
/// # Example
40
+
///
41
+
/// ```ignore
42
+
/// use atproto_tap::{TapConfig, TapStream};
43
+
/// use tokio_stream::StreamExt;
44
+
///
45
+
/// let config = TapConfig::builder()
46
+
/// .hostname("localhost:2480")
47
+
/// .build();
48
+
///
49
+
/// let mut stream = TapStream::new(config);
50
+
///
51
+
/// while let Some(result) = stream.next().await {
52
+
/// match result {
53
+
/// Ok(event) => println!("Event: {:?}", event),
54
+
/// Err(e) => eprintln!("Error: {}", e),
55
+
/// }
56
+
/// }
57
+
/// ```
58
+
pub struct TapStream {
59
+
/// Receiver for events from the background task.
60
+
receiver: mpsc::Receiver<Result<Arc<TapEvent>, TapError>>,
61
+
/// Handle to request stream closure.
62
+
close_sender: Option<mpsc::Sender<()>>,
63
+
/// Whether the stream has been closed.
64
+
closed: bool,
65
+
}
66
+
67
+
impl TapStream {
68
+
/// Create a new TAP stream with the given configuration.
69
+
///
70
+
/// The stream will start connecting immediately in a background task.
71
+
pub fn new(config: TapConfig) -> Self {
72
+
// Channel for events - buffer a few to handle bursts
73
+
let (event_tx, event_rx) = mpsc::channel(32);
74
+
// Channel for close signal
75
+
let (close_tx, close_rx) = mpsc::channel(1);
76
+
77
+
// Spawn background task to manage connection
78
+
tokio::spawn(connection_task(config, event_tx, close_rx));
79
+
80
+
Self {
81
+
receiver: event_rx,
82
+
close_sender: Some(close_tx),
83
+
closed: false,
84
+
}
85
+
}
86
+
87
+
/// Close the stream and release resources.
88
+
///
89
+
/// After calling this, the stream will yield `None` on the next poll.
90
+
pub async fn close(&mut self) {
91
+
if let Some(sender) = self.close_sender.take() {
92
+
// Signal the background task to close
93
+
let _ = sender.send(()).await;
94
+
}
95
+
self.closed = true;
96
+
}
97
+
98
+
/// Returns true if the stream is closed.
99
+
pub fn is_closed(&self) -> bool {
100
+
self.closed
101
+
}
102
+
}
103
+
104
+
impl Stream for TapStream {
105
+
type Item = Result<Arc<TapEvent>, TapError>;
106
+
107
+
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
108
+
if self.closed {
109
+
return Poll::Ready(None);
110
+
}
111
+
112
+
self.receiver.poll_recv(cx)
113
+
}
114
+
}
115
+
116
+
impl Drop for TapStream {
117
+
fn drop(&mut self) {
118
+
// Drop the close_sender to signal the background task
119
+
self.close_sender.take();
120
+
tracing::debug!("TapStream dropped");
121
+
}
122
+
}
123
+
124
+
/// Background task that manages the WebSocket connection.
125
+
async fn connection_task(
126
+
config: TapConfig,
127
+
event_tx: mpsc::Sender<Result<Arc<TapEvent>, TapError>>,
128
+
mut close_rx: mpsc::Receiver<()>,
129
+
) {
130
+
let mut current_reconnect_delay = config.initial_reconnect_delay;
131
+
let mut attempt: u32 = 0;
132
+
133
+
loop {
134
+
// Check for close signal
135
+
if close_rx.try_recv().is_ok() {
136
+
tracing::debug!("Connection task received close signal");
137
+
break;
138
+
}
139
+
140
+
// Try to connect
141
+
tracing::debug!(attempt, hostname = %config.hostname, "Connecting to TAP service");
142
+
let conn_result = TapConnection::connect(&config).await;
143
+
144
+
match conn_result {
145
+
Ok(mut conn) => {
146
+
tracing::info!(hostname = %config.hostname, "TAP stream connected");
147
+
// Reset reconnection state on successful connect
148
+
current_reconnect_delay = config.initial_reconnect_delay;
149
+
attempt = 0;
150
+
151
+
// Event loop for this connection
152
+
loop {
153
+
tokio::select! {
154
+
biased;
155
+
156
+
_ = close_rx.recv() => {
157
+
tracing::debug!("Connection task received close signal during receive");
158
+
let _ = conn.close().await;
159
+
return;
160
+
}
161
+
162
+
recv_result = conn.recv() => {
163
+
match recv_result {
164
+
Ok(Some(msg)) => {
165
+
// Parse the message
166
+
match serde_json::from_str::<TapEvent>(&msg) {
167
+
Ok(event) => {
168
+
let event_id = event.id();
169
+
170
+
// Send ack if enabled (before sending event to channel)
171
+
if config.send_acks
172
+
&& let Err(err) = conn.send_ack(event_id).await
173
+
{
174
+
tracing::warn!(error = %err, "Failed to send ack");
175
+
// Don't break connection for ack errors
176
+
}
177
+
178
+
// Send event to channel
179
+
let event = Arc::new(event);
180
+
if event_tx.send(Ok(event)).await.is_err() {
181
+
// Receiver dropped, exit task
182
+
tracing::debug!("Event receiver dropped, closing connection");
183
+
let _ = conn.close().await;
184
+
return;
185
+
}
186
+
}
187
+
Err(err) => {
188
+
// Parse errors don't affect connection
189
+
tracing::warn!(error = %err, "Failed to parse TAP message");
190
+
191
+
// Try to extract just the ID using fallback parser
192
+
// so we can still ack the message even if full parsing fails
193
+
if config.send_acks {
194
+
if let Some(event_id) = extract_event_id(&msg) {
195
+
tracing::debug!(event_id, "Extracted event ID via fallback parser");
196
+
if let Err(ack_err) = conn.send_ack(event_id).await {
197
+
tracing::warn!(error = %ack_err, "Failed to send ack for unparseable message");
198
+
}
199
+
} else {
200
+
tracing::warn!("Could not extract event ID from unparseable message");
201
+
}
202
+
}
203
+
204
+
if event_tx.send(Err(TapError::ParseError(err.to_string()))).await.is_err() {
205
+
tracing::debug!("Event receiver dropped, closing connection");
206
+
let _ = conn.close().await;
207
+
return;
208
+
}
209
+
}
210
+
}
211
+
}
212
+
Ok(None) => {
213
+
// Connection closed by server
214
+
tracing::debug!("TAP connection closed by server");
215
+
break;
216
+
}
217
+
Err(err) => {
218
+
// Connection error
219
+
tracing::warn!(error = %err, "TAP connection error");
220
+
break;
221
+
}
222
+
}
223
+
}
224
+
}
225
+
}
226
+
}
227
+
Err(err) => {
228
+
tracing::warn!(error = %err, attempt, "Failed to connect to TAP service");
229
+
}
230
+
}
231
+
232
+
// Increment attempt counter
233
+
attempt += 1;
234
+
235
+
// Check if we've exceeded max attempts
236
+
if let Some(max) = config.max_reconnect_attempts
237
+
&& attempt >= max
238
+
{
239
+
tracing::error!(attempts = attempt, "Max reconnection attempts exceeded");
240
+
let _ = event_tx
241
+
.send(Err(TapError::MaxReconnectAttemptsExceeded(attempt)))
242
+
.await;
243
+
break;
244
+
}
245
+
246
+
// Wait before reconnecting with exponential backoff
247
+
tracing::debug!(
248
+
delay_ms = current_reconnect_delay.as_millis(),
249
+
attempt,
250
+
"Waiting before reconnection"
251
+
);
252
+
253
+
tokio::select! {
254
+
_ = close_rx.recv() => {
255
+
tracing::debug!("Connection task received close signal during backoff");
256
+
return;
257
+
}
258
+
_ = tokio::time::sleep(current_reconnect_delay) => {
259
+
// Update delay for next attempt
260
+
current_reconnect_delay = Duration::from_secs_f64(
261
+
(current_reconnect_delay.as_secs_f64() * config.reconnect_backoff_multiplier)
262
+
.min(config.max_reconnect_delay.as_secs_f64()),
263
+
);
264
+
}
265
+
}
266
+
}
267
+
268
+
tracing::debug!("Connection task exiting");
269
+
}
270
+
271
+
/// Create a new TAP stream with the given configuration.
272
+
pub fn connect(config: TapConfig) -> TapStream {
273
+
TapStream::new(config)
274
+
}
275
+
276
+
/// Create a new TAP stream connected to the given hostname.
277
+
///
278
+
/// Uses default configuration values.
279
+
pub fn connect_to(hostname: &str) -> TapStream {
280
+
TapStream::new(TapConfig::new(hostname))
281
+
}
282
+
283
+
#[cfg(test)]
284
+
mod tests {
285
+
use super::*;
286
+
287
+
#[test]
288
+
fn test_stream_initial_state() {
289
+
// Note: This test doesn't actually poll the stream, just checks initial state
290
+
// Creating a TapStream requires a tokio runtime for the spawn
291
+
}
292
+
293
+
#[tokio::test]
294
+
async fn test_stream_close() {
295
+
let mut stream = TapStream::new(TapConfig::new("localhost:9999"));
296
+
assert!(!stream.is_closed());
297
+
stream.close().await;
298
+
assert!(stream.is_closed());
299
+
}
300
+
301
+
#[test]
302
+
fn test_connect_functions() {
303
+
// These just create configs, actual connection happens in background task
304
+
// We can't test without a runtime, so just verify the types compile
305
+
let _ = TapConfig::new("localhost:2480");
306
+
}
307
+
308
+
#[test]
309
+
fn test_reconnect_delay_calculation() {
310
+
// Test the delay calculation logic
311
+
let initial = Duration::from_secs(1);
312
+
let max = Duration::from_secs(10);
313
+
let multiplier = 2.0;
314
+
315
+
let mut delay = initial;
316
+
assert_eq!(delay, Duration::from_secs(1));
317
+
318
+
delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
319
+
assert_eq!(delay, Duration::from_secs(2));
320
+
321
+
delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
322
+
assert_eq!(delay, Duration::from_secs(4));
323
+
324
+
delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
325
+
assert_eq!(delay, Duration::from_secs(8));
326
+
327
+
delay = Duration::from_secs_f64((delay.as_secs_f64() * multiplier).min(max.as_secs_f64()));
328
+
assert_eq!(delay, Duration::from_secs(10)); // Capped at max
329
+
}
330
+
}
+13
-13
crates/atproto-xrpcs/README.md
+13
-13
crates/atproto-xrpcs/README.md
···
23
23
### Basic XRPC Service
24
24
25
25
```rust
26
-
use atproto_xrpcs::authorization::ResolvingAuthorization;
26
+
use atproto_xrpcs::authorization::Authorization;
27
27
use axum::{Json, Router, extract::Query, routing::get};
28
28
use serde::Deserialize;
29
29
use serde_json::json;
···
35
35
36
36
async fn handle_hello(
37
37
params: Query<HelloParams>,
38
-
authorization: Option<ResolvingAuthorization>,
38
+
authorization: Option<Authorization>,
39
39
) -> Json<serde_json::Value> {
40
40
let name = params.name.as_deref().unwrap_or("World");
41
-
41
+
42
42
let message = if authorization.is_some() {
43
43
format!("Hello, authenticated {}!", name)
44
44
} else {
45
45
format!("Hello, {}!", name)
46
46
};
47
-
47
+
48
48
Json(json!({ "message": message }))
49
49
}
50
50
···
56
56
### JWT Authorization
57
57
58
58
```rust
59
-
use atproto_xrpcs::authorization::ResolvingAuthorization;
59
+
use atproto_xrpcs::authorization::Authorization;
60
60
61
61
async fn handle_secure_endpoint(
62
-
authorization: ResolvingAuthorization, // Required authorization
62
+
authorization: Authorization, // Required authorization
63
63
) -> Json<serde_json::Value> {
64
-
// The ResolvingAuthorization extractor automatically:
64
+
// The Authorization extractor automatically:
65
65
// 1. Validates the JWT token
66
-
// 2. Resolves the caller's DID document
66
+
// 2. Resolves the caller's DID document
67
67
// 3. Verifies the signature against the DID document
68
68
// 4. Provides access to caller identity information
69
-
69
+
70
70
let caller_did = authorization.subject();
71
71
Json(json!({"caller": caller_did, "status": "authenticated"}))
72
72
}
···
79
79
use axum::{response::IntoResponse, http::StatusCode};
80
80
81
81
async fn protected_handler(
82
-
authorization: Result<ResolvingAuthorization, AuthorizationError>,
82
+
authorization: Result<Authorization, AuthorizationError>,
83
83
) -> impl IntoResponse {
84
84
match authorization {
85
85
Ok(auth) => (StatusCode::OK, "Access granted").into_response(),
86
-
Err(AuthorizationError::InvalidJWTToken { .. }) => {
86
+
Err(AuthorizationError::InvalidJWTFormat) => {
87
87
(StatusCode::UNAUTHORIZED, "Invalid token").into_response()
88
88
}
89
-
Err(AuthorizationError::DIDDocumentResolutionFailed { .. }) => {
89
+
Err(AuthorizationError::SubjectResolutionFailed { .. }) => {
90
90
(StatusCode::FORBIDDEN, "Identity verification failed").into_response()
91
91
}
92
92
Err(_) => {
···
98
98
99
99
## Authorization Flow
100
100
101
-
The `ResolvingAuthorization` extractor implements:
101
+
The `Authorization` extractor implements:
102
102
103
103
1. JWT extraction from HTTP Authorization headers
104
104
2. Token validation (signature and claims structure)
+5
-49
crates/atproto-xrpcs/src/errors.rs
+5
-49
crates/atproto-xrpcs/src/errors.rs
···
42
42
#[error("error-atproto-xrpcs-authorization-4 No issuer found in JWT claims")]
43
43
NoIssuerInClaims,
44
44
45
-
/// Occurs when DID document is not found for the issuer
46
-
#[error("error-atproto-xrpcs-authorization-5 DID document not found for issuer: {issuer}")]
47
-
DIDDocumentNotFound {
48
-
/// The issuer DID that was not found
49
-
issuer: String,
50
-
},
51
-
52
45
/// Occurs when no verification keys are found in DID document
53
-
#[error("error-atproto-xrpcs-authorization-6 No verification keys found in DID document")]
46
+
#[error("error-atproto-xrpcs-authorization-5 No verification keys found in DID document")]
54
47
NoVerificationKeys,
55
48
56
49
/// Occurs when JWT header cannot be base64 decoded
57
-
#[error("error-atproto-xrpcs-authorization-7 Failed to decode JWT header: {error}")]
50
+
#[error("error-atproto-xrpcs-authorization-6 Failed to decode JWT header: {error}")]
58
51
HeaderDecodeError {
59
52
/// The underlying base64 decode error
60
53
error: base64::DecodeError,
61
54
},
62
55
63
56
/// Occurs when JWT header cannot be parsed as JSON
64
-
#[error("error-atproto-xrpcs-authorization-8 Failed to parse JWT header: {error}")]
57
+
#[error("error-atproto-xrpcs-authorization-7 Failed to parse JWT header: {error}")]
65
58
HeaderParseError {
66
59
/// The underlying JSON parse error
67
60
error: serde_json::Error,
68
61
},
69
62
70
63
/// Occurs when JWT validation fails with all available keys
71
-
#[error("error-atproto-xrpcs-authorization-9 JWT validation failed with all available keys")]
64
+
#[error("error-atproto-xrpcs-authorization-8 JWT validation failed with all available keys")]
72
65
ValidationFailedAllKeys,
73
66
74
67
/// Occurs when subject resolution fails during DID document lookup
75
-
#[error("error-atproto-xrpcs-authorization-10 Subject resolution failed: {issuer} {error}")]
68
+
#[error("error-atproto-xrpcs-authorization-9 Subject resolution failed: {issuer} {error}")]
76
69
SubjectResolutionFailed {
77
70
/// The issuer that failed to resolve
78
71
issuer: String,
79
72
/// The underlying resolution error
80
-
error: anyhow::Error,
81
-
},
82
-
83
-
/// Occurs when DID document lookup fails after successful resolution
84
-
#[error(
85
-
"error-atproto-xrpcs-authorization-11 DID document not found for resolved issuer: {resolved_did}"
86
-
)]
87
-
ResolvedDIDDocumentNotFound {
88
-
/// The resolved DID that was not found in storage
89
-
resolved_did: String,
90
-
},
91
-
92
-
/// Occurs when PLC directory query fails
93
-
#[error("error-atproto-xrpcs-authorization-12 PLC directory query failed: {error}")]
94
-
PLCQueryFailed {
95
-
/// The underlying PLC query error
96
-
error: anyhow::Error,
97
-
},
98
-
99
-
/// Occurs when web DID query fails
100
-
#[error("error-atproto-xrpcs-authorization-13 Web DID query failed: {error}")]
101
-
WebDIDQueryFailed {
102
-
/// The underlying web DID query error
103
-
error: anyhow::Error,
104
-
},
105
-
106
-
/// Occurs when DID document storage operation fails
107
-
#[error("error-atproto-xrpcs-authorization-14 DID document storage failed: {error}")]
108
-
DocumentStorageFailed {
109
-
/// The underlying storage error
110
-
error: anyhow::Error,
111
-
},
112
-
113
-
/// Occurs when input parsing fails for resolved DID
114
-
#[error("error-atproto-xrpcs-authorization-15 Input parsing failed for resolved DID: {error}")]
115
-
InputParsingFailed {
116
-
/// The underlying parsing error
117
73
error: anyhow::Error,
118
74
},
119
75
}
+3
-13
crates/atproto-xrpcs-helloworld/src/main.rs
+3
-13
crates/atproto-xrpcs-helloworld/src/main.rs
···
7
7
config::{CertificateBundles, DnsNameservers, default_env, optional_env, require_env, version},
8
8
key::{KeyData, KeyResolver, identify_key, to_public},
9
9
resolve::{HickoryDnsResolver, IdentityResolver, InnerIdentityResolver},
10
-
storage_lru::LruDidDocumentStorage,
11
-
traits::DidDocumentStorage,
12
10
};
13
-
use atproto_xrpcs::authorization::ResolvingAuthorization;
11
+
use atproto_xrpcs::authorization::Authorization;
14
12
use axum::{
15
13
Json, Router,
16
14
extract::{FromRef, Query, State},
···
21
19
use http::{HeaderMap, StatusCode};
22
20
use serde::Deserialize;
23
21
use serde_json::json;
24
-
use std::{collections::HashMap, num::NonZeroUsize, ops::Deref, sync::Arc};
22
+
use std::{collections::HashMap, ops::Deref, sync::Arc};
25
23
26
24
#[derive(Clone)]
27
25
pub struct SimpleKeyResolver {
···
61
59
62
60
pub struct InnerWebContext {
63
61
pub http_client: reqwest::Client,
64
-
pub document_storage: Arc<dyn DidDocumentStorage>,
65
62
pub key_resolver: Arc<dyn KeyResolver>,
66
63
pub service_document: ServiceDocument,
67
64
pub service_did: ServiceDID,
···
97
94
}
98
95
}
99
96
100
-
impl FromRef<WebContext> for Arc<dyn DidDocumentStorage> {
101
-
fn from_ref(context: &WebContext) -> Self {
102
-
context.0.document_storage.clone()
103
-
}
104
-
}
105
-
106
97
impl FromRef<WebContext> for Arc<dyn KeyResolver> {
107
98
fn from_ref(context: &WebContext) -> Self {
108
99
context.0.key_resolver.clone()
···
216
207
217
208
let web_context = WebContext(Arc::new(InnerWebContext {
218
209
http_client: http_client.clone(),
219
-
document_storage: Arc::new(LruDidDocumentStorage::new(NonZeroUsize::new(255).unwrap())),
220
210
key_resolver: Arc::new(SimpleKeyResolver {
221
211
keys: signing_key_storage,
222
212
}),
···
284
274
async fn handle_xrpc_hello_world(
285
275
parameters: Query<HelloParameters>,
286
276
headers: HeaderMap,
287
-
authorization: Option<ResolvingAuthorization>,
277
+
authorization: Option<Authorization>,
288
278
) -> Json<serde_json::Value> {
289
279
println!("headers {headers:?}");
290
280
let subject = parameters.subject.as_deref().unwrap_or("World");