tangled
alpha
login
or
join now
quilling.dev
/
parakeet
forked from
parakeet.at/parakeet
1
fork
atom
Rust AppView - highly experimental!
1
fork
atom
overview
issues
pulls
pipelines
init: parakeet-appview
quilling.dev
3 months ago
4910dffa
ec1a2a2e
+2018
14 changed files
expand all
collapse all
unified
split
parakeet-appview
Cargo.toml
src
entity
implementations.rs
store.rs
traits.rs
handlers
actor.rs
lib.rs
macros.rs
xrpc
auth.rs
context.rs
error.rs
helpers.rs
rate_limiter.rs
parakeet-appview-macros
Cargo.toml
src
lib.rs
+16
parakeet-appview-macros/Cargo.toml
reviewed
···
1
1
+
[package]
2
2
+
name = "parakeet-appview-macros"
3
3
+
version = "0.1.0"
4
4
+
edition = "2021"
5
5
+
6
6
+
[lib]
7
7
+
proc-macro = true
8
8
+
9
9
+
[dependencies]
10
10
+
syn = { version = "2.0", features = ["full", "extra-traits"] }
11
11
+
quote = "1.0"
12
12
+
proc-macro2 = "1.0"
13
13
+
14
14
+
[lints.rust]
15
15
+
warnings = { level = "warn", priority = -1 }
16
16
+
deprecated-safe = { level = "warn", priority = -1 }
+331
parakeet-appview-macros/src/lib.rs
reviewed
···
1
1
+
//! Procedural macros for parakeet-appview.
2
2
+
//!
3
3
+
//! This crate provides the `#[xrpc_handler]` attribute macro that generates
4
4
+
//! XRPC handler boilerplate, working with Jacquard types.
5
5
+
6
6
+
use proc_macro::TokenStream;
7
7
+
use quote::quote;
8
8
+
use syn::{parse_macro_input, parse::{Parse, ParseStream}, ItemFn, Token, LitStr};
9
9
+
10
10
+
/// Procedural macro for XRPC handlers that generates boilerplate code.
11
11
+
///
12
12
+
/// This macro:
13
13
+
/// - Generates proper Axum handler signatures
14
14
+
/// - Handles authentication requirements
15
15
+
/// - Manages identifier resolution
16
16
+
/// - Provides pagination support
17
17
+
/// - Maps errors appropriately
18
18
+
///
19
19
+
/// # Attributes
20
20
+
///
21
21
+
/// - `auth`: Authentication requirement (required/optional/none)
22
22
+
/// - `resolve`: Identifier resolution mapping
23
23
+
/// - `paginate`: Enable pagination with default settings
24
24
+
/// - `rate_limit`: Rate limiting configuration
25
25
+
///
26
26
+
/// # Example
27
27
+
///
28
28
+
/// ```rust
29
29
+
/// #[xrpc_handler(
30
30
+
/// auth = "optional",
31
31
+
/// paginate = true
32
32
+
/// )]
33
33
+
/// async fn get_profile(req: GetProfileRequest) -> Result<GetProfileOutput, XrpcError> {
34
34
+
/// // Handler implementation
35
35
+
/// }
36
36
+
/// ```
37
37
+
///
38
38
+
/// This generates an Axum handler with:
39
39
+
/// - Optional authentication extraction
40
40
+
/// - Pagination parameters
41
41
+
/// - Proper error handling
42
42
+
/// - Context building
43
43
+
#[proc_macro_attribute]
44
44
+
pub fn xrpc_handler(args: TokenStream, input: TokenStream) -> TokenStream {
45
45
+
let config = parse_macro_input!(args as HandlerConfig);
46
46
+
let input_fn = parse_macro_input!(input as ItemFn);
47
47
+
48
48
+
// Generate the expanded handler
49
49
+
let expanded = generate_handler(&input_fn, &config);
50
50
+
51
51
+
TokenStream::from(expanded)
52
52
+
}
53
53
+
54
54
+
/// Configuration parsed from the macro attributes
55
55
+
struct HandlerConfig {
56
56
+
auth: AuthRequirement,
57
57
+
paginate: bool,
58
58
+
resolve: Vec<ResolveConfig>,
59
59
+
rate_limit: Option<String>,
60
60
+
}
61
61
+
62
62
+
#[derive(Clone)]
63
63
+
enum AuthRequirement {
64
64
+
Required,
65
65
+
Optional,
66
66
+
None,
67
67
+
}
68
68
+
69
69
+
struct ResolveConfig {
70
70
+
field: syn::Ident,
71
71
+
target: syn::Ident,
72
72
+
}
73
73
+
74
74
+
impl Parse for HandlerConfig {
75
75
+
fn parse(input: ParseStream) -> syn::Result<Self> {
76
76
+
let mut auth = AuthRequirement::None;
77
77
+
let mut paginate = false;
78
78
+
let mut resolve = Vec::new();
79
79
+
let mut rate_limit = None;
80
80
+
81
81
+
while !input.is_empty() {
82
82
+
let ident: syn::Ident = input.parse()?;
83
83
+
input.parse::<Token![=]>()?;
84
84
+
85
85
+
match ident.to_string().as_str() {
86
86
+
"auth" => {
87
87
+
let value: LitStr = input.parse()?;
88
88
+
auth = match value.value().as_str() {
89
89
+
"required" => AuthRequirement::Required,
90
90
+
"optional" => AuthRequirement::Optional,
91
91
+
"none" => AuthRequirement::None,
92
92
+
_ => {
93
93
+
return Err(syn::Error::new_spanned(
94
94
+
value,
95
95
+
"auth must be 'required', 'optional', or 'none'",
96
96
+
))
97
97
+
}
98
98
+
};
99
99
+
}
100
100
+
"paginate" => {
101
101
+
let value: syn::LitBool = input.parse()?;
102
102
+
paginate = value.value();
103
103
+
}
104
104
+
"resolve" => {
105
105
+
// Parse resolve mapping like: req.actor -> actor_id
106
106
+
let content;
107
107
+
syn::braced!(content in input);
108
108
+
109
109
+
let field_path: syn::ExprField = content.parse()?;
110
110
+
content.parse::<Token![->]>()?;
111
111
+
let target: syn::Ident = content.parse()?;
112
112
+
113
113
+
if let syn::Expr::Field(field) = field_path.base.as_ref() {
114
114
+
if let syn::Member::Named(field_name) = &field.member {
115
115
+
resolve.push(ResolveConfig {
116
116
+
field: field_name.clone(),
117
117
+
target,
118
118
+
});
119
119
+
}
120
120
+
}
121
121
+
}
122
122
+
"rate_limit" => {
123
123
+
let value: LitStr = input.parse()?;
124
124
+
rate_limit = Some(value.value());
125
125
+
}
126
126
+
_ => {
127
127
+
return Err(syn::Error::new_spanned(
128
128
+
&ident,
129
129
+
format!("Unknown attribute: {}", ident),
130
130
+
))
131
131
+
}
132
132
+
}
133
133
+
134
134
+
if !input.is_empty() {
135
135
+
input.parse::<Token![,]>()?;
136
136
+
}
137
137
+
}
138
138
+
139
139
+
Ok(HandlerConfig {
140
140
+
auth,
141
141
+
paginate,
142
142
+
resolve,
143
143
+
rate_limit,
144
144
+
})
145
145
+
}
146
146
+
}
147
147
+
148
148
+
fn generate_handler(input_fn: &ItemFn, config: &HandlerConfig) -> proc_macro2::TokenStream {
149
149
+
let fn_name = &input_fn.sig.ident;
150
150
+
let fn_body = &input_fn.block;
151
151
+
let vis = &input_fn.vis;
152
152
+
let _asyncness = &input_fn.sig.asyncness;
153
153
+
154
154
+
// Extract request and response types from the function signature
155
155
+
let (req_type, resp_type) = extract_types(&input_fn.sig);
156
156
+
157
157
+
// Generate auth parameter based on requirement
158
158
+
let auth_param = match config.auth {
159
159
+
AuthRequirement::Required => quote! {
160
160
+
auth: crate::xrpc::AtpAuth,
161
161
+
},
162
162
+
AuthRequirement::Optional => quote! {
163
163
+
maybe_auth: ::core::option::Option<crate::xrpc::AtpAuth>,
164
164
+
},
165
165
+
AuthRequirement::None => quote! {},
166
166
+
};
167
167
+
168
168
+
// Generate pagination parameter if enabled
169
169
+
let pagination_param = if config.paginate {
170
170
+
quote! {
171
171
+
pagination: ::axum::extract::Query<crate::xrpc::PaginationParams>,
172
172
+
}
173
173
+
} else {
174
174
+
quote! {}
175
175
+
};
176
176
+
177
177
+
// Generate context building based on auth
178
178
+
let build_context = match config.auth {
179
179
+
AuthRequirement::Required => quote! {
180
180
+
let ctx = crate::xrpc::XrpcContextBuilder::new(state.clone())
181
181
+
.auth(auth.0.clone(), None)
182
182
+
.build();
183
183
+
},
184
184
+
AuthRequirement::Optional => quote! {
185
185
+
let ctx = match maybe_auth {
186
186
+
Some(auth) => crate::xrpc::XrpcContextBuilder::new(state.clone())
187
187
+
.auth(auth.0.clone(), None)
188
188
+
.build(),
189
189
+
None => crate::xrpc::XrpcContext {
190
190
+
state: state.clone(),
191
191
+
auth_did: None,
192
192
+
viewer_id: None,
193
193
+
labelers: labelers.0.clone(),
194
194
+
},
195
195
+
};
196
196
+
},
197
197
+
AuthRequirement::None => quote! {
198
198
+
let ctx = crate::xrpc::XrpcContext {
199
199
+
state: state.clone(),
200
200
+
auth_did: None,
201
201
+
viewer_id: None,
202
202
+
labelers: labelers.0.clone(),
203
203
+
};
204
204
+
},
205
205
+
};
206
206
+
207
207
+
// Generate resolver calls for identifier resolution
208
208
+
let resolver_calls = config.resolve.iter().map(|r| {
209
209
+
let field = &r.field;
210
210
+
let target = &r.target;
211
211
+
quote! {
212
212
+
let #target = if let Some(identifier) = req.#field.as_ref() {
213
213
+
ctx.state.profile_entity
214
214
+
.resolve_identifier(identifier)
215
215
+
.await
216
216
+
.map_err(|e| crate::xrpc::XrpcError::BadRequest(e.to_string()))?
217
217
+
.ok_or_else(|| crate::xrpc::XrpcError::ActorNotFound(
218
218
+
format!("Actor not found: {}", identifier)
219
219
+
))?
220
220
+
} else {
221
221
+
return Err(crate::xrpc::XrpcError::BadRequest(
222
222
+
format!("Missing required field: {}", stringify!(#field))
223
223
+
));
224
224
+
};
225
225
+
}
226
226
+
});
227
227
+
228
228
+
// Handle pagination extraction
229
229
+
let pagination_extract = if config.paginate {
230
230
+
quote! {
231
231
+
let pagination_params = pagination.0;
232
232
+
}
233
233
+
} else {
234
234
+
quote! {}
235
235
+
};
236
236
+
237
237
+
238
238
+
// Generate rate limit check if configured
239
239
+
let rate_limit_check = if let Some(limit_config) = &config.rate_limit {
240
240
+
match config.auth {
241
241
+
AuthRequirement::Required => quote! {
242
242
+
// Rate limit by authenticated DID
243
243
+
let rate_limit_key = auth.0.clone();
244
244
+
if !state.rate_limiter.check(&rate_limit_key, Some(#limit_config)) {
245
245
+
return Err(crate::xrpc::XrpcError::RateLimitExceeded);
246
246
+
}
247
247
+
},
248
248
+
AuthRequirement::Optional => quote! {
249
249
+
// Rate limit by DID if authenticated, otherwise by a default key
250
250
+
let rate_limit_key = maybe_auth
251
251
+
.as_ref()
252
252
+
.map(|a| a.0.clone())
253
253
+
.unwrap_or_else(|| "anonymous".to_string());
254
254
+
if !state.rate_limiter.check(&rate_limit_key, Some(#limit_config)) {
255
255
+
return Err(crate::xrpc::XrpcError::RateLimitExceeded);
256
256
+
}
257
257
+
},
258
258
+
AuthRequirement::None => quote! {
259
259
+
// Rate limit by a default key for unauthenticated endpoints
260
260
+
if !state.rate_limiter.check("anonymous", Some(#limit_config)) {
261
261
+
return Err(crate::xrpc::XrpcError::RateLimitExceeded);
262
262
+
}
263
263
+
},
264
264
+
}
265
265
+
} else {
266
266
+
quote! {}
267
267
+
};
268
268
+
269
269
+
// Generate the expanded handler
270
270
+
quote! {
271
271
+
// Public wrapper function with Axum signature
272
272
+
#vis async fn #fn_name(
273
273
+
::axum::extract::State(state): ::axum::extract::State<::std::sync::Arc<crate::xrpc::GlobalState>>,
274
274
+
#pagination_param
275
275
+
labelers: crate::xrpc::AtpAcceptLabelers,
276
276
+
#auth_param
277
277
+
crate::xrpc::ExtractXrpc(req): crate::xrpc::ExtractXrpc<#req_type>,
278
278
+
) -> crate::xrpc::XrpcResult<::axum::Json<#resp_type>> {
279
279
+
#rate_limit_check
280
280
+
#build_context
281
281
+
#pagination_extract
282
282
+
#(#resolver_calls)*
283
283
+
284
284
+
// Original handler body with injected context
285
285
+
// The body is expected to return Result<Response, XrpcError>
286
286
+
let result_res: Result<#resp_type, crate::xrpc::XrpcError> = {
287
287
+
#fn_body
288
288
+
};
289
289
+
let result = result_res?;
290
290
+
291
291
+
Ok(::axum::Json(result))
292
292
+
}
293
293
+
}
294
294
+
}
295
295
+
296
296
+
fn extract_types(sig: &syn::Signature) -> (proc_macro2::TokenStream, proc_macro2::TokenStream) {
297
297
+
// Extract request type from first parameter
298
298
+
let req_type = if let Some(arg) = sig.inputs.first() {
299
299
+
if let syn::FnArg::Typed(pat_type) = arg {
300
300
+
let ty = &pat_type.ty;
301
301
+
quote! { #ty }
302
302
+
} else {
303
303
+
// Default to unit type if no request parameter
304
304
+
quote! { () }
305
305
+
}
306
306
+
} else {
307
307
+
quote! { () }
308
308
+
};
309
309
+
310
310
+
// Extract response type from return type
311
311
+
let resp_type = match &sig.output {
312
312
+
syn::ReturnType::Type(_, ty) => {
313
313
+
// Handle Result<T, E> by extracting T
314
314
+
if let syn::Type::Path(type_path) = ty.as_ref() {
315
315
+
if let Some(segment) = type_path.path.segments.last() {
316
316
+
if segment.ident == "Result" {
317
317
+
if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
318
318
+
if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
319
319
+
return (req_type, quote! { #inner_ty });
320
320
+
}
321
321
+
}
322
322
+
}
323
323
+
}
324
324
+
}
325
325
+
quote! { #ty }
326
326
+
}
327
327
+
syn::ReturnType::Default => quote! { () },
328
328
+
};
329
329
+
330
330
+
(req_type, resp_type)
331
331
+
}
+60
parakeet-appview/Cargo.toml
reviewed
···
1
1
+
[package]
2
2
+
name = "parakeet-appview"
3
3
+
version = "0.1.0"
4
4
+
edition = "2021"
5
5
+
6
6
+
[dependencies]
7
7
+
# Jacquard dependencies for AT Protocol types
8
8
+
jacquard = { workspace = true }
9
9
+
jacquard-api = { workspace = true }
10
10
+
jacquard-axum = { workspace = true }
11
11
+
jacquard-common = { workspace = true }
12
12
+
13
13
+
# Web framework
14
14
+
axum = { version = "0.8", features = ["json", "macros"] }
15
15
+
axum-extra = { version = "0.10", features = ["query", "typed-header"] }
16
16
+
17
17
+
# Database
18
18
+
diesel = { version = "2.2", features = ["chrono", "serde_json"] }
19
19
+
diesel-async = { version = "0.5", features = ["deadpool", "postgres"] }
20
20
+
deadpool = "0.12"
21
21
+
22
22
+
# Caching
23
23
+
moka = { version = "0.12", features = ["future"] }
24
24
+
dashmap = "6.0"
25
25
+
26
26
+
# Async runtime
27
27
+
tokio = { version = "1.42", features = ["full"] }
28
28
+
futures = "0.3"
29
29
+
async-trait = "0.1"
30
30
+
31
31
+
# Serialization
32
32
+
serde = { version = "1.0", features = ["derive"] }
33
33
+
serde_json = "1.0"
34
34
+
serde_urlencoded = "0.7"
35
35
+
36
36
+
# Error handling
37
37
+
eyre = "0.6"
38
38
+
color-eyre = "0.6"
39
39
+
thiserror = "2.0"
40
40
+
41
41
+
# Utilities
42
42
+
chrono = { version = "0.4", features = ["serde"] }
43
43
+
tracing = "0.1"
44
44
+
itertools = "0.14"
45
45
+
reqwest = { version = "0.12", features = ["json"] }
46
46
+
47
47
+
# Internal dependencies
48
48
+
parakeet-db = { path = "../parakeet-db" }
49
49
+
parakeet-appview-macros = { path = "../parakeet-appview-macros", optional = true }
50
50
+
51
51
+
[features]
52
52
+
default = ["macros"]
53
53
+
macros = ["dep:parakeet-appview-macros"]
54
54
+
55
55
+
[dev-dependencies]
56
56
+
tokio-test = "0.4"
57
57
+
58
58
+
[lints.rust]
59
59
+
warnings = { level = "warn", priority = -1 }
60
60
+
deprecated-safe = { level = "warn", priority = -1 }
+126
parakeet-appview/src/entity/implementations.rs
reviewed
···
1
1
+
//! Entity implementations for parakeet-db models.
2
2
+
3
3
+
use std::time::Duration;
4
4
+
5
5
+
use async_trait::async_trait;
6
6
+
use diesel_async::AsyncPgConnection;
7
7
+
use eyre::Result;
8
8
+
9
9
+
use parakeet_db::models::{Actor, Post};
10
10
+
11
11
+
use super::traits::{CachedEntity, ResolvableEntity, InvalidatableEntity};
12
12
+
13
13
+
/// Profile entity for actor caching.
14
14
+
pub struct ProfileEntity;
15
15
+
16
16
+
#[async_trait]
17
17
+
impl CachedEntity for ProfileEntity {
18
18
+
type Id = i32;
19
19
+
type Model = Actor;
20
20
+
21
21
+
async fn fetch_one(
22
22
+
&self,
23
23
+
_conn: &mut AsyncPgConnection,
24
24
+
_id: &Self::Id,
25
25
+
) -> Result<Option<Self::Model>> {
26
26
+
// TODO: Implement
27
27
+
Ok(None)
28
28
+
}
29
29
+
30
30
+
async fn fetch_batch(
31
31
+
&self,
32
32
+
_conn: &mut AsyncPgConnection,
33
33
+
_ids: &[Self::Id],
34
34
+
) -> Result<Vec<(Self::Id, Self::Model)>> {
35
35
+
// TODO: Implement
36
36
+
Ok(Vec::new())
37
37
+
}
38
38
+
39
39
+
fn cache_ttl(&self) -> Duration {
40
40
+
Duration::from_secs(3600) // 1 hour
41
41
+
}
42
42
+
43
43
+
fn cache_key(&self, id: &Self::Id) -> String {
44
44
+
format!("profile:{}", id)
45
45
+
}
46
46
+
}
47
47
+
48
48
+
#[async_trait]
49
49
+
impl ResolvableEntity for ProfileEntity {
50
50
+
async fn resolve_identifier(
51
51
+
&self,
52
52
+
_conn: &mut AsyncPgConnection,
53
53
+
_identifier: &str,
54
54
+
) -> Result<Option<Self::Id>> {
55
55
+
// TODO: Implement DID/handle resolution
56
56
+
Ok(None)
57
57
+
}
58
58
+
}
59
59
+
60
60
+
#[async_trait]
61
61
+
impl InvalidatableEntity for ProfileEntity {
62
62
+
async fn invalidate(&self, _payload: &str) -> Result<()> {
63
63
+
// Payload format: "profile:123"
64
64
+
Ok(())
65
65
+
}
66
66
+
67
67
+
fn parse_invalidation(&self, payload: &str) -> Option<Self::Id> {
68
68
+
payload
69
69
+
.strip_prefix("profile:")
70
70
+
.and_then(|id_str| id_str.parse::<i32>().ok())
71
71
+
}
72
72
+
}
73
73
+
74
74
+
/// Post entity for caching posts.
75
75
+
pub struct PostEntity;
76
76
+
77
77
+
#[async_trait]
78
78
+
impl CachedEntity for PostEntity {
79
79
+
type Id = (i32, i64); // (actor_id, rkey)
80
80
+
type Model = Post;
81
81
+
82
82
+
async fn fetch_one(
83
83
+
&self,
84
84
+
_conn: &mut AsyncPgConnection,
85
85
+
_id: &Self::Id,
86
86
+
) -> Result<Option<Self::Model>> {
87
87
+
// TODO: Implement
88
88
+
Ok(None)
89
89
+
}
90
90
+
91
91
+
async fn fetch_batch(
92
92
+
&self,
93
93
+
_conn: &mut AsyncPgConnection,
94
94
+
_ids: &[Self::Id],
95
95
+
) -> Result<Vec<(Self::Id, Self::Model)>> {
96
96
+
// TODO: Implement
97
97
+
Ok(Vec::new())
98
98
+
}
99
99
+
100
100
+
fn cache_ttl(&self) -> Duration {
101
101
+
Duration::from_secs(1800) // 30 minutes
102
102
+
}
103
103
+
104
104
+
fn cache_key(&self, id: &Self::Id) -> String {
105
105
+
format!("post:{}:{}", id.0, id.1)
106
106
+
}
107
107
+
}
108
108
+
109
109
+
#[async_trait]
110
110
+
impl InvalidatableEntity for PostEntity {
111
111
+
async fn invalidate(&self, _payload: &str) -> Result<()> {
112
112
+
// Payload format: "post:123:456789"
113
113
+
Ok(())
114
114
+
}
115
115
+
116
116
+
fn parse_invalidation(&self, payload: &str) -> Option<Self::Id> {
117
117
+
let parts: Vec<&str> = payload.strip_prefix("post:")?.split(':').collect();
118
118
+
if parts.len() == 2 {
119
119
+
let actor_id = parts[0].parse::<i32>().ok()?;
120
120
+
let rkey = parts[1].parse::<i64>().ok()?;
121
121
+
Some((actor_id, rkey))
122
122
+
} else {
123
123
+
None
124
124
+
}
125
125
+
}
126
126
+
}
+260
parakeet-appview/src/entity/store.rs
reviewed
···
1
1
+
use std::collections::{HashMap, HashSet};
2
2
+
use std::sync::Arc;
3
3
+
use std::time::Duration;
4
4
+
5
5
+
use deadpool::managed::Pool;
6
6
+
use diesel_async::AsyncPgConnection;
7
7
+
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
8
8
+
use eyre::Result;
9
9
+
use moka::future::Cache;
10
10
+
use tracing::{debug, trace, warn};
11
11
+
12
12
+
use super::traits::{BatchResult, CachedEntity, EntityConfig, InvalidatableEntity, ResolvableEntity};
13
13
+
14
14
+
/// Generic entity store that provides caching and batching for any entity type.
15
15
+
///
16
16
+
/// This store handles:
17
17
+
/// - Multi-level caching (entity cache + optional identifier resolution cache)
18
18
+
/// - Batch fetching with deduplication
19
19
+
/// - Cache invalidation
20
20
+
/// - Automatic TTL management via Moka
21
21
+
pub struct EntityStore<E: CachedEntity> {
22
22
+
/// The entity implementation
23
23
+
entity: Arc<E>,
24
24
+
25
25
+
/// Main cache for entities (ID -> Model)
26
26
+
cache: Cache<E::Id, E::Model>,
27
27
+
28
28
+
/// Optional cache for identifier resolution (e.g., DID -> ID)
29
29
+
identifier_cache: Option<Cache<String, E::Id>>,
30
30
+
31
31
+
/// Database connection pool
32
32
+
db_pool: Arc<Pool<AsyncDieselConnectionManager<AsyncPgConnection>>>,
33
33
+
34
34
+
/// Configuration
35
35
+
config: EntityConfig,
36
36
+
}
37
37
+
38
38
+
impl<E: CachedEntity> EntityStore<E> {
39
39
+
/// Create a new entity store
40
40
+
pub fn new(
41
41
+
entity: E,
42
42
+
db_pool: Arc<Pool<AsyncDieselConnectionManager<AsyncPgConnection>>>,
43
43
+
config: EntityConfig,
44
44
+
) -> Self {
45
45
+
// Build the main entity cache with Moka
46
46
+
let mut cache_builder = Cache::builder()
47
47
+
.max_capacity(config.max_capacity)
48
48
+
.time_to_live(config.cache_ttl);
49
49
+
50
50
+
if let Some(idle) = config.idle_timeout {
51
51
+
cache_builder = cache_builder.time_to_idle(idle);
52
52
+
}
53
53
+
54
54
+
let cache = cache_builder.build();
55
55
+
56
56
+
Self {
57
57
+
entity: Arc::new(entity),
58
58
+
cache,
59
59
+
identifier_cache: None,
60
60
+
db_pool,
61
61
+
config,
62
62
+
}
63
63
+
}
64
64
+
65
65
+
/// Enable identifier resolution caching
66
66
+
pub fn with_identifier_cache(mut self, ttl: Duration) -> Self
67
67
+
where
68
68
+
E: ResolvableEntity,
69
69
+
{
70
70
+
self.identifier_cache = Some(
71
71
+
Cache::builder()
72
72
+
.max_capacity(self.config.max_capacity * 2) // More identifiers than entities
73
73
+
.time_to_live(ttl)
74
74
+
.build(),
75
75
+
);
76
76
+
self
77
77
+
}
78
78
+
79
79
+
/// Get a single entity by ID
80
80
+
pub async fn get(&self, id: &E::Id) -> Result<Option<E::Model>> {
81
81
+
// Check cache first
82
82
+
if let Some(cached) = self.cache.get(id).await {
83
83
+
trace!("Cache hit for entity {:?}", self.entity.cache_key(id));
84
84
+
return Ok(Some(cached));
85
85
+
}
86
86
+
87
87
+
// Cache miss - fetch from database
88
88
+
trace!("Cache miss for entity {:?}", self.entity.cache_key(id));
89
89
+
90
90
+
let mut conn = self.db_pool.get().await?;
91
91
+
let result = self.entity.fetch_one(&mut conn, id).await?;
92
92
+
93
93
+
// Cache the result if found
94
94
+
if let Some(ref model) = result {
95
95
+
self.cache.insert(id.clone(), model.clone()).await;
96
96
+
debug!("Cached entity {:?}", self.entity.cache_key(id));
97
97
+
}
98
98
+
99
99
+
Ok(result)
100
100
+
}
101
101
+
102
102
+
/// Get multiple entities by IDs with batching and caching
103
103
+
pub async fn get_many(&self, ids: &[E::Id]) -> Result<BatchResult<E::Id, E::Model>> {
104
104
+
if ids.is_empty() {
105
105
+
return Ok(BatchResult::new(vec![], vec![]));
106
106
+
}
107
107
+
108
108
+
let mut results = Vec::with_capacity(ids.len());
109
109
+
let mut missing_ids = Vec::new();
110
110
+
let mut ids_to_fetch = HashSet::new();
111
111
+
112
112
+
// First pass: check cache
113
113
+
for id in ids {
114
114
+
if let Some(cached) = self.cache.get(id).await {
115
115
+
results.push((id.clone(), cached));
116
116
+
} else {
117
117
+
ids_to_fetch.insert(id.clone());
118
118
+
}
119
119
+
}
120
120
+
121
121
+
// Fetch missing items from database
122
122
+
if !ids_to_fetch.is_empty() {
123
123
+
debug!("Fetching {} entities from database", ids_to_fetch.len());
124
124
+
125
125
+
let fetch_vec: Vec<E::Id> = ids_to_fetch.iter().cloned().collect();
126
126
+
let mut conn = self.db_pool.get().await?;
127
127
+
let fetched = self.entity.fetch_batch(&mut conn, &fetch_vec).await?;
128
128
+
129
129
+
// Cache fetched items and add to results
130
130
+
let mut fetched_map = HashMap::new();
131
131
+
for (id, model) in fetched {
132
132
+
self.cache.insert(id.clone(), model.clone()).await;
133
133
+
fetched_map.insert(id.clone(), model);
134
134
+
}
135
135
+
136
136
+
// Add fetched items to results, track missing
137
137
+
for id in ids_to_fetch {
138
138
+
if let Some(model) = fetched_map.remove(&id) {
139
139
+
results.push((id, model));
140
140
+
} else {
141
141
+
missing_ids.push(id);
142
142
+
}
143
143
+
}
144
144
+
}
145
145
+
146
146
+
// Preserve original order
147
147
+
let result_map: HashMap<E::Id, E::Model> = results.into_iter().collect();
148
148
+
let mut ordered_results = Vec::new();
149
149
+
let mut ordered_missing = Vec::new();
150
150
+
151
151
+
for id in ids {
152
152
+
if let Some(model) = result_map.get(id) {
153
153
+
ordered_results.push((id.clone(), model.clone()));
154
154
+
} else if !result_map.contains_key(id) {
155
155
+
ordered_missing.push(id.clone());
156
156
+
}
157
157
+
}
158
158
+
159
159
+
Ok(BatchResult::new(ordered_results, ordered_missing))
160
160
+
}
161
161
+
162
162
+
// Note: Jacquard view conversion is now handled directly by parakeet-db models
163
163
+
// Models have methods like to_profile_view(), to_post_view(), etc.
164
164
+
165
165
+
/// Clear all cached entries
166
166
+
pub async fn clear_cache(&self) {
167
167
+
self.cache.invalidate_all();
168
168
+
self.cache.run_pending_tasks().await;
169
169
+
170
170
+
if let Some(ref id_cache) = self.identifier_cache {
171
171
+
id_cache.invalidate_all();
172
172
+
id_cache.run_pending_tasks().await;
173
173
+
}
174
174
+
}
175
175
+
176
176
+
/// Get cache statistics
177
177
+
pub fn cache_stats(&self) -> CacheStats {
178
178
+
CacheStats {
179
179
+
entry_count: self.cache.entry_count(),
180
180
+
weighted_size: self.cache.weighted_size(),
181
181
+
}
182
182
+
}
183
183
+
}
184
184
+
185
185
+
/// Extended methods for entities that support identifier resolution
186
186
+
impl<E> EntityStore<E>
187
187
+
where
188
188
+
E: CachedEntity + ResolvableEntity,
189
189
+
{
190
190
+
/// Resolve an identifier to an ID
191
191
+
pub async fn resolve_identifier(&self, identifier: &str) -> Result<Option<E::Id>> {
192
192
+
// Check identifier cache if enabled
193
193
+
if let Some(ref id_cache) = self.identifier_cache {
194
194
+
if let Some(cached_id) = id_cache.get(identifier).await {
195
195
+
trace!("Identifier cache hit for {}", identifier);
196
196
+
return Ok(Some(cached_id));
197
197
+
}
198
198
+
}
199
199
+
200
200
+
// Resolve from database
201
201
+
let mut conn = self.db_pool.get().await?;
202
202
+
let result = self.entity.resolve_identifier(&mut conn, identifier).await?;
203
203
+
204
204
+
// Cache the result if found
205
205
+
if let (Some(ref id_cache), Some(ref id)) = (&self.identifier_cache, &result) {
206
206
+
id_cache.insert(identifier.to_string(), id.clone()).await;
207
207
+
debug!("Cached identifier mapping {} -> {:?}", identifier, self.entity.cache_key(id));
208
208
+
}
209
209
+
210
210
+
Ok(result)
211
211
+
}
212
212
+
213
213
+
/// Resolve and get an entity by identifier
214
214
+
pub async fn get_by_identifier(&self, identifier: &str) -> Result<Option<E::Model>> {
215
215
+
let id = self.resolve_identifier(identifier).await?;
216
216
+
match id {
217
217
+
Some(id) => self.get(&id).await,
218
218
+
None => Ok(None),
219
219
+
}
220
220
+
}
221
221
+
}
222
222
+
223
223
+
/// Extended methods for entities that support invalidation
224
224
+
impl<E> EntityStore<E>
225
225
+
where
226
226
+
E: CachedEntity + InvalidatableEntity,
227
227
+
{
228
228
+
/// Invalidate cache based on a notification payload
229
229
+
pub async fn invalidate(&self, payload: &str) -> Result<()> {
230
230
+
if let Some(id) = self.entity.parse_invalidation(payload) {
231
231
+
self.cache.invalidate(&id).await;
232
232
+
debug!("Invalidated cache for {:?}", self.entity.cache_key(&id));
233
233
+
} else {
234
234
+
warn!("Could not parse invalidation payload: {}", payload);
235
235
+
}
236
236
+
Ok(())
237
237
+
}
238
238
+
239
239
+
/// Invalidate a specific entity by ID
240
240
+
pub async fn invalidate_by_id(&self, id: &E::Id) -> Result<()> {
241
241
+
self.cache.invalidate(id).await;
242
242
+
debug!("Invalidated cache for {:?}", self.entity.cache_key(id));
243
243
+
Ok(())
244
244
+
}
245
245
+
}
246
246
+
247
247
+
/// Cache statistics
248
248
+
#[derive(Debug, Clone)]
249
249
+
pub struct CacheStats {
250
250
+
/// Number of entries currently in the cache
251
251
+
pub entry_count: u64,
252
252
+
253
253
+
/// Total weighted size of entries
254
254
+
pub weighted_size: u64,
255
255
+
}
256
256
+
257
257
+
#[cfg(test)]
258
258
+
mod tests {
259
259
+
// TODO: Add comprehensive tests for EntityStore
260
260
+
}
+124
parakeet-appview/src/entity/traits.rs
reviewed
···
1
1
+
use std::hash::Hash;
2
2
+
use std::time::Duration;
3
3
+
4
4
+
use async_trait::async_trait;
5
5
+
use diesel_async::AsyncPgConnection;
6
6
+
use eyre::Result;
7
7
+
8
8
+
/// Core trait for entities that can be cached.
9
9
+
///
10
10
+
/// This trait provides the foundation for all entity operations, including:
11
11
+
/// - Database fetching (single and batch)
12
12
+
/// - Cache management with TTL
13
13
+
#[async_trait]
14
14
+
pub trait CachedEntity: Send + Sync + 'static {
15
15
+
/// The type used as the unique identifier for this entity
16
16
+
type Id: Hash + Eq + Clone + Send + Sync + 'static;
17
17
+
18
18
+
/// The database model type (from parakeet-db)
19
19
+
/// Models should have methods like to_profile_view(), to_post_view(), etc.
20
20
+
type Model: Clone + Send + Sync + 'static;
21
21
+
22
22
+
/// Fetch a single entity from the database by ID
23
23
+
async fn fetch_one(
24
24
+
&self,
25
25
+
conn: &mut AsyncPgConnection,
26
26
+
id: &Self::Id,
27
27
+
) -> Result<Option<Self::Model>>;
28
28
+
29
29
+
/// Fetch multiple entities from the database in a single query
30
30
+
async fn fetch_batch(
31
31
+
&self,
32
32
+
conn: &mut AsyncPgConnection,
33
33
+
ids: &[Self::Id],
34
34
+
) -> Result<Vec<(Self::Id, Self::Model)>>;
35
35
+
36
36
+
/// Get the cache TTL for this entity type
37
37
+
fn cache_ttl(&self) -> Duration;
38
38
+
39
39
+
/// Generate a cache key for the given ID
40
40
+
fn cache_key(&self, id: &Self::Id) -> String;
41
41
+
42
42
+
/// Maximum number of items to keep in cache
43
43
+
fn cache_max_capacity(&self) -> u64 {
44
44
+
10_000 // Default, can be overridden
45
45
+
}
46
46
+
}
47
47
+
48
48
+
/// Extended trait for entities that support identifier resolution
49
49
+
/// (e.g., resolving DIDs or handles to internal IDs)
50
50
+
#[async_trait]
51
51
+
pub trait ResolvableEntity: CachedEntity {
52
52
+
/// Resolve an external identifier (DID, handle, URI) to an internal ID
53
53
+
async fn resolve_identifier(
54
54
+
&self,
55
55
+
conn: &mut AsyncPgConnection,
56
56
+
identifier: &str,
57
57
+
) -> Result<Option<Self::Id>>;
58
58
+
59
59
+
/// Get the cache TTL for identifier mappings (usually longer than entity TTL)
60
60
+
fn identifier_cache_ttl(&self) -> Duration {
61
61
+
Duration::from_secs(86400) // 24 hours default
62
62
+
}
63
63
+
}
64
64
+
65
65
+
/// Trait for entities that support invalidation via PostgreSQL NOTIFY
66
66
+
#[async_trait]
67
67
+
pub trait InvalidatableEntity: CachedEntity {
68
68
+
/// Invalidate cache entries based on a notification payload
69
69
+
async fn invalidate(&self, payload: &str) -> Result<()>;
70
70
+
71
71
+
/// Parse invalidation message to extract the ID
72
72
+
fn parse_invalidation(&self, payload: &str) -> Option<Self::Id>;
73
73
+
}
74
74
+
75
75
+
/// Configuration for entity caching behavior
76
76
+
#[derive(Debug, Clone)]
77
77
+
pub struct EntityConfig {
78
78
+
/// Time-to-live for cached entities
79
79
+
pub cache_ttl: Duration,
80
80
+
81
81
+
/// Maximum number of items in cache
82
82
+
pub max_capacity: u64,
83
83
+
84
84
+
/// Time-to-idle for cache entries
85
85
+
pub idle_timeout: Option<Duration>,
86
86
+
87
87
+
/// Whether to use write-through caching
88
88
+
pub write_through: bool,
89
89
+
}
90
90
+
91
91
+
impl Default for EntityConfig {
92
92
+
fn default() -> Self {
93
93
+
Self {
94
94
+
cache_ttl: Duration::from_secs(3600), // 1 hour
95
95
+
max_capacity: 10_000,
96
96
+
idle_timeout: Some(Duration::from_secs(1800)), // 30 minutes
97
97
+
write_through: true,
98
98
+
}
99
99
+
}
100
100
+
}
101
101
+
102
102
+
/// Batch fetch result with ordering preserved
103
103
+
pub struct BatchResult<Id, Model> {
104
104
+
/// The fetched items with their IDs
105
105
+
pub items: Vec<(Id, Model)>,
106
106
+
107
107
+
/// IDs that were requested but not found
108
108
+
pub missing: Vec<Id>,
109
109
+
}
110
110
+
111
111
+
impl<Id: Clone, Model> BatchResult<Id, Model> {
112
112
+
/// Create a new batch result
113
113
+
pub fn new(items: Vec<(Id, Model)>, missing: Vec<Id>) -> Self {
114
114
+
Self { items, missing }
115
115
+
}
116
116
+
117
117
+
/// Convert to a map for easy lookups
118
118
+
pub fn into_map(self) -> std::collections::HashMap<Id, Model>
119
119
+
where
120
120
+
Id: Hash + Eq,
121
121
+
{
122
122
+
self.items.into_iter().collect()
123
123
+
}
124
124
+
}
+15
parakeet-appview/src/handlers/actor.rs
reviewed
···
1
1
+
//! Actor/Profile XRPC handlers.
2
2
+
3
3
+
use std::sync::Arc;
4
4
+
use axum::{extract::State, Json};
5
5
+
use crate::xrpc::{AtpAcceptLabelers, AtpAuth, GlobalState, XrpcResult};
6
6
+
7
7
+
/// Example handler showing the pattern.
8
8
+
pub async fn example_get_profile(
9
9
+
State(_state): State<Arc<GlobalState>>,
10
10
+
_labelers: AtpAcceptLabelers,
11
11
+
_maybe_auth: Option<AtpAuth>,
12
12
+
) -> XrpcResult<Json<String>> {
13
13
+
// TODO: Implement actual profile fetching
14
14
+
Ok(Json("Profile endpoint placeholder".to_string()))
15
15
+
}
+88
parakeet-appview/src/lib.rs
reviewed
···
1
1
+
//! AppView implementation for AT Protocol.
2
2
+
3
3
+
pub mod entity {
4
4
+
//! Entity abstraction layer for caching and data access.
5
5
+
//!
6
6
+
//! This module provides a generic entity system that works with Jacquard AT Protocol types,
7
7
+
//! handling caching, batch fetching, and invalidation for all entity types.
8
8
+
9
9
+
mod implementations;
10
10
+
mod store;
11
11
+
mod traits;
12
12
+
13
13
+
pub use implementations::{PostEntity, ProfileEntity};
14
14
+
pub use moka::future::Cache;
15
15
+
pub use store::{CacheStats, EntityStore};
16
16
+
pub use traits::{
17
17
+
BatchResult, CachedEntity, EntityConfig, InvalidatableEntity, ResolvableEntity,
18
18
+
};
19
19
+
20
20
+
/// Prelude for entity implementations
21
21
+
pub mod prelude {
22
22
+
pub use super::{
23
23
+
BatchResult, CachedEntity, EntityConfig, EntityStore, InvalidatableEntity,
24
24
+
ResolvableEntity,
25
25
+
};
26
26
+
}
27
27
+
}
28
28
+
29
29
+
pub mod macros;
30
30
+
31
31
+
pub mod xrpc {
32
32
+
//! XRPC handler utilities.
33
33
+
mod auth;
34
34
+
mod context;
35
35
+
mod error;
36
36
+
mod helpers;
37
37
+
mod rate_limiter;
38
38
+
39
39
+
pub use auth::{AtpAcceptLabelers, AtpAuth, AuthError, ExtractXrpc};
40
40
+
pub use context::{GlobalState, XrpcContext, XrpcContextBuilder};
41
41
+
pub use error::{IntoXrpcError, XrpcError, XrpcResult};
42
42
+
pub use helpers::{
43
43
+
build_at_uri, decode_compound_cursor, encode_compound_cursor, encode_int_cursor,
44
44
+
encode_timestamp_cursor, is_did, is_handle, normalize_handle, paginate, parse_at_uri,
45
45
+
preserve_order, preserve_order_optional, PaginatedResponse, PaginationParams,
46
46
+
};
47
47
+
pub use rate_limiter::{RateLimiter, RateLimitConfig};
48
48
+
49
49
+
/// Prelude for handler implementations.
50
50
+
pub mod prelude {
51
51
+
pub use super::{
52
52
+
AtpAcceptLabelers, AtpAuth, ExtractXrpc, GlobalState, PaginatedResponse,
53
53
+
PaginationParams, XrpcContext, XrpcError, XrpcResult,
54
54
+
};
55
55
+
56
56
+
pub use jacquard_api::app_bsky;
57
57
+
pub use jacquard_axum::IntoRouter;
58
58
+
pub use jacquard_common::types::{
59
59
+
did::Did,
60
60
+
handle::Handle,
61
61
+
string::{AtUri, Datetime},
62
62
+
};
63
63
+
}
64
64
+
}
65
65
+
66
66
+
pub mod handlers {
67
67
+
//! Handler implementations.
68
68
+
pub mod actor;
69
69
+
}
70
70
+
71
71
+
#[cfg(feature = "macros")]
72
72
+
pub use parakeet_appview_macros::xrpc_handler;
73
73
+
74
74
+
pub use entity::{CachedEntity, EntityStore, InvalidatableEntity, ResolvableEntity};
75
75
+
pub use xrpc::{GlobalState, XrpcContext, XrpcError, XrpcResult};
76
76
+
77
77
+
/// Prelude module for convenient imports.
78
78
+
pub mod prelude {
79
79
+
pub use crate::batch_fetch;
80
80
+
pub use crate::define_entity;
81
81
+
pub use crate::entity::prelude::*;
82
82
+
pub use crate::xrpc::prelude::*;
83
83
+
#[cfg(feature = "macros")]
84
84
+
pub use crate::xrpc_handler;
85
85
+
86
86
+
pub use axum::{extract::State, Json};
87
87
+
pub use eyre::Result;
88
88
+
}
+209
parakeet-appview/src/macros.rs
reviewed
···
1
1
+
//! Declarative macros for entity definitions and handlers.
2
2
+
3
3
+
/// Define a cached entity with automatic implementation of common patterns.
4
4
+
///
5
5
+
/// This macro generates:
6
6
+
/// - Entity struct implementation
7
7
+
/// - CachedEntity trait implementation
8
8
+
/// - Optional ResolvableEntity implementation for identifier resolution
9
9
+
/// - Helper methods for common operations
10
10
+
///
11
11
+
/// # Example
12
12
+
///
13
13
+
/// ```rust
14
14
+
/// define_entity! {
15
15
+
/// ProfileEntity {
16
16
+
/// // Basic entity configuration
17
17
+
/// id: i32,
18
18
+
/// model: parakeet_db::models::Actor,
19
19
+
/// table: parakeet_db::schema::actors,
20
20
+
/// ttl: 3600, // 1 hour in seconds
21
21
+
///
22
22
+
/// // Optional: Define indexes for resolution
23
23
+
/// indexes: {
24
24
+
/// did: String => did_to_id,
25
25
+
/// handle: String => handle_to_id,
26
26
+
/// },
27
27
+
///
28
28
+
/// // Optional: Custom fetch query
29
29
+
/// fetch_one: |conn, id| {
30
30
+
/// use diesel::prelude::*;
31
31
+
/// actors::table
32
32
+
/// .filter(actors::id.eq(id))
33
33
+
/// .filter(actors::status.eq(ActorStatus::Active))
34
34
+
/// .first(conn)
35
35
+
/// },
36
36
+
/// }
37
37
+
/// }
38
38
+
/// ```
39
39
+
#[macro_export]
40
40
+
macro_rules! define_entity {
41
41
+
(
42
42
+
$name:ident {
43
43
+
id: $id_type:ty,
44
44
+
model: $model:ty,
45
45
+
table: $table:path,
46
46
+
ttl: $ttl:expr,
47
47
+
48
48
+
$(
49
49
+
// Optional indexes for resolution
50
50
+
indexes: {
51
51
+
$($index_name:ident : $index_type:ty => $index_field:ident),* $(,)?
52
52
+
}$(,)?
53
53
+
)?
54
54
+
55
55
+
$(
56
56
+
// Optional custom fetch implementation
57
57
+
fetch_one: |$conn_var:ident, $id_var:ident| $fetch_expr:expr
58
58
+
)?$(,)?
59
59
+
}
60
60
+
) => {
61
61
+
pub struct $name;
62
62
+
63
63
+
#[async_trait::async_trait]
64
64
+
impl $crate::entity::CachedEntity for $name {
65
65
+
type Id = $id_type;
66
66
+
type Model = $model;
67
67
+
68
68
+
async fn fetch_one(
69
69
+
&self,
70
70
+
conn: &mut diesel_async::AsyncPgConnection,
71
71
+
id: &Self::Id,
72
72
+
) -> ::eyre::Result<Option<Self::Model>> {
73
73
+
use diesel::prelude::*;
74
74
+
use diesel_async::RunQueryDsl;
75
75
+
76
76
+
$(
77
77
+
// Use custom fetch if provided
78
78
+
let $conn_var = conn;
79
79
+
let $id_var = id;
80
80
+
return Ok($fetch_expr.await.optional()?);
81
81
+
)?
82
82
+
83
83
+
// Default fetch implementation
84
84
+
#[allow(unreachable_code)]
85
85
+
{
86
86
+
Ok($table::table
87
87
+
.find(id.clone())
88
88
+
.first::<Self::Model>(conn)
89
89
+
.await
90
90
+
.optional()?)
91
91
+
}
92
92
+
}
93
93
+
94
94
+
async fn fetch_batch(
95
95
+
&self,
96
96
+
conn: &mut diesel_async::AsyncPgConnection,
97
97
+
ids: &[Self::Id],
98
98
+
) -> ::eyre::Result<Vec<(Self::Id, Self::Model)>> {
99
99
+
use diesel::prelude::*;
100
100
+
use diesel_async::RunQueryDsl;
101
101
+
102
102
+
// Default batch fetch implementation
103
103
+
let models: Vec<Self::Model> = $table::table
104
104
+
.filter($table::id.eq_any(ids))
105
105
+
.load(conn)
106
106
+
.await?;
107
107
+
108
108
+
// Map models back to their IDs (assuming model has id field)
109
109
+
Ok(models.into_iter()
110
110
+
.map(|model| {
111
111
+
let id = model.id.clone();
112
112
+
(id, model)
113
113
+
})
114
114
+
.collect())
115
115
+
}
116
116
+
117
117
+
118
118
+
fn cache_ttl(&self) -> ::std::time::Duration {
119
119
+
::std::time::Duration::from_secs($ttl)
120
120
+
}
121
121
+
122
122
+
fn cache_key(&self, id: &Self::Id) -> String {
123
123
+
format!("{}:{:?}", stringify!($name), id)
124
124
+
}
125
125
+
}
126
126
+
127
127
+
// Generate index methods if indexes are defined
128
128
+
$(
129
129
+
impl $name {
130
130
+
$(
131
131
+
pub async fn $index_name(
132
132
+
&self,
133
133
+
conn: &mut diesel_async::AsyncPgConnection,
134
134
+
value: &$index_type,
135
135
+
) -> ::eyre::Result<Option<$id_type>> {
136
136
+
use diesel::prelude::*;
137
137
+
use diesel_async::RunQueryDsl;
138
138
+
139
139
+
Ok($table::table
140
140
+
.filter($table::$index_field.eq(value))
141
141
+
.select($table::id)
142
142
+
.first::<$id_type>(conn)
143
143
+
.await
144
144
+
.optional()?)
145
145
+
}
146
146
+
)*
147
147
+
}
148
148
+
149
149
+
// Implement ResolvableEntity if indexes are defined
150
150
+
#[async_trait::async_trait]
151
151
+
impl $crate::entity::ResolvableEntity for $name {
152
152
+
async fn resolve_identifier(
153
153
+
&self,
154
154
+
conn: &mut diesel_async::AsyncPgConnection,
155
155
+
identifier: &str,
156
156
+
) -> ::eyre::Result<Option<Self::Id>> {
157
157
+
// Try each index method to resolve the identifier
158
158
+
$(
159
159
+
// Check if this index type can parse the identifier
160
160
+
if let Ok(value) = identifier.parse::<$index_type>() {
161
161
+
if let Some(id) = self.$index_name(conn, &value).await? {
162
162
+
return Ok(Some(id));
163
163
+
}
164
164
+
}
165
165
+
)*
166
166
+
167
167
+
Ok(None)
168
168
+
}
169
169
+
}
170
170
+
)?
171
171
+
};
172
172
+
}
173
173
+
174
174
+
/// Helper macro for batch operations on entities.
175
175
+
/// Macro for batch operations on entities.
176
176
+
///
177
177
+
/// Simplifies the pattern of fetching multiple entities and converting them to views.
178
178
+
#[macro_export]
179
179
+
macro_rules! batch_fetch {
180
180
+
(
181
181
+
$entity:expr,
182
182
+
ids: $ids:expr,
183
183
+
viewer: $viewer:expr
184
184
+
$(, preserve_order: $preserve:literal)?
185
185
+
) => {{
186
186
+
let batch_result = $entity.get_many($ids).await?;
187
187
+
let views = $entity.to_jacquard_views(&batch_result.items, $viewer).await?;
188
188
+
189
189
+
$(
190
190
+
if $preserve {
191
191
+
let view_map: ::std::collections::HashMap<_, _> = batch_result.items
192
192
+
.into_iter()
193
193
+
.zip(views)
194
194
+
.collect();
195
195
+
196
196
+
let ordered = $ids.iter()
197
197
+
.filter_map(|id| view_map.get(id).cloned())
198
198
+
.collect();
199
199
+
200
200
+
ordered
201
201
+
} else {
202
202
+
views
203
203
+
}
204
204
+
)?
205
205
+
206
206
+
#[allow(unreachable_code)]
207
207
+
views
208
208
+
}};
209
209
+
}
+102
parakeet-appview/src/xrpc/auth.rs
reviewed
···
1
1
+
//! Authentication types for XRPC handlers.
2
2
+
3
3
+
use axum::{
4
4
+
extract::FromRequestParts,
5
5
+
http::{header, request::Parts, StatusCode},
6
6
+
response::{IntoResponse, Response},
7
7
+
};
8
8
+
use serde::Deserialize;
9
9
+
10
10
+
/// Authentication token containing the authenticated user's DID.
11
11
+
#[derive(Debug, Clone)]
12
12
+
pub struct AtpAuth(pub String);
13
13
+
14
14
+
impl<S> FromRequestParts<S> for AtpAuth
15
15
+
where
16
16
+
S: Send + Sync,
17
17
+
{
18
18
+
type Rejection = AuthError;
19
19
+
20
20
+
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
21
21
+
let auth_header = parts
22
22
+
.headers
23
23
+
.get(header::AUTHORIZATION)
24
24
+
.and_then(|h| h.to_str().ok())
25
25
+
.ok_or(AuthError::MissingToken)?;
26
26
+
27
27
+
let token = auth_header
28
28
+
.strip_prefix("Bearer ")
29
29
+
.ok_or(AuthError::InvalidToken)?;
30
30
+
31
31
+
// TODO: Validate JWT and extract DID
32
32
+
Ok(AtpAuth(token.to_string()))
33
33
+
}
34
34
+
}
35
35
+
36
36
+
/// List of labeler DIDs accepted by the client.
37
37
+
#[derive(Debug, Clone, Default)]
38
38
+
pub struct AtpAcceptLabelers(pub Vec<String>);
39
39
+
40
40
+
impl<S> FromRequestParts<S> for AtpAcceptLabelers
41
41
+
where
42
42
+
S: Send + Sync,
43
43
+
{
44
44
+
type Rejection = std::convert::Infallible;
45
45
+
46
46
+
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
47
47
+
let labelers = parts
48
48
+
.headers
49
49
+
.get("atproto-accept-labelers")
50
50
+
.and_then(|h| h.to_str().ok())
51
51
+
.map(|s| s.split(',').map(|s| s.trim().to_string()).collect())
52
52
+
.unwrap_or_default();
53
53
+
54
54
+
Ok(AtpAcceptLabelers(labelers))
55
55
+
}
56
56
+
}
57
57
+
58
58
+
/// Authentication error.
59
59
+
#[derive(Debug)]
60
60
+
pub enum AuthError {
61
61
+
/// No Authorization header was provided
62
62
+
MissingToken,
63
63
+
/// The Authorization header was malformed
64
64
+
InvalidToken,
65
65
+
/// The token failed validation
66
66
+
ValidationFailed(String),
67
67
+
}
68
68
+
69
69
+
impl IntoResponse for AuthError {
70
70
+
fn into_response(self) -> Response {
71
71
+
let (status, message) = match self {
72
72
+
AuthError::MissingToken => (StatusCode::UNAUTHORIZED, "Missing authentication token"),
73
73
+
AuthError::InvalidToken => (StatusCode::UNAUTHORIZED, "Invalid authentication token"),
74
74
+
AuthError::ValidationFailed(_msg) => {
75
75
+
(StatusCode::UNAUTHORIZED, "Token validation failed")
76
76
+
}
77
77
+
};
78
78
+
79
79
+
(status, message).into_response()
80
80
+
}
81
81
+
}
82
82
+
83
83
+
/// Wrapper for extracting XRPC request types.
84
84
+
///
85
85
+
/// Placeholder until jacquard_axum integration is complete.
86
86
+
#[derive(Debug)]
87
87
+
pub struct ExtractXrpc<T>(pub T);
88
88
+
89
89
+
impl<S, T> FromRequestParts<S> for ExtractXrpc<T>
90
90
+
where
91
91
+
S: Send + Sync,
92
92
+
T: for<'de> Deserialize<'de> + Send,
93
93
+
{
94
94
+
type Rejection = StatusCode;
95
95
+
96
96
+
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
97
97
+
let query = parts.uri.query().unwrap_or("");
98
98
+
let value = serde_urlencoded::from_str(query).map_err(|_| StatusCode::BAD_REQUEST)?;
99
99
+
100
100
+
Ok(ExtractXrpc(value))
101
101
+
}
102
102
+
}
+104
parakeet-appview/src/xrpc/context.rs
reviewed
···
1
1
+
use std::sync::Arc;
2
2
+
3
3
+
use deadpool::managed::Pool;
4
4
+
use diesel_async::pooled_connection::AsyncDieselConnectionManager;
5
5
+
use diesel_async::AsyncPgConnection;
6
6
+
7
7
+
use crate::entity::{EntityStore, PostEntity, ProfileEntity};
8
8
+
use super::rate_limiter::RateLimiter;
9
9
+
10
10
+
/// Global application state.
11
11
+
pub struct GlobalState {
12
12
+
pub pool: Arc<Pool<AsyncDieselConnectionManager<AsyncPgConnection>>>,
13
13
+
pub profile_entity: Arc<EntityStore<ProfileEntity>>,
14
14
+
pub post_entity: Arc<EntityStore<PostEntity>>,
15
15
+
pub http_client: reqwest::Client,
16
16
+
pub rate_limiter: RateLimiter,
17
17
+
}
18
18
+
19
19
+
/// Context for XRPC handlers.
20
20
+
#[derive(Clone)]
21
21
+
pub struct XrpcContext {
22
22
+
pub state: Arc<GlobalState>,
23
23
+
pub auth_did: Option<String>,
24
24
+
pub viewer_id: Option<i32>,
25
25
+
pub labelers: Vec<String>,
26
26
+
}
27
27
+
28
28
+
impl XrpcContext {
29
29
+
pub fn new(state: Arc<GlobalState>) -> Self {
30
30
+
Self {
31
31
+
state,
32
32
+
auth_did: None,
33
33
+
viewer_id: None,
34
34
+
labelers: Vec::new(),
35
35
+
}
36
36
+
}
37
37
+
38
38
+
pub fn with_auth(mut self, did: String, viewer_id: Option<i32>) -> Self {
39
39
+
self.auth_did = Some(did);
40
40
+
self.viewer_id = viewer_id;
41
41
+
self
42
42
+
}
43
43
+
44
44
+
pub fn with_labelers(mut self, labelers: Vec<String>) -> Self {
45
45
+
self.labelers = labelers;
46
46
+
self
47
47
+
}
48
48
+
49
49
+
pub fn is_authenticated(&self) -> bool {
50
50
+
self.auth_did.is_some()
51
51
+
}
52
52
+
53
53
+
pub fn viewer_did(&self) -> Option<&str> {
54
54
+
self.auth_did.as_deref()
55
55
+
}
56
56
+
57
57
+
pub async fn db_conn(
58
58
+
&self,
59
59
+
) -> Result<
60
60
+
deadpool::managed::Object<AsyncDieselConnectionManager<AsyncPgConnection>>,
61
61
+
deadpool::managed::PoolError<diesel_async::pooled_connection::PoolError>,
62
62
+
> {
63
63
+
self.state.pool.get().await
64
64
+
}
65
65
+
}
66
66
+
67
67
+
/// Builder for XrpcContext.
68
68
+
pub struct XrpcContextBuilder {
69
69
+
state: Arc<GlobalState>,
70
70
+
auth_did: Option<String>,
71
71
+
viewer_id: Option<i32>,
72
72
+
labelers: Vec<String>,
73
73
+
}
74
74
+
75
75
+
impl XrpcContextBuilder {
76
76
+
pub fn new(state: Arc<GlobalState>) -> Self {
77
77
+
Self {
78
78
+
state,
79
79
+
auth_did: None,
80
80
+
viewer_id: None,
81
81
+
labelers: Vec::new(),
82
82
+
}
83
83
+
}
84
84
+
85
85
+
pub fn auth(mut self, did: String, viewer_id: Option<i32>) -> Self {
86
86
+
self.auth_did = Some(did);
87
87
+
self.viewer_id = viewer_id;
88
88
+
self
89
89
+
}
90
90
+
91
91
+
pub fn labelers(mut self, labelers: Vec<String>) -> Self {
92
92
+
self.labelers = labelers;
93
93
+
self
94
94
+
}
95
95
+
96
96
+
pub fn build(self) -> XrpcContext {
97
97
+
XrpcContext {
98
98
+
state: self.state,
99
99
+
auth_did: self.auth_did,
100
100
+
viewer_id: self.viewer_id,
101
101
+
labelers: self.labelers,
102
102
+
}
103
103
+
}
104
104
+
}
+130
parakeet-appview/src/xrpc/error.rs
reviewed
···
1
1
+
use axum::http::StatusCode;
2
2
+
use axum::response::{IntoResponse, Response};
3
3
+
use serde_json::json;
4
4
+
use thiserror::Error;
5
5
+
6
6
+
/// XRPC error type for the appview
7
7
+
///
8
8
+
/// This is a simplified error type that can be extended to integrate
9
9
+
/// with Jacquard error types when needed.
10
10
+
#[derive(Debug, Error)]
11
11
+
pub enum XrpcError {
12
12
+
/// Resource not found error (404)
13
13
+
#[error("Not found: {0}")]
14
14
+
NotFound(String),
15
15
+
16
16
+
/// Malformed or invalid request syntax (400)
17
17
+
#[error("Bad request: {0}")]
18
18
+
BadRequest(String),
19
19
+
20
20
+
/// Request validation failed (400)
21
21
+
#[error("Invalid request: {0}")]
22
22
+
InvalidRequest(String),
23
23
+
24
24
+
/// Authentication required or failed (401)
25
25
+
#[error("Unauthorized: {0}")]
26
26
+
Unauthorized(String),
27
27
+
28
28
+
/// User lacks permission for this operation (403)
29
29
+
#[error("Forbidden: {0}")]
30
30
+
Forbidden(String),
31
31
+
32
32
+
/// Requested actor/DID does not exist (404)
33
33
+
#[error("Actor not found: {0}")]
34
34
+
ActorNotFound(String),
35
35
+
36
36
+
/// Requested post does not exist (404)
37
37
+
#[error("Post not found: {0}")]
38
38
+
PostNotFound(String),
39
39
+
40
40
+
/// Requested feed does not exist (404)
41
41
+
#[error("Feed not found: {0}")]
42
42
+
FeedNotFound(String),
43
43
+
44
44
+
/// Request exceeded rate limits (429)
45
45
+
#[error("Rate limit exceeded")]
46
46
+
RateLimitExceeded,
47
47
+
48
48
+
/// Unexpected server error (500)
49
49
+
#[error("Internal server error: {0}")]
50
50
+
InternalServerError(String),
51
51
+
52
52
+
/// Database operation failed (500)
53
53
+
#[error("Database error: {0}")]
54
54
+
DatabaseError(String),
55
55
+
56
56
+
/// Wrapper for other error types
57
57
+
#[error(transparent)]
58
58
+
Other(#[from] eyre::Report),
59
59
+
}
60
60
+
61
61
+
/// Result type for XRPC handlers
62
62
+
pub type XrpcResult<T> = Result<T, XrpcError>;
63
63
+
64
64
+
impl XrpcError {
65
65
+
/// Create an actor not found error
66
66
+
pub fn actor_not_found(identifier: &str) -> Self {
67
67
+
Self::ActorNotFound(format!("Actor not found: {}", identifier))
68
68
+
}
69
69
+
70
70
+
/// Create a post not found error
71
71
+
pub fn post_not_found(uri: &str) -> Self {
72
72
+
Self::PostNotFound(format!("Post not found: {}", uri))
73
73
+
}
74
74
+
75
75
+
/// Create a feed not found error
76
76
+
pub fn feed_not_found(uri: &str) -> Self {
77
77
+
Self::FeedNotFound(format!("Feed not found: {}", uri))
78
78
+
}
79
79
+
80
80
+
/// Create a bad request error
81
81
+
pub fn bad_request(message: &str) -> Self {
82
82
+
Self::BadRequest(message.to_string())
83
83
+
}
84
84
+
85
85
+
/// Get the HTTP status code for this error
86
86
+
pub fn status_code(&self) -> StatusCode {
87
87
+
match self {
88
88
+
Self::NotFound(_) => StatusCode::NOT_FOUND,
89
89
+
Self::BadRequest(_) | Self::InvalidRequest(_) => StatusCode::BAD_REQUEST,
90
90
+
Self::Unauthorized(_) => StatusCode::UNAUTHORIZED,
91
91
+
Self::Forbidden(_) => StatusCode::FORBIDDEN,
92
92
+
Self::ActorNotFound(_) | Self::PostNotFound(_) | Self::FeedNotFound(_) => {
93
93
+
StatusCode::NOT_FOUND
94
94
+
}
95
95
+
Self::RateLimitExceeded => StatusCode::TOO_MANY_REQUESTS,
96
96
+
Self::InternalServerError(_) | Self::DatabaseError(_) | Self::Other(_) => {
97
97
+
StatusCode::INTERNAL_SERVER_ERROR
98
98
+
}
99
99
+
}
100
100
+
}
101
101
+
}
102
102
+
103
103
+
impl IntoResponse for XrpcError {
104
104
+
fn into_response(self) -> Response {
105
105
+
let status = self.status_code();
106
106
+
let body = json!({
107
107
+
"error": self.to_string(),
108
108
+
});
109
109
+
110
110
+
(status, axum::Json(body)).into_response()
111
111
+
}
112
112
+
}
113
113
+
114
114
+
/// Extension trait for converting errors into XrpcError
115
115
+
pub trait IntoXrpcError {
116
116
+
/// Convert this error type into an XrpcError
117
117
+
fn into_xrpc_error(self) -> XrpcError;
118
118
+
}
119
119
+
120
120
+
impl IntoXrpcError for eyre::Report {
121
121
+
fn into_xrpc_error(self) -> XrpcError {
122
122
+
XrpcError::Other(self)
123
123
+
}
124
124
+
}
125
125
+
126
126
+
impl IntoXrpcError for diesel::result::Error {
127
127
+
fn into_xrpc_error(self) -> XrpcError {
128
128
+
XrpcError::DatabaseError(self.to_string())
129
129
+
}
130
130
+
}
+245
parakeet-appview/src/xrpc/helpers.rs
reviewed
···
1
1
+
use chrono::{DateTime, Utc};
2
2
+
use eyre::{eyre, Result};
3
3
+
use serde::{Deserialize, Serialize};
4
4
+
use std::collections::HashMap;
5
5
+
6
6
+
/// Pagination helpers for XRPC endpoints
7
7
+
/// Standard pagination parameters used across XRPC endpoints
8
8
+
#[derive(Debug, Clone, Deserialize)]
9
9
+
#[serde(rename_all = "camelCase")]
10
10
+
pub struct PaginationParams {
11
11
+
/// Maximum number of items to return in a single page
12
12
+
pub limit: Option<i32>,
13
13
+
/// Opaque cursor for fetching the next page of results
14
14
+
pub cursor: Option<String>,
15
15
+
}
16
16
+
17
17
+
impl PaginationParams {
18
18
+
/// Get the limit value, clamped to a range with a default
19
19
+
pub fn get_limit(&self, default: usize, min: usize, max: usize) -> usize {
20
20
+
self.limit
21
21
+
.map(|l| l as usize)
22
22
+
.unwrap_or(default)
23
23
+
.clamp(min, max)
24
24
+
}
25
25
+
26
26
+
/// Parse a datetime cursor
27
27
+
pub fn parse_datetime_cursor(&self) -> Option<DateTime<Utc>> {
28
28
+
self.cursor
29
29
+
.as_ref()
30
30
+
.and_then(|c| c.parse::<i64>().ok())
31
31
+
.and_then(DateTime::from_timestamp_millis)
32
32
+
}
33
33
+
34
34
+
/// Parse an integer cursor
35
35
+
pub fn parse_int_cursor(&self) -> Option<i64> {
36
36
+
self.cursor.as_ref().and_then(|c| c.parse().ok())
37
37
+
}
38
38
+
}
39
39
+
40
40
+
/// Response with pagination support
41
41
+
#[derive(Debug, Clone, Serialize)]
42
42
+
#[serde(rename_all = "camelCase")]
43
43
+
pub struct PaginatedResponse<T> {
44
44
+
/// The list of items for the current page
45
45
+
pub items: Vec<T>,
46
46
+
/// Cursor for fetching the next page, if more items exist
47
47
+
#[serde(skip_serializing_if = "Option::is_none")]
48
48
+
pub cursor: Option<String>,
49
49
+
}
50
50
+
51
51
+
impl<T> PaginatedResponse<T> {
52
52
+
/// Create a paginated response from items
53
53
+
pub fn new(mut items: Vec<T>, limit: usize) -> Self {
54
54
+
let has_more = items.len() > limit;
55
55
+
if has_more {
56
56
+
items.truncate(limit);
57
57
+
}
58
58
+
59
59
+
Self {
60
60
+
items,
61
61
+
cursor: None, // Caller should set if has_more
62
62
+
}
63
63
+
}
64
64
+
65
65
+
/// Set the cursor if there are more items
66
66
+
pub fn with_cursor<F>(mut self, has_more: bool, cursor_fn: F) -> Self
67
67
+
where
68
68
+
F: FnOnce(&T) -> String,
69
69
+
{
70
70
+
if has_more && !self.items.is_empty() {
71
71
+
self.cursor = Some(cursor_fn(self.items.last().unwrap()));
72
72
+
}
73
73
+
self
74
74
+
}
75
75
+
}
76
76
+
77
77
+
/// Helper to paginate a vector of items
78
78
+
pub fn paginate<T>(
79
79
+
items: Vec<T>,
80
80
+
limit: usize,
81
81
+
cursor_fn: impl Fn(&T) -> String,
82
82
+
) -> PaginatedResponse<T> {
83
83
+
let has_more = items.len() > limit;
84
84
+
let mut items = items;
85
85
+
86
86
+
if has_more {
87
87
+
items.truncate(limit);
88
88
+
}
89
89
+
90
90
+
let cursor = if has_more && !items.is_empty() {
91
91
+
Some(cursor_fn(items.last().unwrap()))
92
92
+
} else {
93
93
+
None
94
94
+
};
95
95
+
96
96
+
PaginatedResponse { items, cursor }
97
97
+
}
98
98
+
99
99
+
/// AT URI parsing helpers
100
100
+
/// Parse an AT URI into its components
101
101
+
///
102
102
+
/// Note: This is a simplified implementation.
103
103
+
/// In production, use the proper AtUri parser from jacquard_common.
104
104
+
pub fn parse_at_uri(uri: &str) -> Result<(String, String, String)> {
105
105
+
// Simple parsing: at://did/collection/rkey
106
106
+
if !uri.starts_with("at://") {
107
107
+
return Err(eyre!("Not an AT URI"));
108
108
+
}
109
109
+
110
110
+
let without_scheme = &uri[5..]; // Remove "at://"
111
111
+
let parts: Vec<&str> = without_scheme.split('/').collect();
112
112
+
113
113
+
if parts.len() < 3 {
114
114
+
return Err(eyre!("AT URI path incomplete"));
115
115
+
}
116
116
+
117
117
+
let did = parts[0].to_string();
118
118
+
let collection = parts[1].to_string();
119
119
+
let rkey = parts[2].to_string();
120
120
+
121
121
+
Ok((did, collection, rkey))
122
122
+
}
123
123
+
124
124
+
/// Build an AT URI from components
125
125
+
pub fn build_at_uri(did: &str, collection: &str, rkey: &str) -> String {
126
126
+
format!("at://{}/{}/{}", did, collection, rkey)
127
127
+
}
128
128
+
129
129
+
/// Identifier resolution helpers
130
130
+
/// Check if a string is a valid DID
131
131
+
pub fn is_did(identifier: &str) -> bool {
132
132
+
identifier.starts_with("did:")
133
133
+
}
134
134
+
135
135
+
/// Check if a string is a valid handle
136
136
+
pub fn is_handle(identifier: &str) -> bool {
137
137
+
// Basic handle validation - contains a dot and doesn't start with did:
138
138
+
!is_did(identifier) && identifier.contains('.')
139
139
+
}
140
140
+
141
141
+
/// Normalize a handle (remove @ prefix if present)
142
142
+
pub fn normalize_handle(handle: &str) -> &str {
143
143
+
handle.strip_prefix('@').unwrap_or(handle)
144
144
+
}
145
145
+
146
146
+
/// Cursor encoding/decoding
147
147
+
/// Encode a timestamp as a cursor
148
148
+
pub fn encode_timestamp_cursor(ts: DateTime<Utc>) -> String {
149
149
+
ts.timestamp_millis().to_string()
150
150
+
}
151
151
+
152
152
+
/// Encode an integer as a cursor
153
153
+
pub fn encode_int_cursor(value: i64) -> String {
154
154
+
value.to_string()
155
155
+
}
156
156
+
157
157
+
/// Encode a compound cursor with multiple values
158
158
+
pub fn encode_compound_cursor(values: &[(&str, &str)]) -> String {
159
159
+
let params: Vec<String> = values
160
160
+
.iter()
161
161
+
.map(|(k, v)| format!("{}={}", k, v))
162
162
+
.collect();
163
163
+
params.join("&")
164
164
+
}
165
165
+
166
166
+
/// Decode a compound cursor
167
167
+
pub fn decode_compound_cursor(cursor: &str) -> HashMap<String, String> {
168
168
+
cursor
169
169
+
.split('&')
170
170
+
.filter_map(|part| {
171
171
+
let mut split = part.split('=');
172
172
+
match (split.next(), split.next()) {
173
173
+
(Some(k), Some(v)) => Some((k.to_string(), v.to_string())),
174
174
+
_ => None,
175
175
+
}
176
176
+
})
177
177
+
.collect()
178
178
+
}
179
179
+
180
180
+
/// Order preservation helpers
181
181
+
/// Preserve the order of items based on the order of IDs
182
182
+
pub fn preserve_order<Id, Item>(
183
183
+
ids: &[Id],
184
184
+
mut items: Vec<(Id, Item)>,
185
185
+
) -> Vec<Item>
186
186
+
where
187
187
+
Id: Eq + std::hash::Hash + Clone,
188
188
+
Item: Clone,
189
189
+
{
190
190
+
let item_map: HashMap<Id, Item> = items.drain(..).collect();
191
191
+
192
192
+
ids.iter()
193
193
+
.filter_map(|id| item_map.get(id).cloned())
194
194
+
.collect()
195
195
+
}
196
196
+
197
197
+
/// Preserve order with optional items
198
198
+
pub fn preserve_order_optional<Id, Item>(
199
199
+
ids: &[Id],
200
200
+
mut items: Vec<(Id, Item)>,
201
201
+
) -> Vec<Option<Item>>
202
202
+
where
203
203
+
Id: Eq + std::hash::Hash + Clone,
204
204
+
Item: Clone,
205
205
+
{
206
206
+
let item_map: HashMap<Id, Item> = items.drain(..).collect();
207
207
+
208
208
+
ids.iter()
209
209
+
.map(|id| item_map.get(id).cloned())
210
210
+
.collect()
211
211
+
}
212
212
+
213
213
+
#[cfg(test)]
214
214
+
mod tests {
215
215
+
use super::*;
216
216
+
217
217
+
#[test]
218
218
+
fn test_parse_at_uri() {
219
219
+
let uri = "at://did:plc:example/app.bsky.feed.post/3k2yv5";
220
220
+
let (did, collection, rkey) = parse_at_uri(uri).unwrap();
221
221
+
222
222
+
assert_eq!(did, "did:plc:example");
223
223
+
assert_eq!(collection, "app.bsky.feed.post");
224
224
+
assert_eq!(rkey, "3k2yv5");
225
225
+
}
226
226
+
227
227
+
#[test]
228
228
+
fn test_pagination() {
229
229
+
let items = vec![1, 2, 3, 4, 5];
230
230
+
let response = paginate(items, 3, |i| i.to_string());
231
231
+
232
232
+
assert_eq!(response.items, vec![1, 2, 3]);
233
233
+
assert_eq!(response.cursor, Some("3".to_string()));
234
234
+
}
235
235
+
236
236
+
#[test]
237
237
+
fn test_compound_cursor() {
238
238
+
let cursor = encode_compound_cursor(&[("id", "123"), ("ts", "456")]);
239
239
+
assert_eq!(cursor, "id=123&ts=456");
240
240
+
241
241
+
let decoded = decode_compound_cursor(&cursor);
242
242
+
assert_eq!(decoded.get("id"), Some(&"123".to_string()));
243
243
+
assert_eq!(decoded.get("ts"), Some(&"456".to_string()));
244
244
+
}
245
245
+
}
+208
parakeet-appview/src/xrpc/rate_limiter.rs
reviewed
···
1
1
+
//! Rate limiting for XRPC endpoints.
2
2
+
3
3
+
use dashmap::DashMap;
4
4
+
use std::str::FromStr;
5
5
+
use std::sync::Arc;
6
6
+
use std::time::{Duration, Instant};
7
7
+
8
8
+
/// Rate limiter for tracking request rates per client.
9
9
+
#[derive(Clone)]
10
10
+
pub struct RateLimiter {
11
11
+
/// Stores rate limit state per key (e.g., DID, IP address)
12
12
+
buckets: Arc<DashMap<String, TokenBucket>>,
13
13
+
/// Default configuration for all rate limits
14
14
+
default_config: RateLimitConfig,
15
15
+
}
16
16
+
17
17
+
/// Configuration for a rate limit.
18
18
+
#[derive(Clone, Debug)]
19
19
+
pub struct RateLimitConfig {
20
20
+
/// Maximum number of requests
21
21
+
pub max_requests: u32,
22
22
+
/// Time window for the requests
23
23
+
pub window: Duration,
24
24
+
}
25
25
+
26
26
+
impl FromStr for RateLimitConfig {
27
27
+
type Err = String;
28
28
+
29
29
+
/// Parse from a string like "100/60s" or "1000/1h"
30
30
+
fn from_str(s: &str) -> Result<Self, Self::Err> {
31
31
+
let parts: Vec<&str> = s.split('/').collect();
32
32
+
if parts.len() != 2 {
33
33
+
return Err(format!("Invalid rate limit format: {}", s));
34
34
+
}
35
35
+
36
36
+
let max_requests = parts[0]
37
37
+
.parse::<u32>()
38
38
+
.map_err(|_| format!("Invalid request count: {}", parts[0]))?;
39
39
+
40
40
+
let window =
41
41
+
parse_duration(parts[1]).ok_or_else(|| format!("Invalid duration: {}", parts[1]))?;
42
42
+
43
43
+
Ok(RateLimitConfig {
44
44
+
max_requests,
45
45
+
window,
46
46
+
})
47
47
+
}
48
48
+
}
49
49
+
50
50
+
/// Parse duration strings like "60s", "5m", "1h"
51
51
+
fn parse_duration(s: &str) -> Option<Duration> {
52
52
+
let (num_str, unit) = s.split_at(s.len() - 1);
53
53
+
let num = num_str.parse::<u64>().ok()?;
54
54
+
55
55
+
match unit {
56
56
+
"s" => Some(Duration::from_secs(num)),
57
57
+
"m" => Some(Duration::from_secs(num * 60)),
58
58
+
"h" => Some(Duration::from_secs(num * 3600)),
59
59
+
_ => None,
60
60
+
}
61
61
+
}
62
62
+
63
63
+
/// Token bucket for rate limiting.
64
64
+
struct TokenBucket {
65
65
+
tokens: u32,
66
66
+
last_refill: Instant,
67
67
+
}
68
68
+
69
69
+
impl RateLimiter {
70
70
+
/// Create a new rate limiter with default configuration.
71
71
+
pub fn new() -> Self {
72
72
+
Self {
73
73
+
buckets: Arc::new(DashMap::new()),
74
74
+
default_config: RateLimitConfig {
75
75
+
max_requests: 100,
76
76
+
window: Duration::from_secs(60), // 100 requests per minute default
77
77
+
},
78
78
+
}
79
79
+
}
80
80
+
81
81
+
/// Check if a request is allowed for the given key and configuration.
82
82
+
pub fn check(&self, key: &str, config_str: Option<&str>) -> bool {
83
83
+
let config = if let Some(s) = config_str {
84
84
+
match s.parse::<RateLimitConfig>() {
85
85
+
Ok(c) => c,
86
86
+
Err(e) => {
87
87
+
tracing::warn!("Invalid rate limit config '{}': {}, using default", s, e);
88
88
+
self.default_config.clone()
89
89
+
}
90
90
+
}
91
91
+
} else {
92
92
+
self.default_config.clone()
93
93
+
};
94
94
+
95
95
+
let mut entry = self
96
96
+
.buckets
97
97
+
.entry(key.to_string())
98
98
+
.or_insert_with(|| TokenBucket {
99
99
+
tokens: config.max_requests,
100
100
+
last_refill: Instant::now(),
101
101
+
});
102
102
+
103
103
+
let now = Instant::now();
104
104
+
let time_passed = now.duration_since(entry.last_refill);
105
105
+
106
106
+
// Refill tokens based on time passed
107
107
+
if time_passed >= config.window {
108
108
+
// Full refill
109
109
+
entry.tokens = config.max_requests;
110
110
+
entry.last_refill = now;
111
111
+
} else {
112
112
+
// Partial refill
113
113
+
let refill_rate = config.max_requests as f64 / config.window.as_secs_f64();
114
114
+
let tokens_to_add = (time_passed.as_secs_f64() * refill_rate) as u32;
115
115
+
if tokens_to_add > 0 {
116
116
+
entry.tokens = (entry.tokens + tokens_to_add).min(config.max_requests);
117
117
+
entry.last_refill = now;
118
118
+
}
119
119
+
}
120
120
+
121
121
+
// Check if we have tokens available
122
122
+
if entry.tokens > 0 {
123
123
+
entry.tokens -= 1;
124
124
+
true
125
125
+
} else {
126
126
+
false
127
127
+
}
128
128
+
}
129
129
+
130
130
+
/// Clear rate limit data for a specific key.
131
131
+
pub fn clear(&self, key: &str) {
132
132
+
self.buckets.remove(key);
133
133
+
}
134
134
+
135
135
+
/// Clear all rate limit data.
136
136
+
pub fn clear_all(&self) {
137
137
+
self.buckets.clear();
138
138
+
}
139
139
+
}
140
140
+
141
141
+
impl Default for RateLimiter {
142
142
+
fn default() -> Self {
143
143
+
Self::new()
144
144
+
}
145
145
+
}
146
146
+
147
147
+
#[cfg(test)]
148
148
+
mod tests {
149
149
+
use super::*;
150
150
+
use std::thread;
151
151
+
152
152
+
#[test]
153
153
+
fn test_rate_limit_config_parsing() {
154
154
+
let config = "100/60s".parse::<RateLimitConfig>().unwrap();
155
155
+
assert_eq!(config.max_requests, 100);
156
156
+
assert_eq!(config.window, Duration::from_secs(60));
157
157
+
158
158
+
let config = "1000/5m".parse::<RateLimitConfig>().unwrap();
159
159
+
assert_eq!(config.max_requests, 1000);
160
160
+
assert_eq!(config.window, Duration::from_secs(300));
161
161
+
162
162
+
let config = "5000/1h".parse::<RateLimitConfig>().unwrap();
163
163
+
assert_eq!(config.max_requests, 5000);
164
164
+
assert_eq!(config.window, Duration::from_secs(3600));
165
165
+
166
166
+
assert!("invalid".parse::<RateLimitConfig>().is_err());
167
167
+
assert!("100".parse::<RateLimitConfig>().is_err());
168
168
+
assert!("100/60x".parse::<RateLimitConfig>().is_err());
169
169
+
}
170
170
+
171
171
+
#[test]
172
172
+
fn test_rate_limiting() {
173
173
+
let limiter = RateLimiter::new();
174
174
+
175
175
+
// Test basic limiting - 2 requests per second
176
176
+
let key = "test_user";
177
177
+
let config = "2/1s";
178
178
+
179
179
+
// First two requests should succeed
180
180
+
assert!(limiter.check(key, Some(config)));
181
181
+
assert!(limiter.check(key, Some(config)));
182
182
+
183
183
+
// Third request should fail
184
184
+
assert!(!limiter.check(key, Some(config)));
185
185
+
186
186
+
// After waiting, should refill
187
187
+
thread::sleep(Duration::from_millis(1100));
188
188
+
assert!(limiter.check(key, Some(config)));
189
189
+
assert!(limiter.check(key, Some(config)));
190
190
+
assert!(!limiter.check(key, Some(config)));
191
191
+
}
192
192
+
193
193
+
#[test]
194
194
+
fn test_different_keys() {
195
195
+
let limiter = RateLimiter::new();
196
196
+
let config = "1/1s";
197
197
+
198
198
+
// Different keys should have independent limits
199
199
+
assert!(limiter.check("user1", Some(config)));
200
200
+
assert!(limiter.check("user2", Some(config)));
201
201
+
assert!(limiter.check("user3", Some(config)));
202
202
+
203
203
+
// But second request for same key should fail
204
204
+
assert!(!limiter.check("user1", Some(config)));
205
205
+
assert!(!limiter.check("user2", Some(config)));
206
206
+
assert!(!limiter.check("user3", Some(config)));
207
207
+
}
208
208
+
}