+28
-20
src/auth.rs
+28
-20
src/auth.rs
···
5
5
};
6
6
use axum::{extract::FromRequestParts, http::StatusCode};
7
7
use base64::Engine as _;
8
-
use sha2::{Digest, Sha256};
8
+
use sha2::{Digest as _, Sha256};
9
9
10
10
use crate::{AppState, Error, error::ErrorMessage};
11
11
···
39
39
Error::with_status(StatusCode::UNAUTHORIZED, anyhow!("no authorization header"))
40
40
})?;
41
41
42
-
let auth_str = auth_header.to_str().map_err(|_| {
42
+
let auth_str = auth_header.to_str().map_err(|e| {
43
43
Error::with_status(
44
44
StatusCode::UNAUTHORIZED,
45
-
anyhow!("authorization header should be valid utf-8"),
45
+
anyhow!("authorization header should be valid utf-8").context(e),
46
46
)
47
47
})?;
48
48
···
52
52
53
53
// Handle different token types
54
54
if auth_str.starts_with("Bearer ") || auth_str.starts_with("DPoP ") {
55
-
let token = auth_str.splitn(2, ' ').nth(1).unwrap();
55
+
let token = auth_str
56
+
.split_once(' ')
57
+
.expect("Auth string should have a space")
58
+
.1;
59
+
56
60
if has_dpop {
57
61
// Process DPoP token - the Authorization header contains the access token
58
62
// and the DPoP header contains the proof
59
-
let dpop_token = dpop_header.unwrap().to_str().map_err(|_| {
60
-
Error::with_status(
61
-
StatusCode::UNAUTHORIZED,
62
-
anyhow!("DPoP header should be valid utf-8"),
63
-
)
64
-
})?;
63
+
let dpop_token = dpop_header
64
+
.expect("DPoP header should exist")
65
+
.to_str()
66
+
.map_err(|e| {
67
+
Error::with_status(
68
+
StatusCode::UNAUTHORIZED,
69
+
anyhow!("DPoP header should be valid utf-8").context(e),
70
+
)
71
+
})?;
65
72
66
73
return validate_dpop_token(token, dpop_token, parts, state).await;
67
-
} else {
68
-
// Standard Bearer token
69
-
return validate_bearer_token(token, state).await;
70
74
}
75
+
76
+
// Standard Bearer token
77
+
return validate_bearer_token(token, state).await;
71
78
}
72
79
73
80
// If we reach here, no valid authorization method was found
···
139
146
}
140
147
}
141
148
149
+
#[expect(clippy::too_many_lines, reason = "validating dpop has many loc")]
142
150
/// Validate a DPoP token and proof
143
151
async fn validate_dpop_token(
144
152
access_token: &str,
···
165
173
}
166
174
167
175
// Check token expiration
168
-
if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
176
+
if let Some(exp) = claims.get("exp").and_then(serde_json::Value::as_i64) {
169
177
let now = chrono::Utc::now().timestamp();
170
178
if now >= exp {
171
179
return Err(Error::with_message(
···
187
195
188
196
// Decode header
189
197
let dpop_header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
190
-
.decode(dpop_parts[0])
198
+
.decode(dpop_parts.first().context("header part missing")?)
191
199
.context("failed to decode DPoP header")?;
192
200
193
201
let dpop_header: serde_json::Value =
···
203
211
204
212
// Decode claims
205
213
let dpop_claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
206
-
.decode(dpop_parts[1])
214
+
.decode(dpop_parts.get(1).context("claims part missing")?)
207
215
.context("failed to decode DPoP claims")?;
208
216
209
217
let dpop_claims: serde_json::Value =
···
280
288
for prop in required_props {
281
289
let value = jwk
282
290
.get(prop)
283
-
.context(format!("JWK missing required property: {}", prop))?;
284
-
drop(canonical_jwk.insert(prop.to_string(), value.clone()));
291
+
.context(format!("JWK missing required property: {prop}"))?;
292
+
drop(canonical_jwk.insert((*prop).to_owned(), value.clone()));
285
293
}
286
294
287
295
// Serialize with no whitespace
···
336
344
// Get expiry from token or default to 60 seconds
337
345
let exp = dpop_claims
338
346
.get("exp")
339
-
.and_then(|e| e.as_i64())
340
-
.unwrap_or_else(|| timestamp + 60);
347
+
.and_then(serde_json::Value::as_i64)
348
+
.unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp));
341
349
342
350
_ = sqlx::query!(
343
351
r#"
+5
-2
src/endpoints/sync.rs
+5
-2
src/endpoints/sync.rs
···
234
234
active,
235
235
status,
236
236
did: input.did.clone(),
237
-
rev: Some(atrium_api::types::string::Tid::new(r.rev).unwrap()),
237
+
rev: Some(
238
+
atrium_api::types::string::Tid::new(r.rev).expect("should be able to convert Tid"),
239
+
),
238
240
}
239
241
.into(),
240
242
))
···
371
373
head: atrium_api::types::string::Cid::new(
372
374
Cid::from_str(&r.root).expect("should be a valid CID"),
373
375
),
374
-
rev: atrium_api::types::string::Tid::new(r.rev).unwrap(),
376
+
rev: atrium_api::types::string::Tid::new(r.rev)
377
+
.expect("should be able to convert Tid"),
375
378
status: None,
376
379
}
377
380
.into()
+4
-4
src/firehose.rs
+4
-4
src/firehose.rs
···
81
81
},
82
82
}
83
83
84
-
impl Into<sync::subscribe_repos::RepoOp> for RepoOp {
85
-
fn into(self) -> sync::subscribe_repos::RepoOp {
86
-
let (action, cid, prev, path) = match self {
84
+
impl From<RepoOp> for sync::subscribe_repos::RepoOp {
85
+
fn from(val: RepoOp) -> Self {
86
+
let (action, cid, prev, path) = match val {
87
87
RepoOp::Create { cid, path } => ("create", Some(cid), None, path),
88
88
RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path),
89
89
RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path),
···
131
131
prev_data: val.pcid.map(atrium_api::types::CidLink),
132
132
rebase: false,
133
133
repo: val.did,
134
-
rev: Tid::new(val.rev).unwrap(),
134
+
rev: Tid::new(val.rev).expect("should be valid revision"),
135
135
seq: 0,
136
136
since: None,
137
137
time: Datetime::now(),
+104
-95
src/oauth.rs
+104
-95
src/oauth.rs
···
3
3
use crate::metrics::AUTH_FAILED;
4
4
use crate::{AppConfig, AppState, Client, Db, Error, Result, SigningKey};
5
5
use anyhow::{Context as _, anyhow};
6
-
use argon2::{Argon2, PasswordHash, PasswordVerifier};
7
-
use atrium_crypto::keypair::Did;
6
+
use argon2::{Argon2, PasswordHash, PasswordVerifier as _};
7
+
use atrium_crypto::keypair::Did as _;
8
8
use axum::response::Redirect;
9
9
use axum::{
10
10
Json, Router, extract,
···
13
13
response::IntoResponse,
14
14
routing::{get, post},
15
15
};
16
-
use base64::Engine;
16
+
use base64::Engine as _;
17
17
use metrics::counter;
18
18
use rand::distributions::Alphanumeric;
19
-
use rand::{Rng, thread_rng};
19
+
use rand::{Rng as _, thread_rng};
20
20
use serde::{Deserialize, Serialize};
21
21
use serde_json::{Value, json};
22
-
use sha2::Digest;
22
+
use sha2::Digest as _;
23
23
use std::collections::{HashMap, HashSet};
24
24
25
25
/// JWK thumbprint required properties for each key type (RFC7638)
···
31
31
("RSA", &["e", "kty", "n"]),
32
32
];
33
33
34
+
/// JWT ID used record for tracking used JTIs to prevent replay attacks
35
+
#[derive(Debug, Serialize, Deserialize)]
36
+
struct JtiRecord {
37
+
expires_at: i64,
38
+
issuer: String,
39
+
jti: String,
40
+
}
41
+
34
42
/// Parses a JWT without validation and returns header and claims
35
43
fn parse_jwt(token: &str) -> Result<(Value, Value)> {
36
44
let parts: Vec<&str> = token.split('.').collect();
···
42
50
}
43
51
44
52
let header_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
45
-
.decode(parts[0])
53
+
.decode(parts.first().expect("should have JWT header"))
46
54
.context("Failed to decode JWT header")?;
47
55
48
56
let header: Value =
49
57
serde_json::from_slice(&header_bytes).context("Failed to parse JWT header as JSON")?;
50
58
51
59
let claims_bytes = base64::engine::general_purpose::URL_SAFE_NO_PAD
52
-
.decode(parts[1])
60
+
.decode(parts.get(1).expect("should have JWT claims"))
53
61
.context("Failed to decode JWT claims")?;
54
62
55
63
let claims: Value =
···
75
83
// Find required properties for this key type
76
84
let required_props = JWK_REQUIRED_PROPS
77
85
.iter()
78
-
.find(|(kty, _)| *kty == key_type)
79
-
.map(|(_, props)| *props)
80
-
.context(format!("Unsupported key type: {}", key_type))?;
86
+
.find(|&&(kty, _)| kty == key_type)
87
+
.map(|&(_, props)| props)
88
+
.context(anyhow!("Unsupported key type: {key_type}"))?;
81
89
82
90
// Build a new JWK with only the required properties
83
91
let mut canonical_jwk = serde_json::Map::new();
84
92
85
-
for prop in required_props {
93
+
for &prop in required_props {
86
94
let value = jwk
87
95
.get(prop)
88
-
.context(format!("JWK missing required property: {}", prop))?;
89
-
drop(canonical_jwk.insert(prop.to_string(), value.clone()));
96
+
.context(anyhow!("JWK missing required property: {prop}"))?;
97
+
drop(canonical_jwk.insert((*prop).to_string(), value.clone()));
90
98
}
91
99
92
100
// Serialize with no whitespace
···
149
157
})))
150
158
}
151
159
152
-
/// Fetch and validate client metadata from client_id URL
160
+
/// Fetch and validate client metadata from `client_id` URL
153
161
async fn fetch_client_metadata(client: &Client, client_id: &str) -> Result<Value> {
154
162
// Handle localhost development
155
163
if client_id.starts_with("http://localhost") {
···
167
175
});
168
176
169
177
// Extract redirect_uri from query params if available
170
-
let redirect_uris = if let Some(query) = client_url.query() {
171
-
let pairs: HashMap<_, _> = url::form_urlencoded::parse(query.as_bytes()).collect();
172
-
if let Some(uri) = pairs.get("redirect_uri") {
173
-
vec![json!(uri)]
174
-
} else {
178
+
let redirect_uris = client_url.query().map_or_else(
179
+
|| {
175
180
vec![
176
181
json!("http://127.0.0.1/callback"),
177
182
json!("http://[::1]/callback"),
178
183
]
179
-
}
180
-
} else {
181
-
vec![
182
-
json!("http://127.0.0.1/callback"),
183
-
json!("http://[::1]/callback"),
184
-
]
185
-
};
184
+
},
185
+
|query| {
186
+
let pairs: HashMap<_, _> = url::form_urlencoded::parse(query.as_bytes()).collect();
187
+
pairs.get("redirect_uri").map_or_else(
188
+
|| {
189
+
vec![
190
+
json!("http://127.0.0.1/callback"),
191
+
json!("http://[::1]/callback"),
192
+
]
193
+
},
194
+
|uri| vec![json!(uri)],
195
+
)
196
+
},
197
+
);
186
198
187
-
metadata["redirect_uris"] = json!(redirect_uris);
199
+
if let Some(redirect_uris_value) = metadata.as_object_mut() {
200
+
drop(redirect_uris_value.insert("redirect_uris".to_owned(), json!(redirect_uris)));
201
+
}
202
+
188
203
return Ok(metadata);
189
204
}
190
205
···
221
236
// Validate DPoP tokens requirement
222
237
if !metadata
223
238
.get("dpop_bound_access_tokens")
224
-
.and_then(|v| v.as_bool())
239
+
.and_then(Value::as_bool)
225
240
.unwrap_or(false)
226
241
{
227
242
return Err(Error::with_status(
···
233
248
Ok(metadata)
234
249
}
235
250
236
-
/// JWT ID used record for tracking used JTIs to prevent replay attacks
237
-
#[derive(Debug, Serialize, Deserialize)]
238
-
struct JtiRecord {
239
-
jti: String,
240
-
issuer: String,
241
-
expires_at: i64,
242
-
}
243
-
244
251
/// Pushed Authorization Request endpoint
245
252
/// POST `/oauth/par`
253
+
#[expect(clippy::too_many_lines)]
246
254
async fn par(
247
255
State(db): State<Db>,
248
256
State(client): State<Client>,
···
296
304
.and_then(|uris| uris.as_array())
297
305
.context("client metadata missing redirect_uris")?;
298
306
299
-
let uri_valid = allowed_uris.iter().any(|uri| {
300
-
uri.as_str()
301
-
.map_or(false, |uri_str| uri_str == provided_uri)
302
-
});
307
+
let uri_valid = allowed_uris
308
+
.iter()
309
+
.any(|uri| uri.as_str().is_some_and(|uri_str| uri_str == provided_uri));
303
310
304
311
if !uri_valid {
305
312
return Err(Error::with_status(
···
310
317
} else if client_metadata
311
318
.get("redirect_uris")
312
319
.and_then(|uris| uris.as_array())
313
-
.map_or(0, |uris| uris.len())
320
+
.map_or(0, Vec::len)
314
321
!= 1
315
322
{
316
323
return Err(Error::with_status(
···
340
347
.take(32)
341
348
.map(char::from)
342
349
.collect::<String>();
343
-
let request_uri = format!("urn:ietf:params:oauth:request_uri:req-{}", request_id);
350
+
let request_uri = format!("urn:ietf:params:oauth:request_uri:req-{request_id}");
344
351
345
352
// Store request data in the database
346
353
let now = chrono::Utc::now();
···
378
385
379
386
Ok(Json(json!({
380
387
"request_uri": request_uri,
381
-
"expires_in": 300 // 5 minutes
388
+
"expires_in": 300_i32 // 5 minutes
382
389
})))
383
390
}
384
391
···
482
489
483
490
/// OAuth Authorization Sign-in endpoint
484
491
/// POST `/oauth/authorize/sign-in`
492
+
#[expect(clippy::too_many_lines)]
485
493
async fn authorize_signin(
486
494
State(db): State<Db>,
487
495
State(config): State<AppConfig>,
488
496
State(client): State<Client>,
489
497
extract::Form(form_data): extract::Form<HashMap<String, String>>,
490
498
) -> Result<impl IntoResponse> {
499
+
use std::fmt::Write as _;
500
+
491
501
// Extract form data
492
502
let username = form_data.get("username").context("username is required")?;
493
503
let password = form_data.get("password").context("password is required")?;
···
539
549
.context("failed to query database")?
540
550
.context("user not found")?;
541
551
542
-
// Verify password
543
-
match Argon2::default().verify_password(
552
+
// Verify password - fixed to use equality check instead of pattern matching
553
+
if Argon2::default().verify_password(
544
554
password.as_bytes(),
545
555
&PasswordHash::new(account.password.as_str()).context("invalid password hash in db")?,
546
-
) {
547
-
Ok(()) => {}
548
-
Err(_) => {
549
-
counter!(AUTH_FAILED).increment(1);
550
-
return Err(Error::with_status(
551
-
StatusCode::UNAUTHORIZED,
552
-
anyhow!("invalid credentials"),
553
-
));
554
-
}
556
+
) == Ok(())
557
+
{
558
+
} else {
559
+
counter!(AUTH_FAILED).increment(1);
560
+
return Err(Error::with_status(
561
+
StatusCode::UNAUTHORIZED,
562
+
anyhow!("invalid credentials"),
563
+
));
555
564
}
556
565
557
566
// Generate authorization code
···
562
571
.collect::<String>();
563
572
564
573
// Determine redirect URI to use
565
-
let redirect_uri = if let Some(uri) = &par_request.redirect_uri {
574
+
let redirect_uri = if let Some(uri) = par_request.redirect_uri.as_ref() {
566
575
uri.clone()
567
576
} else {
568
577
let client_metadata = fetch_client_metadata(&client, client_id).await?;
···
572
581
.and_then(|uris| uris.first())
573
582
.and_then(|uri| uri.as_str())
574
583
.context("No redirect_uri available")?
575
-
.to_string()
584
+
.to_owned()
576
585
};
577
586
578
587
// Store the authorization code
···
615
624
});
616
625
617
626
// Build redirect URL
618
-
let mut redirect_url = redirect_uri;
619
-
match par_request.response_mode {
620
-
None => redirect_url.push_str("?"), // Default to query
621
-
Some(response_mode) => match response_mode.as_str() {
622
-
"query" => redirect_url.push_str("?"),
623
-
"fragment" => redirect_url.push_str("#"),
624
-
_ => redirect_url.push_str("?"), // Default to query
625
-
},
626
-
};
627
-
redirect_url.push_str(&format!("state={}", urlencoding::encode(&state)));
627
+
let mut redirect_target = redirect_uri;
628
+
match par_request.response_mode.as_deref() {
629
+
Some("fragment") => redirect_target.push('#'),
630
+
_ => redirect_target.push('?'),
631
+
}
632
+
633
+
write!(redirect_target, "state={}", urlencoding::encode(&state)).unwrap();
628
634
let host_name = format!("https://{}", &config.host_name);
629
-
redirect_url.push_str(&format!("&iss={}", urlencoding::encode(&host_name)));
630
-
redirect_url.push_str(&format!("&code={}", urlencoding::encode(&code)));
631
-
Ok(Redirect::to(&redirect_url))
635
+
write!(redirect_target, "&iss={}", urlencoding::encode(&host_name)).unwrap();
636
+
write!(redirect_target, "&code={}", urlencoding::encode(&code)).unwrap();
637
+
Ok(Redirect::to(&redirect_target))
632
638
}
633
639
634
640
/// Verify a DPoP proof and return the JWK thumbprint
···
650
656
let (header, claims) = parse_jwt(dpop_token)?;
651
657
652
658
// Verify "typ" is "dpop+jwt"
653
-
if header.get("typ").and_then(|t| t.as_str()) != Some("dpop+jwt") {
659
+
if header.get("typ").and_then(Value::as_str) != Some("dpop+jwt") {
654
660
return Err(Error::with_status(
655
661
StatusCode::BAD_REQUEST,
656
662
anyhow!("Invalid DPoP token type"),
···
660
666
// Verify required claims
661
667
let jti = claims
662
668
.get("jti")
663
-
.and_then(|j| j.as_str())
669
+
.and_then(Value::as_str)
664
670
.context("Missing jti claim in DPoP token")?;
665
671
666
672
// Check for token expiration
673
+
#[expect(clippy::arithmetic_side_effects)]
667
674
let exp = claims
668
675
.get("exp")
669
-
.and_then(|e| e.as_i64())
676
+
.and_then(Value::as_i64)
670
677
.unwrap_or_else(|| chrono::Utc::now().timestamp() + 60); // Default 60s if not specified
671
678
672
679
let now = chrono::Utc::now().timestamp();
···
678
685
}
679
686
680
687
// Check htm (HTTP method) claim
681
-
if claims.get("htm").and_then(|m| m.as_str()) != Some(http_method) {
688
+
if claims.get("htm").and_then(Value::as_str) != Some(http_method) {
682
689
return Err(Error::with_status(
683
690
StatusCode::BAD_REQUEST,
684
691
anyhow!("Invalid htm claim in DPoP token"),
···
686
693
}
687
694
688
695
// Check htu (HTTP URI) claim
689
-
if claims.get("htu").and_then(|u| u.as_str()) != Some(http_uri) {
696
+
if claims.get("htu").and_then(Value::as_str) != Some(http_uri) {
690
697
return Err(Error::with_status(
691
698
StatusCode::BAD_REQUEST,
692
699
anyhow!(
693
700
"Invalid htu claim in DPoP token: expected {}, got {}",
694
701
http_uri,
695
-
claims.get("htu").and_then(|u| u.as_str()).unwrap_or("None")
702
+
claims.get("htu").and_then(Value::as_str).unwrap_or("None")
696
703
),
697
704
));
698
705
}
···
739
746
.context("failed to store JTI")?;
740
747
741
748
// Cleanup expired JTIs periodically (1% chance on each request)
742
-
if thread_rng().gen_range(0..100) == 0 {
749
+
if thread_rng().gen_range(0_i32..100_i32) == 0_i32 {
743
750
_ = sqlx::query!(r#"DELETE FROM oauth_used_jtis WHERE expires_at < ?"#, now)
744
751
.execute(db)
745
752
.await
···
749
756
Ok(thumbprint)
750
757
}
751
758
752
-
/// Verify a code_verifier against stored code_challenge
759
+
/// Verify a `code_verifier` against stored `code_challenge`
753
760
fn verify_pkce(code_verifier: &str, stored_challenge: &str, method: &str) -> Result<()> {
754
761
// Only S256 is supported currently
755
762
if method != "S256" {
···
778
785
/// OAuth token endpoint
779
786
/// - POST `/oauth/token`
780
787
///
781
-
/// Handles both authorization_code and refresh_token grants
788
+
/// Handles both `authorization_code` and `refresh_token` grants
789
+
#[expect(clippy::too_many_lines)]
782
790
async fn token(
783
791
State(db): State<Db>,
784
792
State(skey): State<SigningKey>,
···
857
865
// Generate tokens
858
866
let now = chrono::Utc::now().timestamp();
859
867
let access_token_expires_in = 3600; // 1 hour
868
+
#[expect(clippy::arithmetic_side_effects)]
860
869
let access_token_expires_at = now + access_token_expires_in;
861
-
let refresh_token_expires_at = now + 2592000; // 30 days
870
+
#[expect(clippy::arithmetic_side_effects)]
871
+
let refresh_token_expires_at = now + 2_592_000; // 30 days
862
872
863
873
// Create access token
864
874
let access_token_claims = json!({
···
961
971
// Generate new tokens
962
972
let now = chrono::Utc::now().timestamp();
963
973
let access_token_expires_in = 3600; // 1 hour
974
+
#[expect(clippy::arithmetic_side_effects)]
964
975
let access_token_expires_at = now + access_token_expires_in;
965
-
let refresh_token_expires_at = now + 2592000; // 30 days
976
+
#[expect(clippy::arithmetic_side_effects)]
977
+
let refresh_token_expires_at = now + 2_592_000; // 30 days
966
978
967
979
// Create access token
968
980
let access_token_claims = json!({
···
1043
1055
// For a real implementation, you would construct a proper JWK
1044
1056
// with all the required fields based on the key type
1045
1057
1046
-
let key_did = skey.did();
1058
+
let did_string = skey.did();
1047
1059
1048
1060
// Extract the key ID from the DID string
1049
1061
// did:key:z... format, where z... is the multibase-encoded public key
1050
-
let key_id = key_did.strip_prefix("did:key:").unwrap_or(&key_did);
1062
+
let key_id = did_string.strip_prefix("did:key:").unwrap_or(&did_string);
1051
1063
1052
1064
let jwk = json!({
1053
1065
"kty": "EC",
···
1077
1089
) -> Result<Json<Value>> {
1078
1090
// Extract required parameters
1079
1091
let token = form_data.get("token").context("token is required")?;
1080
-
let refresh_token_string = "refresh_token".to_string();
1092
+
let refresh_token_string = "refresh_token".to_owned();
1081
1093
let token_type_hint = form_data
1082
1094
.get("token_type_hint")
1083
1095
.unwrap_or(&refresh_token_string);
···
1117
1129
let token_type_hint = form_data.get("token_type_hint");
1118
1130
1119
1131
// Parse the token
1120
-
let (typ, claims) = match crate::auth::verify(&skey.did(), token) {
1121
-
Ok(result) => result,
1122
-
Err(_) => {
1123
-
// Per RFC7662, invalid tokens return { "active": false }
1124
-
return Ok(Json(json!({"active": false})));
1125
-
}
1132
+
let Ok((typ, claims)) = crate::auth::verify(&skey.did(), token) else {
1133
+
// Per RFC7662, invalid tokens return { "active": false }
1134
+
return Ok(Json(json!({"active": false})));
1126
1135
};
1127
1136
1128
1137
// Check token type
···
1143
1152
}
1144
1153
1145
1154
// Check expiration
1146
-
if let Some(exp) = claims.get("exp").and_then(|v| v.as_i64()) {
1155
+
if let Some(exp) = claims.get("exp").and_then(Value::as_i64) {
1147
1156
let now = chrono::Utc::now().timestamp();
1148
1157
if now >= exp {
1149
1158
return Ok(Json(json!({"active": false})));
···
1169
1178
}
1170
1179
1171
1180
// Token is valid, return introspection info
1172
-
let subject = claims.get("sub").and_then(|v| v.as_str());
1173
-
let client_id = claims.get("aud").and_then(|v| v.as_str());
1174
-
let scope = claims.get("scope").and_then(|v| v.as_str());
1175
-
let expiration = claims.get("exp").and_then(|v| v.as_i64());
1176
-
let issued_at = claims.get("iat").and_then(|v| v.as_i64());
1181
+
let subject = claims.get("sub").and_then(Value::as_str);
1182
+
let client_id = claims.get("aud").and_then(Value::as_str);
1183
+
let scope = claims.get("scope").and_then(Value::as_str);
1184
+
let expiration = claims.get("exp").and_then(Value::as_i64);
1185
+
let issued_at = claims.get("iat").and_then(Value::as_i64);
1177
1186
1178
1187
Ok(Json(json!({
1179
1188
"active": true,
+135
-125
src/tests.rs
+135
-125
src/tests.rs
···
1
1
//! Testing utilities for the PDS.
2
-
2
+
#![expect(clippy::arbitrary_source_item_ordering)]
3
3
use std::{
4
4
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener},
5
5
path::PathBuf,
6
-
sync::Arc,
7
6
time::{Duration, Instant},
8
7
};
9
8
10
-
use anyhow::{Context as _, Result};
9
+
use anyhow::Result;
11
10
use atrium_api::{
12
11
com::atproto::server,
13
-
types::{
14
-
Unknown,
15
-
string::{AtIdentifier, Did, Handle, Nsid, RecordKey},
16
-
},
12
+
types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey},
17
13
};
18
14
use figment::{Figment, providers::Format as _};
19
15
use futures::future::join_all;
20
-
use rand::{Rng, thread_rng};
21
16
use serde::{Deserialize, Serialize};
22
17
use tokio::sync::OnceCell;
23
18
use uuid::Uuid;
24
19
25
-
use crate::{AppState, auth::AuthenticatedUser, config::AppConfig};
20
+
use crate::config::AppConfig;
26
21
27
22
/// Global test state, created once for all tests.
28
23
pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new();
···
49
44
50
45
impl Drop for TempDir {
51
46
fn drop(&mut self) {
52
-
let _ = std::fs::remove_dir_all(&self.path);
47
+
drop(std::fs::remove_dir_all(&self.path));
53
48
}
54
49
}
55
50
56
51
/// Test state for the application.
57
52
pub(crate) struct TestState {
58
-
/// The temporary directory for test data.
59
-
temp_dir: TempDir,
60
53
/// The address the test server is listening on.
61
54
address: SocketAddr,
55
+
/// The HTTP client.
56
+
client: reqwest::Client,
62
57
/// The application configuration.
63
58
config: AppConfig,
64
-
/// The HTTP client.
65
-
client: reqwest::Client,
59
+
/// The temporary directory for test data.
60
+
#[expect(dead_code)]
61
+
temp_dir: TempDir,
66
62
}
67
63
68
64
impl TestState {
69
-
/// Create a new test state.
70
-
async fn new() -> Result<Self> {
71
-
// Create a temporary directory for test data
72
-
let temp_dir = TempDir::new()?;
65
+
/// Get a base URL for the test server.
66
+
pub(crate) fn base_url(&self) -> String {
67
+
format!("http://{}", self.address)
68
+
}
69
+
70
+
/// Create a test account.
71
+
pub(crate) async fn create_test_account(&self) -> Result<TestAccount> {
72
+
// Create the account
73
+
let handle = "test.handle";
74
+
let response = self
75
+
.client
76
+
.post(format!(
77
+
"http://{}/xrpc/com.atproto.server.createAccount",
78
+
self.address
79
+
))
80
+
.json(&server::create_account::InputData {
81
+
did: None,
82
+
verification_code: None,
83
+
verification_phone: None,
84
+
email: Some(format!("{}@example.com", &handle)),
85
+
handle: Handle::new(handle.to_owned()).expect("should be able to create handle"),
86
+
password: Some("password123".to_owned()),
87
+
invite_code: None,
88
+
recovery_key: None,
89
+
plc_op: None,
90
+
})
91
+
.send()
92
+
.await?;
73
93
74
-
// Find a free port
75
-
let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
76
-
let address = listener.local_addr()?;
77
-
drop(listener);
94
+
let account: server::create_account::Output = response.json().await?;
78
95
96
+
Ok(TestAccount {
97
+
handle: handle.to_owned(),
98
+
did: account.did.to_string(),
99
+
access_token: account.access_jwt.clone(),
100
+
refresh_token: account.refresh_jwt.clone(),
101
+
})
102
+
}
103
+
104
+
/// Create a new test state.
105
+
#[expect(clippy::unused_async)]
106
+
async fn new() -> Result<Self> {
79
107
// Configure the test app
80
108
#[derive(Serialize, Deserialize)]
81
109
struct TestConfigInput {
110
+
db: Option<String>,
82
111
host_name: Option<String>,
83
-
db: Option<String>,
112
+
key: Option<PathBuf>,
84
113
listen_address: Option<SocketAddr>,
85
-
key: Option<PathBuf>,
86
114
test: Option<bool>,
87
115
}
116
+
// Create a temporary directory for test data
117
+
let temp_dir = TempDir::new()?;
118
+
119
+
// Find a free port
120
+
let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
121
+
let address = listener.local_addr()?;
122
+
drop(listener);
88
123
89
124
let test_config = TestConfigInput {
125
+
db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())),
90
126
host_name: Some(format!("localhost:{}", address.port())),
91
-
db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())),
127
+
key: Some(temp_dir.path().join("test.key")),
92
128
listen_address: Some(address),
93
-
key: Some(temp_dir.path().join("test.key")),
94
129
test: Some(true),
95
130
};
96
131
···
130
165
.build()?;
131
166
132
167
Ok(Self {
133
-
temp_dir,
134
168
address,
135
-
config,
136
169
client,
170
+
config,
171
+
temp_dir,
137
172
})
138
173
}
139
174
···
144
179
let address = self.address;
145
180
146
181
// Start the application in a background task
147
-
tokio::spawn(async move {
182
+
let _handle = tokio::spawn(async move {
148
183
// Set up the application
149
184
use crate::*;
150
185
151
186
// Initialize metrics (noop in test mode)
152
-
let _ = metrics::setup(None);
187
+
drop(metrics::setup(None));
153
188
154
189
// Create client
155
190
let simple_client = reqwest::Client::builder()
···
158
193
.context("failed to build requester client")?;
159
194
let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
160
195
.with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
161
-
mode: http_cache_reqwest::CacheMode::Default,
162
-
manager: http_cache_reqwest::MokaManager::default(),
163
-
options: http_cache_reqwest::HttpCacheOptions::default(),
196
+
mode: CacheMode::Default,
197
+
manager: MokaManager::default(),
198
+
options: HttpCacheOptions::default(),
164
199
}))
165
200
.build();
166
201
167
202
// Create a test keypair
168
-
std::fs::create_dir_all(&config.key.parent().context("should have parent")?)?;
203
+
std::fs::create_dir_all(config.key.parent().context("should have parent")?)?;
169
204
let (skey, rkey) = {
170
-
let skey =
171
-
atrium_crypto::keypair::Secp256k1Keypair::create(&mut rand::thread_rng());
172
-
let rkey =
173
-
atrium_crypto::keypair::Secp256k1Keypair::create(&mut rand::thread_rng());
205
+
let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
206
+
let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
174
207
175
208
let keys = KeyData {
176
209
skey: skey.export(),
···
186
219
};
187
220
188
221
// Set up database
189
-
let opts = sqlx::sqlite::SqliteConnectOptions::from_str(&config.db)
222
+
let opts = SqliteConnectOptions::from_str(&config.db)
190
223
.context("failed to parse database options")?
191
224
.create_if_missing(true);
192
-
let db = sqlx::SqlitePool::connect_with(opts).await?;
225
+
let db = SqlitePool::connect_with(opts).await?;
193
226
194
227
sqlx::migrate!()
195
228
.run(&db)
···
212
245
};
213
246
214
247
// Create the router
215
-
let app = axum::Router::new()
216
-
.route("/", axum::routing::get(crate::index))
217
-
.merge(crate::oauth::routes())
248
+
let app = Router::new()
249
+
.route("/", get(index))
250
+
.merge(oauth::routes())
218
251
.nest(
219
252
"/xrpc",
220
-
crate::endpoints::routes()
221
-
.merge(crate::actor_endpoints::routes())
222
-
.fallback(crate::service_proxy),
253
+
endpoints::routes()
254
+
.merge(actor_endpoints::routes())
255
+
.fallback(service_proxy),
223
256
)
224
-
.layer(tower_http::cors::CorsLayer::permissive())
225
-
.layer(tower_http::trace::TraceLayer::new_for_http())
257
+
.layer(CorsLayer::permissive())
258
+
.layer(TraceLayer::new_for_http())
226
259
.with_state(app_state);
227
260
228
-
println!("Test server listening on {address}");
229
-
230
261
// Listen for connections
231
-
let listener = tokio::net::TcpListener::bind(&address)
262
+
let listener = TcpListener::bind(&address)
232
263
.await
233
264
.context("failed to bind address")?;
234
265
···
242
273
243
274
Ok(())
244
275
}
245
-
246
-
/// Create a test account.
247
-
pub async fn create_test_account(&self) -> Result<TestAccount> {
248
-
let handle = "test.handle";
249
-
println!("Creating test account with handle: {}", handle);
250
-
251
-
// Create the account
252
-
let response = self
253
-
.client
254
-
.post(&format!(
255
-
"http://{}/xrpc/com.atproto.server.createAccount",
256
-
self.address
257
-
))
258
-
.json(&server::create_account::InputData {
259
-
did: None,
260
-
verification_code: None,
261
-
verification_phone: None,
262
-
email: Some(format!("{}@example.com", &handle)),
263
-
handle: Handle::new(handle.to_owned()).expect("should be able to create handle"),
264
-
password: Some("password123".to_string()),
265
-
invite_code: None,
266
-
recovery_key: None,
267
-
plc_op: None,
268
-
})
269
-
.send()
270
-
.await?;
271
-
272
-
let account: server::create_account::Output = response.json().await?;
273
-
274
-
Ok(TestAccount {
275
-
handle: handle.to_owned(),
276
-
did: account.did.to_string(),
277
-
access_token: account.access_jwt.clone(),
278
-
refresh_token: account.refresh_jwt.clone(),
279
-
})
280
-
}
281
-
282
-
/// Get a base URL for the test server.
283
-
pub fn base_url(&self) -> String {
284
-
format!("http://{}", self.address)
285
-
}
286
276
}
287
277
288
278
/// A test account that can be used for testing.
289
-
pub struct TestAccount {
290
-
/// The account handle.
291
-
pub handle: String,
292
-
/// The account DID.
293
-
pub did: String,
279
+
pub(crate) struct TestAccount {
294
280
/// The access token for the account.
295
-
pub access_token: String,
281
+
pub(crate) access_token: String,
282
+
/// The account DID.
283
+
pub(crate) did: String,
284
+
/// The account handle.
285
+
pub(crate) handle: String,
296
286
/// The refresh token for the account.
297
-
pub refresh_token: String,
287
+
#[expect(dead_code)]
288
+
pub(crate) refresh_token: String,
298
289
}
299
290
300
291
/// Initialize the test state.
301
-
pub async fn init_test_state() -> Result<&'static TestState> {
302
-
TEST_STATE
303
-
.get_or_try_init(|| async {
304
-
let state = TestState::new().await?;
305
-
state.start_app().await?;
306
-
Ok(state)
307
-
})
308
-
.await
292
+
pub(crate) async fn init_test_state() -> Result<&'static TestState> {
293
+
async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> {
294
+
let state = TestState::new().await?;
295
+
state.start_app().await?;
296
+
Ok(state)
297
+
}
298
+
TEST_STATE.get_or_try_init(init_test_state).await
309
299
}
310
300
311
301
/// Create a record benchmark that creates records and measures the time it takes.
312
-
pub async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> {
302
+
#[expect(
303
+
clippy::arithmetic_side_effects,
304
+
clippy::integer_division,
305
+
clippy::integer_division_remainder_used,
306
+
clippy::use_debug,
307
+
clippy::print_stdout
308
+
)]
309
+
pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> {
313
310
// Initialize the test state
314
311
let state = init_test_state().await?;
315
312
···
341
338
let record_idx = batch_idx * batch_size + i;
342
339
343
340
let result = client
344
-
.post(&format!("{}/xrpc/com.atproto.repo.createRecord", base_url))
345
-
.header("Authorization", format!("Bearer {}", access_token))
341
+
.post(format!("{base_url}/xrpc/com.atproto.repo.createRecord"))
342
+
.header("Authorization", format!("Bearer {access_token}"))
346
343
.json(&atrium_api::com::atproto::repo::create_record::InputData {
347
-
repo: AtIdentifier::Did(Did::new(account_did.clone()).unwrap()),
348
-
collection: Nsid::new("app.bsky.feed.post".to_string()).unwrap(),
349
-
rkey: Some(RecordKey::new(format!("test-{}", record_idx)).unwrap()),
344
+
repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")),
345
+
collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"),
346
+
rkey: Some(
347
+
RecordKey::new(format!("test-{record_idx}")).expect("valid record key"),
348
+
),
350
349
validate: None,
351
350
record: serde_json::from_str(
352
351
&serde_json::json!({
353
352
"$type": "app.bsky.feed.post",
354
-
"text": format!("Test post {} from {}", record_idx, account_handle),
353
+
"text": format!("Test post {record_idx} from {account_handle}"),
355
354
"createdAt": chrono::Utc::now().to_rfc3339(),
356
355
})
357
356
.to_string(),
358
357
)
359
-
.unwrap(),
358
+
.expect("valid JSON record"),
360
359
swap_commit: None,
361
360
})
362
361
.send()
···
364
363
365
364
let request_duration = request_start.elapsed();
366
365
if record_idx % 10 == 0 {
367
-
println!("Created record {} in {:?}", record_idx, request_duration);
366
+
println!("Created record {record_idx} in {request_duration:?}");
368
367
}
369
368
results.push(result);
370
369
}
···
379
378
let results = join_all(handles).await;
380
379
381
380
// Check for errors
382
-
for result in results {
383
-
let batch_results = result?;
384
-
for result in batch_results {
385
-
match result {
381
+
for batch_result in results {
382
+
let batch_responses = batch_result?;
383
+
for response_result in batch_responses {
384
+
match response_result {
386
385
Ok(response) => {
387
386
if !response.status().is_success() {
388
387
return Err(anyhow::anyhow!(
···
403
402
}
404
403
405
404
#[cfg(test)]
405
+
#[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)]
406
406
mod tests {
407
407
use super::*;
408
+
use anyhow::anyhow;
408
409
409
410
#[tokio::test]
410
411
async fn test_create_account() -> Result<()> {
···
412
413
let account = state.create_test_account().await?;
413
414
414
415
println!("Created test account: {}", account.handle);
415
-
assert!(!account.handle.is_empty());
416
-
assert!(!account.did.is_empty());
417
-
assert!(!account.access_token.is_empty());
416
+
if account.handle.is_empty() {
417
+
return Err(anyhow::anyhow!("Account handle is empty"));
418
+
}
419
+
if account.did.is_empty() {
420
+
return Err(anyhow::anyhow!("Account DID is empty"));
421
+
}
422
+
if account.access_token.is_empty() {
423
+
return Err(anyhow::anyhow!("Account access token is empty"));
424
+
}
418
425
419
426
Ok(())
420
427
}
···
423
430
async fn test_create_record_benchmark() -> Result<()> {
424
431
let duration = create_record_benchmark(100, 1).await?;
425
432
426
-
println!("Created 100 records in {:?}", duration);
427
-
assert!(duration.as_secs() < 100, "Benchmark took too long");
433
+
println!("Created 100 records in {duration:?}");
434
+
435
+
if duration.as_secs() >= 100 {
436
+
return Err(anyhow!("Benchmark took too long"));
437
+
}
428
438
429
439
Ok(())
430
440
}