+1
-1
Cargo.toml
+1
-1
Cargo.toml
+17
-6
src/account_manager/helpers/account.rs
+17
-6
src/account_manager/helpers/account.rs
···
23
23
use thiserror::Error;
24
24
25
25
use diesel::dsl::{LeftJoinOn, exists, not};
26
-
use diesel::helper_types::{Eq, IntoBoxed};
26
+
use diesel::helper_types::Eq;
27
27
28
28
#[derive(Error, Debug)]
29
29
pub enum AccountHelperError {
···
277
277
})
278
278
.await
279
279
.expect("Failed to delete actor")?;
280
-
let did = did.to_owned();
280
+
let did_clone = did.to_owned();
281
281
_ = db
282
282
.get()
283
283
.await?
284
284
.interact(move |conn| {
285
285
_ = delete(EmailTokenSchema::email_token)
286
-
.filter(EmailTokenSchema::did.eq(&did))
286
+
.filter(EmailTokenSchema::did.eq(&did_clone))
287
287
.execute(conn)?;
288
288
_ = delete(RefreshTokenSchema::refresh_token)
289
-
.filter(RefreshTokenSchema::did.eq(&did))
289
+
.filter(RefreshTokenSchema::did.eq(&did_clone))
290
290
.execute(conn)?;
291
291
_ = delete(AccountSchema::account)
292
-
.filter(AccountSchema::did.eq(&did))
292
+
.filter(AccountSchema::did.eq(&did_clone))
293
293
.execute(conn)?;
294
294
delete(ActorSchema::actor)
295
-
.filter(ActorSchema::did.eq(&did))
295
+
.filter(ActorSchema::did.eq(&did_clone))
296
296
.execute(conn)
297
297
})
298
298
.await
299
299
.expect("Failed to delete account")?;
300
+
301
+
let data_repo_file = format!("data/repo/{}.db", did.to_owned());
302
+
let data_blob_path = format!("data/blob/{}", did);
303
+
let data_blob_path = std::path::Path::new(&data_blob_path);
304
+
let data_repo_file = std::path::Path::new(&data_repo_file);
305
+
if data_repo_file.exists() {
306
+
std::fs::remove_file(data_repo_file)?;
307
+
};
308
+
if data_blob_path.exists() {
309
+
std::fs::remove_dir_all(data_blob_path)?;
310
+
};
300
311
Ok(())
301
312
}
302
313
+18
-15
src/account_manager/mod.rs
+18
-15
src/account_manager/mod.rs
···
2
2
//! blacksky-algorithms/rsky is licensed under the Apache License 2.0
3
3
//!
4
4
//! Modified for SQLite backend
5
-
use crate::ActorPools;
6
5
use crate::account_manager::helpers::account::{
7
6
AccountStatus, ActorAccount, AvailabilityFlags, GetAccountAdminStatusOutput,
8
7
};
···
12
11
use crate::account_manager::helpers::invite::CodeDetail;
13
12
use crate::account_manager::helpers::password::UpdateUserPasswordOpts;
14
13
use crate::models::pds::EmailTokenPurpose;
14
+
use crate::serve::ActorStorage;
15
15
use anyhow::Result;
16
-
use axum::extract::FromRef;
17
16
use chrono::DateTime;
18
17
use chrono::offset::Utc as UtcOffset;
19
18
use cidv10::Cid;
20
-
use deadpool_diesel::sqlite::Pool;
21
19
use diesel::*;
22
20
use futures::try_join;
23
21
use helpers::{account, auth, email_token, invite, password, repo};
···
136
134
pub async fn create_account(
137
135
&self,
138
136
opts: CreateAccountOpts,
139
-
actor_pools: &mut std::collections::HashMap<String, ActorPools>,
137
+
actor_pools: &mut std::collections::HashMap<String, ActorStorage>,
140
138
) -> Result<(String, String)> {
141
139
let CreateAccountOpts {
142
140
did,
···
182
180
let did_path = did
183
181
.strip_prefix("did:plc:")
184
182
.ok_or_else(|| anyhow::anyhow!("Invalid DID"))?;
185
-
let path_repo = format!("sqlite://{}", did_path);
183
+
let repo_path = format!("sqlite://data/repo/{}.db", did_path);
186
184
let actor_repo_pool =
187
-
crate::establish_pool(path_repo.as_str()).expect("Failed to establish pool");
188
-
let path_blob = path_repo.replace("repo", "blob");
189
-
let actor_blob_pool = crate::establish_pool(&path_blob).expect("Failed to establish pool");
190
-
let actor_pool = ActorPools {
185
+
crate::db::establish_pool(repo_path.as_str()).expect("Failed to establish pool");
186
+
let blob_path = std::path::Path::new("data/blob").to_path_buf();
187
+
let actor_pool = ActorStorage {
191
188
repo: actor_repo_pool,
192
-
blob: actor_blob_pool,
189
+
blob: blob_path.clone(),
193
190
};
194
-
actor_pools
195
-
.insert(did.clone(), actor_pool)
196
-
.expect("Failed to insert actor pools");
191
+
let blob_path = blob_path.join(did_path);
192
+
tokio::fs::create_dir_all(&blob_path)
193
+
.await
194
+
.map_err(|_| anyhow::anyhow!("Failed to create blob path"))?;
195
+
drop(
196
+
actor_pools
197
+
.insert(did.clone(), actor_pool)
198
+
.expect("Failed to insert actor pools"),
199
+
);
197
200
let db = actor_pools
198
201
.get(&did)
199
202
.ok_or_else(|| anyhow::anyhow!("Actor not found"))?
···
215
218
did: String,
216
219
cid: Cid,
217
220
rev: String,
218
-
actor_pools: &std::collections::HashMap<String, ActorPools>,
221
+
actor_pools: &std::collections::HashMap<String, ActorStorage>,
219
222
) -> Result<()> {
220
223
let db = actor_pools
221
224
.get(&did)
···
228
231
pub async fn delete_account(
229
232
&self,
230
233
did: &str,
231
-
actor_pools: &std::collections::HashMap<String, ActorPools>,
234
+
actor_pools: &std::collections::HashMap<String, ActorStorage>,
232
235
) -> Result<()> {
233
236
let db = actor_pools
234
237
.get(did)
+20
-7
src/actor_endpoints.rs
+20
-7
src/actor_endpoints.rs
···
3
3
/// We shouldn't have to know about any bsky endpoints to store private user data.
4
4
/// This will _very likely_ be changed in the future.
5
5
use atrium_api::app::bsky::actor;
6
-
use axum::{Json, routing::post};
6
+
use axum::{
7
+
Json, Router,
8
+
extract::State,
9
+
routing::{get, post},
10
+
};
7
11
use constcat::concat;
8
12
use diesel::prelude::*;
9
13
10
-
use crate::actor_store::ActorStore;
14
+
use crate::{actor_store::ActorStore, auth::AuthenticatedUser};
11
15
12
-
use super::*;
16
+
use super::serve::*;
13
17
14
18
async fn put_preferences(
15
19
user: AuthenticatedUser,
16
-
State(actor_pools): State<std::collections::HashMap<String, ActorPools>>,
20
+
State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>,
17
21
Json(input): Json<actor::put_preferences::Input>,
18
22
) -> Result<()> {
19
23
let did = user.did();
20
-
let json_string =
21
-
serde_json::to_string(&input.preferences).context("failed to serialize preferences")?;
24
+
// let json_string =
25
+
// serde_json::to_string(&input.preferences).context("failed to serialize preferences")?;
22
26
23
27
// let conn = &mut actor_pools
24
28
// .get(&did)
···
35
39
// .context("failed to update user preferences")
36
40
// });
37
41
todo!("Use actor_store's preferences writer instead");
42
+
// let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
43
+
// let values = actor::defs::Preferences {
44
+
// private_prefs: Some(json_string),
45
+
// ..Default::default()
46
+
// };
47
+
// let namespace = actor::defs::PreferencesNamespace::Private;
48
+
// let scope = actor::defs::PreferencesScope::User;
49
+
// actor_store.pref.put_preferences(values, namespace, scope);
50
+
38
51
Ok(())
39
52
}
40
53
41
54
async fn get_preferences(
42
55
user: AuthenticatedUser,
43
-
State(actor_pools): State<std::collections::HashMap<String, ActorPools>>,
56
+
State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>,
44
57
) -> Result<Json<actor::get_preferences::Output>> {
45
58
let did = user.did();
46
59
// let conn = &mut actor_pools
+3
-3
src/actor_store/blob.rs
+3
-3
src/actor_store/blob.rs
···
28
28
use rsky_repo::types::{PreparedBlobRef, PreparedWrite};
29
29
use std::str::FromStr as _;
30
30
31
-
use super::sql_blob::{BlobStoreSql, ByteStream};
31
+
use super::blob_fs::{BlobStoreFs, ByteStream};
32
32
33
33
pub struct GetBlobOutput {
34
34
pub size: i32,
···
39
39
/// Handles blob operations for an actor store
40
40
pub struct BlobReader {
41
41
/// SQL-based blob storage
42
-
pub blobstore: BlobStoreSql,
42
+
pub blobstore: BlobStoreFs,
43
43
/// DID of the actor
44
44
pub did: String,
45
45
/// Database connection
···
52
52
impl BlobReader {
53
53
/// Create a new blob reader
54
54
pub fn new(
55
-
blobstore: BlobStoreSql,
55
+
blobstore: BlobStoreFs,
56
56
db: deadpool_diesel::Pool<
57
57
deadpool_diesel::Manager<SqliteConnection>,
58
58
deadpool_diesel::sqlite::Object,
+4
-5
src/actor_store/blob_fs.rs
+4
-5
src/actor_store/blob_fs.rs
···
72
72
let first_level = if cid_str.len() >= 10 {
73
73
&cid_str[0..10]
74
74
} else {
75
-
&cid_str
75
+
"short"
76
76
};
77
77
78
78
let second_level = if cid_str.len() >= 20 {
79
79
&cid_str[10..20]
80
80
} else {
81
-
"default"
81
+
"short"
82
82
};
83
83
84
84
self.base_dir
···
277
277
async_fs::create_dir_all(parent).await?;
278
278
}
279
279
280
-
// Copy first, then delete source after success
281
-
_ = async_fs::copy(&mov.from, &mov.to).await?;
282
-
async_fs::remove_file(&mov.from).await?;
280
+
// Move the file
281
+
async_fs::rename(&mov.from, &mov.to).await?;
283
282
284
283
debug!("Moved blob: {:?} -> {:?}", mov.from, mov.to);
285
284
Ok(())
+6
-6
src/actor_store/mod.rs
+6
-6
src/actor_store/mod.rs
···
34
34
use tokio::sync::RwLock;
35
35
36
36
use blob::BlobReader;
37
+
use blob_fs::BlobStoreFs;
37
38
use preference::PreferenceReader;
38
39
use record::RecordReader;
39
-
use sql_blob::BlobStoreSql;
40
40
use sql_repo::SqlRepoReader;
41
41
42
-
use crate::ActorPools;
42
+
use crate::serve::ActorStorage;
43
43
44
44
#[derive(Debug)]
45
45
enum FormatCommitError {
···
74
74
75
75
// Combination of RepoReader/Transactor, BlobReader/Transactor, SqlRepoReader/Transactor
76
76
impl ActorStore {
77
-
/// Concrete reader of an individual repo (hence BlobStoreSql which takes `did` param)
77
+
/// Concrete reader of an individual repo (hence BlobStoreFs which takes `did` param)
78
78
pub fn new(
79
79
did: String,
80
-
blobstore: BlobStoreSql,
80
+
blobstore: BlobStoreFs,
81
81
db: deadpool_diesel::Pool<
82
82
deadpool_diesel::Manager<SqliteConnection>,
83
83
deadpool_diesel::sqlite::Object,
···
96
96
/// Create a new ActorStore taking ActorPools HashMap as input
97
97
pub async fn from_actor_pools(
98
98
did: &String,
99
-
hashmap_actor_pools: &std::collections::HashMap<String, ActorPools>,
99
+
hashmap_actor_pools: &std::collections::HashMap<String, ActorStorage>,
100
100
) -> Self {
101
101
let actor_pool = hashmap_actor_pools
102
102
.get(did)
103
103
.expect("Actor pool not found")
104
104
.clone();
105
-
let blobstore = BlobStoreSql::new(did.clone(), actor_pool.blob);
105
+
let blobstore = BlobStoreFs::new(did.clone(), actor_pool.blob);
106
106
let conn = actor_pool
107
107
.repo
108
108
.clone()
+6
-17
src/apis/com/atproto/repo/apply_writes.rs
+6
-17
src/apis/com/atproto/repo/apply_writes.rs
···
1
1
//! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS.
2
-
use crate::SharedSequencer;
2
+
use crate::account_manager::AccountManager;
3
3
use crate::account_manager::helpers::account::AvailabilityFlags;
4
-
use crate::account_manager::{AccountManager, AccountManagerCreator, SharedAccountManager};
5
4
use crate::{
6
-
ActorPools, AppState, SigningKey,
7
-
actor_store::{ActorStore, sql_blob::BlobStoreSql},
5
+
actor_store::ActorStore,
8
6
auth::AuthenticatedUser,
9
-
config::AppConfig,
10
-
error::{ApiError, ErrorMessage},
7
+
error::ApiError,
8
+
serve::{ActorStorage, AppState},
11
9
};
12
10
use anyhow::{Result, bail};
13
-
use axum::{
14
-
Json, Router,
15
-
body::Body,
16
-
extract::{Query, Request, State},
17
-
http::{self, StatusCode},
18
-
routing::{get, post},
19
-
};
11
+
use axum::{Json, extract::State};
20
12
use cidv10::Cid;
21
-
use deadpool_diesel::sqlite::Pool;
22
13
use futures::stream::{self, StreamExt};
23
14
use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite};
24
-
use rsky_pds::auth_verifier::AccessStandardIncludeChecks;
25
15
use rsky_pds::repo::prepare::{
26
16
PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete,
27
17
prepare_update,
···
29
19
use rsky_pds::sequencer::Sequencer;
30
20
use rsky_repo::types::PreparedWrite;
31
21
use std::str::FromStr;
32
-
use std::sync::Arc;
33
22
use tokio::sync::RwLock;
34
23
35
24
async fn inner_apply_writes(
36
25
body: ApplyWritesInput,
37
26
user: AuthenticatedUser,
38
27
sequencer: &RwLock<Sequencer>,
39
-
actor_pools: std::collections::HashMap<String, ActorPools>,
28
+
actor_pools: std::collections::HashMap<String, ActorStorage>,
40
29
account_manager: &RwLock<AccountManager>,
41
30
) -> Result<()> {
42
31
let tx: ApplyWritesInput = body;
+1
-1
src/apis/com/atproto/repo/mod.rs
+1
-1
src/apis/com/atproto/repo/mod.rs
+1
-1
src/apis/mod.rs
+1
-1
src/apis/mod.rs
···
7
7
use axum::{Json, Router, routing::get};
8
8
use serde_json::json;
9
9
10
-
use crate::{AppState, Result};
10
+
use crate::serve::{AppState, Result};
11
11
12
12
/// Health check endpoint. Returns name and version of the service.
13
13
pub(crate) async fn health() -> Result<Json<serde_json::Value>> {
+4
-1
src/auth.rs
+4
-1
src/auth.rs
···
8
8
use diesel::prelude::*;
9
9
use sha2::{Digest as _, Sha256};
10
10
11
-
use crate::{AppState, Error, error::ErrorMessage};
11
+
use crate::{
12
+
error::{Error, ErrorMessage},
13
+
serve::AppState,
14
+
};
12
15
13
16
/// Request extractor for authenticated users.
14
17
/// If specified in an API endpoint, this guarantees the API can only be called
+1
-1
src/did.rs
+1
-1
src/did.rs
+12
-11
src/error.rs
+12
-11
src/error.rs
···
148
148
149
149
impl ApiError {
150
150
/// Get the appropriate HTTP status code for this error
151
-
fn status_code(&self) -> StatusCode {
151
+
const fn status_code(&self) -> StatusCode {
152
152
match self {
153
153
Self::RuntimeError => StatusCode::INTERNAL_SERVER_ERROR,
154
154
Self::InvalidLogin
···
190
190
Self::BadRequest(error, _) => error,
191
191
Self::AuthRequiredError(_) => "AuthRequiredError",
192
192
}
193
-
.to_string()
193
+
.to_owned()
194
194
}
195
195
196
196
/// Get the user-facing error message
···
218
218
Self::BadRequest(_, msg) => msg,
219
219
Self::AuthRequiredError(msg) => msg,
220
220
}
221
-
.to_string()
221
+
.to_owned()
222
222
}
223
223
}
224
224
225
225
impl From<Error> for ApiError {
226
226
fn from(_value: Error) -> Self {
227
-
ApiError::RuntimeError
227
+
Self::RuntimeError
228
228
}
229
229
}
230
230
231
231
impl From<handle::errors::Error> for ApiError {
232
232
fn from(value: handle::errors::Error) -> Self {
233
233
match value.kind {
234
-
ErrorKind::InvalidHandle => ApiError::InvalidHandle,
235
-
ErrorKind::HandleNotAvailable => ApiError::HandleNotAvailable,
236
-
ErrorKind::UnsupportedDomain => ApiError::UnsupportedDomain,
237
-
ErrorKind::InternalError => ApiError::RuntimeError,
234
+
ErrorKind::InvalidHandle => Self::InvalidHandle,
235
+
ErrorKind::HandleNotAvailable => Self::HandleNotAvailable,
236
+
ErrorKind::UnsupportedDomain => Self::UnsupportedDomain,
237
+
ErrorKind::InternalError => Self::RuntimeError,
238
238
}
239
239
}
240
240
}
···
245
245
let error_type = self.error_type();
246
246
let message = self.message();
247
247
248
-
// Log the error for debugging
249
-
error!("API Error: {}: {}", error_type, message);
248
+
if cfg!(debug_assertions) {
249
+
error!("API Error: {}: {}", error_type, message);
250
+
}
250
251
251
252
// Create the error message and serialize to JSON
252
253
let error_message = ErrorMessage::new(error_type, message);
253
254
let body = serde_json::to_string(&error_message).unwrap_or_else(|_| {
254
-
r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_string()
255
+
r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_owned()
255
256
});
256
257
257
258
// Build the response
-426
src/firehose.rs
-426
src/firehose.rs
···
1
-
//! The firehose module.
2
-
use std::{collections::VecDeque, time::Duration};
3
-
4
-
use anyhow::{Result, bail};
5
-
use atrium_api::{
6
-
com::atproto::sync::{self},
7
-
types::string::{Datetime, Did, Tid},
8
-
};
9
-
use atrium_repo::Cid;
10
-
use axum::extract::ws::{Message, WebSocket};
11
-
use metrics::{counter, gauge};
12
-
use rand::Rng as _;
13
-
use serde::{Serialize, ser::SerializeMap as _};
14
-
use tracing::{debug, error, info, warn};
15
-
16
-
use crate::{
17
-
Client,
18
-
config::AppConfig,
19
-
metrics::{FIREHOSE_HISTORY, FIREHOSE_LISTENERS, FIREHOSE_MESSAGES, FIREHOSE_SEQUENCE},
20
-
};
21
-
22
-
enum FirehoseMessage {
23
-
Broadcast(sync::subscribe_repos::Message),
24
-
Connect(Box<(WebSocket, Option<i64>)>),
25
-
}
26
-
27
-
enum FrameHeader {
28
-
Error,
29
-
Message(String),
30
-
}
31
-
32
-
impl Serialize for FrameHeader {
33
-
#[expect(clippy::question_mark_used, reason = "returns a Result")]
34
-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35
-
where
36
-
S: serde::Serializer,
37
-
{
38
-
let mut map = serializer.serialize_map(None)?;
39
-
40
-
match *self {
41
-
Self::Message(ref s) => {
42
-
map.serialize_key("op")?;
43
-
map.serialize_value(&1_i32)?;
44
-
map.serialize_key("t")?;
45
-
map.serialize_value(s.as_str())?;
46
-
}
47
-
Self::Error => {
48
-
map.serialize_key("op")?;
49
-
map.serialize_value(&-1_i32)?;
50
-
}
51
-
}
52
-
53
-
map.end()
54
-
}
55
-
}
56
-
57
-
/// A repository operation.
58
-
pub(crate) enum RepoOp {
59
-
/// Create a new record.
60
-
Create {
61
-
/// The CID of the record.
62
-
cid: Cid,
63
-
/// The path of the record.
64
-
path: String,
65
-
},
66
-
/// Delete an existing record.
67
-
Delete {
68
-
/// The path of the record.
69
-
path: String,
70
-
/// The previous CID of the record.
71
-
prev: Cid,
72
-
},
73
-
/// Update an existing record.
74
-
Update {
75
-
/// The CID of the record.
76
-
cid: Cid,
77
-
/// The path of the record.
78
-
path: String,
79
-
/// The previous CID of the record.
80
-
prev: Cid,
81
-
},
82
-
}
83
-
84
-
impl From<RepoOp> for sync::subscribe_repos::RepoOp {
85
-
fn from(val: RepoOp) -> Self {
86
-
let (action, cid, prev, path) = match val {
87
-
RepoOp::Create { cid, path } => ("create", Some(cid), None, path),
88
-
RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path),
89
-
RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path),
90
-
};
91
-
92
-
sync::subscribe_repos::RepoOpData {
93
-
action: action.to_owned(),
94
-
cid: cid.map(atrium_api::types::CidLink),
95
-
prev: prev.map(atrium_api::types::CidLink),
96
-
path,
97
-
}
98
-
.into()
99
-
}
100
-
}
101
-
102
-
/// A commit to the repository.
103
-
pub(crate) struct Commit {
104
-
/// Blobs that were created in this commit.
105
-
pub blobs: Vec<Cid>,
106
-
/// The car file containing the commit blocks.
107
-
pub car: Vec<u8>,
108
-
/// The CID of the commit.
109
-
pub cid: Cid,
110
-
/// The DID of the repository changed.
111
-
pub did: Did,
112
-
/// The operations performed in this commit.
113
-
pub ops: Vec<RepoOp>,
114
-
/// The previous commit's CID (if applicable).
115
-
pub pcid: Option<Cid>,
116
-
/// The revision of the commit.
117
-
pub rev: String,
118
-
}
119
-
120
-
impl From<Commit> for sync::subscribe_repos::Commit {
121
-
fn from(val: Commit) -> Self {
122
-
sync::subscribe_repos::CommitData {
123
-
blobs: val
124
-
.blobs
125
-
.into_iter()
126
-
.map(atrium_api::types::CidLink)
127
-
.collect::<Vec<_>>(),
128
-
blocks: val.car,
129
-
commit: atrium_api::types::CidLink(val.cid),
130
-
ops: val.ops.into_iter().map(Into::into).collect::<Vec<_>>(),
131
-
prev_data: val.pcid.map(atrium_api::types::CidLink),
132
-
rebase: false,
133
-
repo: val.did,
134
-
rev: Tid::new(val.rev).expect("should be valid revision"),
135
-
seq: 0,
136
-
since: None,
137
-
time: Datetime::now(),
138
-
too_big: false,
139
-
}
140
-
.into()
141
-
}
142
-
}
143
-
144
-
/// A firehose producer. This is used to transmit messages to the firehose for broadcast.
145
-
#[derive(Clone, Debug)]
146
-
pub(crate) struct FirehoseProducer {
147
-
/// The channel to send messages to the firehose.
148
-
tx: tokio::sync::mpsc::Sender<FirehoseMessage>,
149
-
}
150
-
151
-
impl FirehoseProducer {
152
-
/// Broadcast an `#account` event.
153
-
pub(crate) async fn account(&self, account: impl Into<sync::subscribe_repos::Account>) {
154
-
drop(
155
-
self.tx
156
-
.send(FirehoseMessage::Broadcast(
157
-
sync::subscribe_repos::Message::Account(Box::new(account.into())),
158
-
))
159
-
.await,
160
-
);
161
-
}
162
-
/// Handle client connection.
163
-
pub(crate) async fn client_connection(&self, ws: WebSocket, cursor: Option<i64>) {
164
-
drop(
165
-
self.tx
166
-
.send(FirehoseMessage::Connect(Box::new((ws, cursor))))
167
-
.await,
168
-
);
169
-
}
170
-
/// Broadcast a `#commit` event.
171
-
pub(crate) async fn commit(&self, commit: impl Into<sync::subscribe_repos::Commit>) {
172
-
drop(
173
-
self.tx
174
-
.send(FirehoseMessage::Broadcast(
175
-
sync::subscribe_repos::Message::Commit(Box::new(commit.into())),
176
-
))
177
-
.await,
178
-
);
179
-
}
180
-
/// Broadcast an `#identity` event.
181
-
pub(crate) async fn identity(&self, identity: impl Into<sync::subscribe_repos::Identity>) {
182
-
drop(
183
-
self.tx
184
-
.send(FirehoseMessage::Broadcast(
185
-
sync::subscribe_repos::Message::Identity(Box::new(identity.into())),
186
-
))
187
-
.await,
188
-
);
189
-
}
190
-
}
191
-
192
-
#[expect(
193
-
clippy::as_conversions,
194
-
clippy::cast_possible_truncation,
195
-
clippy::cast_sign_loss,
196
-
clippy::cast_precision_loss,
197
-
clippy::arithmetic_side_effects
198
-
)]
199
-
/// Convert a `usize` to a `f64`.
200
-
const fn convert_usize_f64(x: usize) -> Result<f64, &'static str> {
201
-
let result = x as f64;
202
-
if result as usize - x > 0 {
203
-
return Err("cannot convert");
204
-
}
205
-
Ok(result)
206
-
}
207
-
208
-
/// Serialize a message.
209
-
fn serialize_message(seq: u64, mut msg: sync::subscribe_repos::Message) -> (&'static str, Vec<u8>) {
210
-
let mut dummy_seq = 0_i64;
211
-
#[expect(clippy::pattern_type_mismatch)]
212
-
let (ty, nseq) = match &mut msg {
213
-
sync::subscribe_repos::Message::Account(m) => ("#account", &mut m.seq),
214
-
sync::subscribe_repos::Message::Commit(m) => ("#commit", &mut m.seq),
215
-
sync::subscribe_repos::Message::Identity(m) => ("#identity", &mut m.seq),
216
-
sync::subscribe_repos::Message::Sync(m) => ("#sync", &mut m.seq),
217
-
sync::subscribe_repos::Message::Info(_m) => ("#info", &mut dummy_seq),
218
-
};
219
-
// Set the sequence number.
220
-
*nseq = i64::try_from(seq).expect("should find seq");
221
-
222
-
let hdr = FrameHeader::Message(ty.to_owned());
223
-
224
-
let mut frame = Vec::new();
225
-
serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
226
-
serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message");
227
-
228
-
(ty, frame)
229
-
}
230
-
231
-
/// Broadcast a message out to all clients.
232
-
async fn broadcast_message(clients: &mut Vec<WebSocket>, msg: Message) -> Result<()> {
233
-
counter!(FIREHOSE_MESSAGES).increment(1);
234
-
235
-
for i in (0..clients.len()).rev() {
236
-
let client = clients.get_mut(i).expect("should find client");
237
-
if let Err(e) = client.send(msg.clone()).await {
238
-
debug!("Firehose client disconnected: {e}");
239
-
drop(clients.remove(i));
240
-
}
241
-
}
242
-
243
-
gauge!(FIREHOSE_LISTENERS)
244
-
.set(convert_usize_f64(clients.len()).expect("should find clients length"));
245
-
Ok(())
246
-
}
247
-
248
-
/// Handle a new connection from a websocket client created by subscribeRepos.
249
-
async fn handle_connect(
250
-
mut ws: WebSocket,
251
-
seq: u64,
252
-
history: &VecDeque<(u64, &str, sync::subscribe_repos::Message)>,
253
-
cursor: Option<i64>,
254
-
) -> Result<WebSocket> {
255
-
if let Some(cursor) = cursor {
256
-
let mut frame = Vec::new();
257
-
let cursor = u64::try_from(cursor);
258
-
if cursor.is_err() {
259
-
tracing::warn!("cursor is not a valid u64");
260
-
return Ok(ws);
261
-
}
262
-
let cursor = cursor.expect("should be valid u64");
263
-
// Cursor specified; attempt to backfill the consumer.
264
-
if cursor > seq {
265
-
let hdr = FrameHeader::Error;
266
-
let msg = sync::subscribe_repos::Error::FutureCursor(Some(format!(
267
-
"cursor {cursor} is greater than the current sequence number {seq}"
268
-
)));
269
-
serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
270
-
serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message");
271
-
// Drop the connection.
272
-
drop(ws.send(Message::binary(frame)).await);
273
-
bail!(
274
-
"connection dropped: cursor {cursor} is greater than the current sequence number {seq}"
275
-
);
276
-
}
277
-
278
-
for &(historical_seq, ty, ref msg) in history {
279
-
if cursor > historical_seq {
280
-
continue;
281
-
}
282
-
let hdr = FrameHeader::Message(ty.to_owned());
283
-
serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
284
-
serde_ipld_dagcbor::to_writer(&mut frame, msg).expect("should serialize message");
285
-
if let Err(e) = ws.send(Message::binary(frame.clone())).await {
286
-
debug!("Firehose client disconnected during backfill: {e}");
287
-
break;
288
-
}
289
-
// Clear out the frame to begin a new one.
290
-
frame.clear();
291
-
}
292
-
}
293
-
294
-
Ok(ws)
295
-
}
296
-
297
-
/// Reconnect to upstream relays.
298
-
pub(crate) async fn reconnect_relays(client: &Client, config: &AppConfig) {
299
-
// Avoid connecting to upstream relays in test mode.
300
-
if config.test {
301
-
return;
302
-
}
303
-
304
-
info!("attempting to reconnect to upstream relays");
305
-
for relay in &config.firehose.relays {
306
-
let Some(host) = relay.host_str() else {
307
-
warn!("relay {} has no host specified", relay);
308
-
continue;
309
-
};
310
-
311
-
let r = client
312
-
.post(format!("https://{host}/xrpc/com.atproto.sync.requestCrawl"))
313
-
.json(&serde_json::json!({
314
-
"hostname": format!("https://{}", config.host_name)
315
-
}))
316
-
.send()
317
-
.await;
318
-
319
-
let r = match r {
320
-
Ok(r) => r,
321
-
Err(e) => {
322
-
error!("failed to hit upstream relay {host}: {e}");
323
-
continue;
324
-
}
325
-
};
326
-
327
-
let s = r.status();
328
-
if let Err(e) = r.error_for_status_ref() {
329
-
error!("failed to hit upstream relay {host}: {e}");
330
-
}
331
-
332
-
let b = r.json::<serde_json::Value>().await;
333
-
if let Ok(b) = b {
334
-
info!("relay {host}: {} {}", s, b);
335
-
} else {
336
-
info!("relay {host}: {}", s);
337
-
}
338
-
}
339
-
}
340
-
341
-
/// The main entrypoint for the firehose.
342
-
///
343
-
/// This will broadcast all updates in this PDS out to anyone who is listening.
344
-
///
345
-
/// Reference: <https://atproto.com/specs/sync>
346
-
pub(crate) fn spawn(
347
-
client: Client,
348
-
config: AppConfig,
349
-
) -> (tokio::task::JoinHandle<()>, FirehoseProducer) {
350
-
let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
351
-
let handle = tokio::spawn(async move {
352
-
fn time_since_inception() -> u64 {
353
-
chrono::Utc::now()
354
-
.timestamp_micros()
355
-
.checked_sub(1_743_442_000_000_000)
356
-
.expect("should not wrap")
357
-
.unsigned_abs()
358
-
}
359
-
let mut clients: Vec<WebSocket> = Vec::new();
360
-
let mut history = VecDeque::with_capacity(1000);
361
-
let mut seq = time_since_inception();
362
-
363
-
loop {
364
-
if let Ok(msg) = tokio::time::timeout(Duration::from_secs(30), rx.recv()).await {
365
-
match msg {
366
-
Some(FirehoseMessage::Broadcast(msg)) => {
367
-
let (ty, by) = serialize_message(seq, msg.clone());
368
-
369
-
history.push_back((seq, ty, msg));
370
-
gauge!(FIREHOSE_HISTORY).set(
371
-
convert_usize_f64(history.len()).expect("should find history length"),
372
-
);
373
-
374
-
info!(
375
-
"Broadcasting message {} {} to {} clients",
376
-
seq,
377
-
ty,
378
-
clients.len()
379
-
);
380
-
381
-
counter!(FIREHOSE_SEQUENCE).absolute(seq);
382
-
let now = time_since_inception();
383
-
if now > seq {
384
-
seq = now;
385
-
} else {
386
-
seq = seq.checked_add(1).expect("should not wrap");
387
-
}
388
-
389
-
drop(broadcast_message(&mut clients, Message::binary(by)).await);
390
-
}
391
-
Some(FirehoseMessage::Connect(ws_cursor)) => {
392
-
let (ws, cursor) = *ws_cursor;
393
-
match handle_connect(ws, seq, &history, cursor).await {
394
-
Ok(r) => {
395
-
gauge!(FIREHOSE_LISTENERS).increment(1_i32);
396
-
clients.push(r);
397
-
}
398
-
Err(e) => {
399
-
error!("failed to connect new client: {e}");
400
-
}
401
-
}
402
-
}
403
-
// All producers have been destroyed.
404
-
None => break,
405
-
}
406
-
} else {
407
-
if clients.is_empty() {
408
-
reconnect_relays(&client, &config).await;
409
-
}
410
-
411
-
let contents = rand::thread_rng()
412
-
.sample_iter(rand::distributions::Alphanumeric)
413
-
.take(15)
414
-
.map(char::from)
415
-
.collect::<String>();
416
-
417
-
// Send a websocket ping message.
418
-
// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets
419
-
let message = Message::Ping(axum::body::Bytes::from_owner(contents));
420
-
drop(broadcast_message(&mut clients, message).await);
421
-
}
422
-
}
423
-
});
424
-
425
-
(handle, FirehoseProducer { tx })
426
-
}
+3
-438
src/lib.rs
+3
-438
src/lib.rs
···
8
8
mod db;
9
9
mod did;
10
10
pub mod error;
11
-
mod firehose;
12
11
mod metrics;
13
-
mod mmap;
14
12
mod models;
15
13
mod oauth;
16
-
mod plc;
17
14
mod schema;
15
+
mod serve;
18
16
mod service_proxy;
19
-
#[cfg(test)]
20
-
mod tests;
21
17
22
-
use account_manager::{AccountManager, SharedAccountManager};
23
-
use anyhow::{Context as _, anyhow};
24
-
use atrium_api::types::string::Did;
25
-
use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
26
-
use auth::AuthenticatedUser;
27
-
use axum::{
28
-
Router,
29
-
body::Body,
30
-
extract::{FromRef, Request, State},
31
-
http::{self, HeaderMap, Response, StatusCode, Uri},
32
-
response::IntoResponse,
33
-
routing::get,
34
-
};
35
-
use azure_core::credentials::TokenCredential;
36
-
use clap::Parser;
37
-
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
38
-
use config::AppConfig;
39
-
use db::establish_pool;
40
-
use deadpool_diesel::sqlite::Pool;
41
-
use diesel::prelude::*;
42
-
use diesel_migrations::{EmbeddedMigrations, embed_migrations};
43
-
pub use error::Error;
44
-
use figment::{Figment, providers::Format as _};
45
-
use firehose::FirehoseProducer;
46
-
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
47
-
use rand::Rng as _;
48
-
use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer};
49
-
use serde::{Deserialize, Serialize};
50
-
use service_proxy::service_proxy;
51
-
use std::{
52
-
net::{IpAddr, Ipv4Addr, SocketAddr},
53
-
path::PathBuf,
54
-
str::FromStr as _,
55
-
sync::Arc,
56
-
};
57
-
use tokio::{net::TcpListener, sync::RwLock};
58
-
use tower_http::{cors::CorsLayer, trace::TraceLayer};
59
-
use tracing::{info, warn};
60
-
use uuid::Uuid;
61
-
62
-
/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`.
63
-
pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
64
-
65
-
/// Embedded migrations
66
-
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
67
-
68
-
/// The application-wide result type.
69
-
pub type Result<T> = std::result::Result<T, Error>;
70
-
/// The reqwest client type with middleware.
71
-
pub type Client = reqwest_middleware::ClientWithMiddleware;
72
-
73
-
/// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose.
74
-
pub struct SharedSequencer {
75
-
/// The sequencer instance.
76
-
pub sequencer: RwLock<Sequencer>,
77
-
}
78
-
79
-
#[expect(
80
-
clippy::arbitrary_source_item_ordering,
81
-
reason = "serialized data might be structured"
82
-
)]
83
-
#[derive(Serialize, Deserialize, Debug, Clone)]
84
-
/// The key data structure.
85
-
struct KeyData {
86
-
/// Primary signing key for all repo operations.
87
-
skey: Vec<u8>,
88
-
/// Primary signing (rotation) key for all PLC operations.
89
-
rkey: Vec<u8>,
90
-
}
91
-
92
-
// FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies,
93
-
// and the implementations of this algorithm are much more limited as compared to P256.
94
-
//
95
-
// Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
96
-
#[derive(Clone)]
97
-
/// The signing key for PLC/DID operations.
98
-
pub struct SigningKey(Arc<Secp256k1Keypair>);
99
-
#[derive(Clone)]
100
-
/// The rotation key for PLC operations.
101
-
pub struct RotationKey(Arc<Secp256k1Keypair>);
102
-
103
-
impl std::ops::Deref for SigningKey {
104
-
type Target = Secp256k1Keypair;
105
-
106
-
fn deref(&self) -> &Self::Target {
107
-
&self.0
108
-
}
109
-
}
110
-
111
-
impl SigningKey {
112
-
/// Import from a private key.
113
-
pub fn import(key: &[u8]) -> Result<Self> {
114
-
let key = Secp256k1Keypair::import(key).context("failed to import signing key")?;
115
-
Ok(Self(Arc::new(key)))
116
-
}
117
-
}
118
-
119
-
impl std::ops::Deref for RotationKey {
120
-
type Target = Secp256k1Keypair;
121
-
122
-
fn deref(&self) -> &Self::Target {
123
-
&self.0
124
-
}
125
-
}
126
-
127
-
#[derive(Parser, Debug, Clone)]
128
-
/// Command line arguments.
129
-
pub struct Args {
130
-
/// Path to the configuration file
131
-
#[arg(short, long, default_value = "default.toml")]
132
-
pub config: PathBuf,
133
-
/// The verbosity level.
134
-
#[command(flatten)]
135
-
pub verbosity: Verbosity<InfoLevel>,
136
-
}
137
-
138
-
/// The actor pools for the database connections.
139
-
pub struct ActorPools {
140
-
/// The database connection pool for the actor's repository.
141
-
pub repo: Pool,
142
-
/// The database connection pool for the actor's blobs.
143
-
pub blob: Pool,
144
-
}
145
-
146
-
impl Clone for ActorPools {
147
-
fn clone(&self) -> Self {
148
-
Self {
149
-
repo: self.repo.clone(),
150
-
blob: self.blob.clone(),
151
-
}
152
-
}
153
-
}
154
-
155
-
#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
156
-
#[derive(Clone, FromRef)]
157
-
pub struct AppState {
158
-
/// The application configuration.
159
-
pub config: AppConfig,
160
-
/// The main database connection pool. Used for common PDS data, like invite codes.
161
-
pub db: Pool,
162
-
/// Actor-specific database connection pools. Hashed by DID.
163
-
pub db_actors: std::collections::HashMap<String, ActorPools>,
164
-
165
-
/// The HTTP client with middleware.
166
-
pub client: Client,
167
-
/// The simple HTTP client.
168
-
pub simple_client: reqwest::Client,
169
-
/// The firehose producer.
170
-
pub sequencer: Arc<SharedSequencer>,
171
-
/// The account manager.
172
-
pub account_manager: Arc<SharedAccountManager>,
173
-
174
-
/// The signing key.
175
-
pub signing_key: SigningKey,
176
-
/// The rotation key.
177
-
pub rotation_key: RotationKey,
178
-
}
18
+
pub use serve::run;
179
19
180
20
/// The index (/) route.
181
-
async fn index() -> impl IntoResponse {
21
+
async fn index() -> impl axum::response::IntoResponse {
182
22
r"
183
23
__ __
184
24
/\ \__ /\ \__
···
199
39
Protocol: https://atproto.com
200
40
"
201
41
}
202
-
203
-
/// The main application entry point.
204
-
#[expect(
205
-
clippy::cognitive_complexity,
206
-
clippy::too_many_lines,
207
-
unused_qualifications,
208
-
reason = "main function has high complexity"
209
-
)]
210
-
pub async fn run() -> anyhow::Result<()> {
211
-
let args = Args::parse();
212
-
213
-
// Set up trace logging to console and account for the user-provided verbosity flag.
214
-
if args.verbosity.log_level_filter() != LevelFilter::Off {
215
-
let lvl = match args.verbosity.log_level_filter() {
216
-
LevelFilter::Error => tracing::Level::ERROR,
217
-
LevelFilter::Warn => tracing::Level::WARN,
218
-
LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO,
219
-
LevelFilter::Debug => tracing::Level::DEBUG,
220
-
LevelFilter::Trace => tracing::Level::TRACE,
221
-
};
222
-
tracing_subscriber::fmt().with_max_level(lvl).init();
223
-
}
224
-
225
-
if !args.config.exists() {
226
-
// Throw up a warning if the config file does not exist.
227
-
//
228
-
// This is not fatal because users can specify all configuration settings via
229
-
// the environment, but the most likely scenario here is that a user accidentally
230
-
// omitted the config file for some reason (e.g. forgot to mount it into Docker).
231
-
warn!(
232
-
"configuration file {} does not exist",
233
-
args.config.display()
234
-
);
235
-
}
236
-
237
-
// Read and parse the user-provided configuration.
238
-
let config: AppConfig = Figment::new()
239
-
.admerge(figment::providers::Toml::file(args.config))
240
-
.admerge(figment::providers::Env::prefixed("BLUEPDS_"))
241
-
.extract()
242
-
.context("failed to load configuration")?;
243
-
244
-
if config.test {
245
-
warn!("BluePDS starting up in TEST mode.");
246
-
warn!("This means the application will not federate with the rest of the network.");
247
-
warn!(
248
-
"If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`"
249
-
);
250
-
}
251
-
252
-
// Initialize metrics reporting.
253
-
metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?;
254
-
255
-
// Create a reqwest client that will be used for all outbound requests.
256
-
let simple_client = reqwest::Client::builder()
257
-
.user_agent(APP_USER_AGENT)
258
-
.build()
259
-
.context("failed to build requester client")?;
260
-
let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
261
-
.with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
262
-
mode: CacheMode::Default,
263
-
manager: MokaManager::default(),
264
-
options: HttpCacheOptions::default(),
265
-
}))
266
-
.build();
267
-
268
-
tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?)
269
-
.await
270
-
.context("failed to create key directory")?;
271
-
272
-
// Check if crypto keys exist. If not, create new ones.
273
-
let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) {
274
-
let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f))
275
-
.context("failed to deserialize crypto keys")?;
276
-
277
-
let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?;
278
-
let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?;
279
-
280
-
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
281
-
} else {
282
-
info!("signing keys not found, generating new ones");
283
-
284
-
let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
285
-
let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
286
-
287
-
let keys = KeyData {
288
-
skey: skey.export(),
289
-
rkey: rkey.export(),
290
-
};
291
-
292
-
let mut f = std::fs::File::create(&config.key).context("failed to create key file")?;
293
-
serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?;
294
-
295
-
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
296
-
};
297
-
298
-
tokio::fs::create_dir_all(&config.repo.path).await?;
299
-
tokio::fs::create_dir_all(&config.plc.path).await?;
300
-
tokio::fs::create_dir_all(&config.blob.path).await?;
301
-
302
-
// Create a database connection manager and pool for the main database.
303
-
let pool =
304
-
establish_pool(&config.db).context("failed to establish database connection pool")?;
305
-
// Create a dictionary of database connection pools for each actor.
306
-
let mut actor_pools = std::collections::HashMap::new();
307
-
// let mut actor_blob_pools = std::collections::HashMap::new();
308
-
// We'll determine actors by looking in the data/repo dir for .db files.
309
-
let mut actor_dbs = tokio::fs::read_dir(&config.repo.path)
310
-
.await
311
-
.context("failed to read repo directory")?;
312
-
while let Some(entry) = actor_dbs
313
-
.next_entry()
314
-
.await
315
-
.context("failed to read repo dir")?
316
-
{
317
-
let path = entry.path();
318
-
if path.extension().and_then(|s| s.to_str()) == Some("db") {
319
-
let did_path = path
320
-
.file_stem()
321
-
.and_then(|s| s.to_str())
322
-
.context("failed to get actor DID")?;
323
-
let did = Did::from_str(&format!("did:plc:{}", did_path))
324
-
.expect("should be able to parse actor DID");
325
-
326
-
// Create a new database connection manager and pool for the actor.
327
-
// The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db"
328
-
let path_repo = format!("sqlite://{}", did_path);
329
-
let actor_repo_pool =
330
-
establish_pool(&path_repo).context("failed to create database connection pool")?;
331
-
// Create a new database connection manager and pool for the actor blobs.
332
-
// The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db"
333
-
let path_blob = path_repo.replace("repo", "blob");
334
-
let actor_blob_pool =
335
-
establish_pool(&path_blob).context("failed to create database connection pool")?;
336
-
drop(actor_pools.insert(
337
-
did.to_string(),
338
-
ActorPools {
339
-
repo: actor_repo_pool,
340
-
blob: actor_blob_pool,
341
-
},
342
-
));
343
-
}
344
-
}
345
-
// Apply pending migrations
346
-
// let conn = pool.get().await?;
347
-
// conn.run_pending_migrations(MIGRATIONS)
348
-
// .expect("should be able to run migrations");
349
-
350
-
let hostname = config.host_name.clone();
351
-
let crawlers: Vec<String> = config
352
-
.firehose
353
-
.relays
354
-
.iter()
355
-
.map(|s| s.to_string())
356
-
.collect();
357
-
let sequencer = Arc::new(SharedSequencer {
358
-
sequencer: RwLock::new(Sequencer::new(
359
-
Crawlers::new(hostname, crawlers.clone()),
360
-
None,
361
-
)),
362
-
});
363
-
let account_manager = SharedAccountManager {
364
-
account_manager: RwLock::new(AccountManager::new(pool.clone())),
365
-
};
366
-
367
-
let addr = config
368
-
.listen_address
369
-
.unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000));
370
-
371
-
let app = Router::new()
372
-
.route("/", get(index))
373
-
.merge(oauth::routes())
374
-
.nest(
375
-
"/xrpc",
376
-
apis::routes()
377
-
.merge(actor_endpoints::routes())
378
-
.fallback(service_proxy),
379
-
)
380
-
// .layer(RateLimitLayer::new(30, Duration::from_secs(30)))
381
-
.layer(CorsLayer::permissive())
382
-
.layer(TraceLayer::new_for_http())
383
-
.with_state(AppState {
384
-
config: config.clone(),
385
-
db: pool.clone(),
386
-
db_actors: actor_pools.clone(),
387
-
client: client.clone(),
388
-
simple_client,
389
-
sequencer: sequencer.clone(),
390
-
account_manager: Arc::new(account_manager),
391
-
signing_key: skey,
392
-
rotation_key: rkey,
393
-
});
394
-
395
-
info!("listening on {addr}");
396
-
info!("connect to: http://127.0.0.1:{}", addr.port());
397
-
398
-
// Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
399
-
// If so, create an invite code and share it via the console.
400
-
let conn = pool.get().await.context("failed to get db connection")?;
401
-
402
-
#[derive(QueryableByName)]
403
-
struct TotalCount {
404
-
#[diesel(sql_type = diesel::sql_types::Integer)]
405
-
total_count: i32,
406
-
}
407
-
408
-
let result = conn.interact(move |conn| {
409
-
diesel::sql_query(
410
-
"SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count",
411
-
)
412
-
.get_result::<TotalCount>(conn)
413
-
})
414
-
.await
415
-
.expect("should be able to query database")?;
416
-
417
-
let c = result.total_count;
418
-
419
-
#[expect(clippy::print_stdout)]
420
-
if c == 0 {
421
-
let uuid = Uuid::new_v4().to_string();
422
-
423
-
use crate::models::pds as models;
424
-
use crate::schema::pds::invite_code::dsl as InviteCode;
425
-
let uuid_clone = uuid.clone();
426
-
drop(
427
-
conn.interact(move |conn| {
428
-
diesel::insert_into(InviteCode::invite_code)
429
-
.values(models::InviteCode {
430
-
code: uuid_clone,
431
-
available_uses: 1,
432
-
disabled: 0,
433
-
for_account: "None".to_owned(),
434
-
created_by: "None".to_owned(),
435
-
created_at: "None".to_owned(),
436
-
})
437
-
.execute(conn)
438
-
.context("failed to create new invite code")
439
-
})
440
-
.await
441
-
.expect("should be able to create invite code"),
442
-
);
443
-
444
-
// N.B: This is a sensitive message, so we're bypassing `tracing` here and
445
-
// logging it directly to console.
446
-
println!("=====================================");
447
-
println!(" FIRST STARTUP ");
448
-
println!("=====================================");
449
-
println!("Use this code to create an account:");
450
-
println!("{uuid}");
451
-
println!("=====================================");
452
-
}
453
-
454
-
let listener = TcpListener::bind(&addr)
455
-
.await
456
-
.context("failed to bind address")?;
457
-
458
-
// Serve the app, and request crawling from upstream relays.
459
-
let serve = tokio::spawn(async move {
460
-
axum::serve(listener, app.into_make_service())
461
-
.await
462
-
.context("failed to serve app")
463
-
});
464
-
465
-
// Now that the app is live, request a crawl from upstream relays.
466
-
let mut background_sequencer = sequencer.sequencer.write().await.clone();
467
-
drop(tokio::spawn(
468
-
async move { background_sequencer.start().await },
469
-
));
470
-
471
-
serve
472
-
.await
473
-
.map_err(Into::into)
474
-
.and_then(|r| r)
475
-
.context("failed to serve app")
476
-
}
+1
-3
src/main.rs
+1
-3
src/main.rs
···
1
1
//! BluePDS binary entry point.
2
2
3
3
use anyhow::Context as _;
4
-
use clap::Parser;
5
4
6
5
#[tokio::main(flavor = "multi_thread")]
7
6
async fn main() -> anyhow::Result<()> {
8
-
// Parse command line arguments and call into the library's run function
9
7
bluepds::run().await.context("failed to run application")
10
-
}
8
+
}
-274
src/mmap.rs
-274
src/mmap.rs
···
1
-
#![allow(clippy::arbitrary_source_item_ordering)]
2
-
use std::io::{ErrorKind, Read as _, Seek as _, Write as _};
3
-
4
-
#[cfg(unix)]
5
-
use std::os::fd::AsRawFd as _;
6
-
#[cfg(windows)]
7
-
use std::os::windows::io::AsRawHandle;
8
-
9
-
use memmap2::{MmapMut, MmapOptions};
10
-
11
-
pub(crate) struct MappedFile {
12
-
/// The underlying file handle.
13
-
file: std::fs::File,
14
-
/// The length of the file.
15
-
len: u64,
16
-
/// The mapped memory region.
17
-
map: MmapMut,
18
-
/// Our current offset into the file.
19
-
off: u64,
20
-
}
21
-
22
-
impl MappedFile {
23
-
pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> {
24
-
let len = f.seek(std::io::SeekFrom::End(0))?;
25
-
26
-
#[cfg(windows)]
27
-
let raw = f.as_raw_handle();
28
-
#[cfg(unix)]
29
-
let raw = f.as_raw_fd();
30
-
31
-
#[expect(unsafe_code)]
32
-
Ok(Self {
33
-
// SAFETY:
34
-
// All file-backed memory map constructors are marked \
35
-
// unsafe because of the potential for Undefined Behavior (UB) \
36
-
// using the map if the underlying file is subsequently modified, in or out of process.
37
-
map: unsafe { MmapOptions::new().map_mut(raw)? },
38
-
file: f,
39
-
len,
40
-
off: 0,
41
-
})
42
-
}
43
-
44
-
/// Resize the memory-mapped file. This will reallocate the memory mapping.
45
-
#[expect(unsafe_code)]
46
-
fn resize(&mut self, len: u64) -> std::io::Result<()> {
47
-
// Resize the file.
48
-
self.file.set_len(len)?;
49
-
50
-
#[cfg(windows)]
51
-
let raw = self.file.as_raw_handle();
52
-
#[cfg(unix)]
53
-
let raw = self.file.as_raw_fd();
54
-
55
-
// SAFETY:
56
-
// All file-backed memory map constructors are marked \
57
-
// unsafe because of the potential for Undefined Behavior (UB) \
58
-
// using the map if the underlying file is subsequently modified, in or out of process.
59
-
self.map = unsafe { MmapOptions::new().map_mut(raw)? };
60
-
self.len = len;
61
-
62
-
Ok(())
63
-
}
64
-
}
65
-
66
-
impl std::io::Read for MappedFile {
67
-
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
68
-
if self.off == self.len {
69
-
// If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations.
70
-
return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof"));
71
-
}
72
-
73
-
// Calculate the number of bytes we're going to read.
74
-
let remaining_bytes = self.len.saturating_sub(self.off);
75
-
let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX);
76
-
let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX);
77
-
78
-
let off = usize::try_from(self.off).map_err(|e| {
79
-
std::io::Error::new(
80
-
ErrorKind::InvalidInput,
81
-
format!("offset too large for this platform: {e}"),
82
-
)
83
-
})?;
84
-
85
-
if let (Some(dest), Some(src)) = (
86
-
buf.get_mut(..len),
87
-
self.map.get(off..off.saturating_add(len)),
88
-
) {
89
-
dest.copy_from_slice(src);
90
-
self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0));
91
-
Ok(len)
92
-
} else {
93
-
Err(std::io::Error::new(
94
-
ErrorKind::InvalidInput,
95
-
"invalid buffer range",
96
-
))
97
-
}
98
-
}
99
-
}
100
-
101
-
impl std::io::Write for MappedFile {
102
-
fn flush(&mut self) -> std::io::Result<()> {
103
-
// This is done by the system.
104
-
Ok(())
105
-
}
106
-
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107
-
// Determine if we need to resize the file.
108
-
let buf_len = u64::try_from(buf.len()).map_err(|e| {
109
-
std::io::Error::new(
110
-
ErrorKind::InvalidInput,
111
-
format!("buffer length too large for this platform: {e}"),
112
-
)
113
-
})?;
114
-
115
-
if self.off.saturating_add(buf_len) >= self.len {
116
-
self.resize(self.off.saturating_add(buf_len))?;
117
-
}
118
-
119
-
let off = usize::try_from(self.off).map_err(|e| {
120
-
std::io::Error::new(
121
-
ErrorKind::InvalidInput,
122
-
format!("offset too large for this platform: {e}"),
123
-
)
124
-
})?;
125
-
let len = buf.len();
126
-
127
-
if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) {
128
-
dest.copy_from_slice(buf);
129
-
self.off = self.off.saturating_add(buf_len);
130
-
Ok(len)
131
-
} else {
132
-
Err(std::io::Error::new(
133
-
ErrorKind::InvalidInput,
134
-
"invalid buffer range",
135
-
))
136
-
}
137
-
}
138
-
}
139
-
140
-
impl std::io::Seek for MappedFile {
141
-
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
142
-
let off = match pos {
143
-
std::io::SeekFrom::Start(i) => i,
144
-
std::io::SeekFrom::End(i) => {
145
-
if i <= 0 {
146
-
// If i is negative or zero, we're seeking backwards from the end
147
-
// or exactly at the end
148
-
self.len.saturating_sub(i.unsigned_abs())
149
-
} else {
150
-
// If i is positive, we're seeking beyond the end, which is allowed
151
-
// but requires extending the file
152
-
self.len.saturating_add(i.unsigned_abs())
153
-
}
154
-
}
155
-
std::io::SeekFrom::Current(i) => {
156
-
if i >= 0 {
157
-
self.off.saturating_add(i.unsigned_abs())
158
-
} else {
159
-
self.off.saturating_sub(i.unsigned_abs())
160
-
}
161
-
}
162
-
};
163
-
164
-
// If the offset is beyond EOF, extend the file to the new size.
165
-
if off > self.len {
166
-
self.resize(off)?;
167
-
}
168
-
169
-
self.off = off;
170
-
Ok(off)
171
-
}
172
-
}
173
-
174
-
impl tokio::io::AsyncRead for MappedFile {
175
-
fn poll_read(
176
-
mut self: std::pin::Pin<&mut Self>,
177
-
_cx: &mut std::task::Context<'_>,
178
-
buf: &mut tokio::io::ReadBuf<'_>,
179
-
) -> std::task::Poll<std::io::Result<()>> {
180
-
let wbuf = buf.initialize_unfilled();
181
-
let len = wbuf.len();
182
-
183
-
std::task::Poll::Ready(match self.read(wbuf) {
184
-
Ok(_) => {
185
-
buf.advance(len);
186
-
Ok(())
187
-
}
188
-
Err(e) => Err(e),
189
-
})
190
-
}
191
-
}
192
-
193
-
impl tokio::io::AsyncWrite for MappedFile {
194
-
fn poll_flush(
195
-
self: std::pin::Pin<&mut Self>,
196
-
_cx: &mut std::task::Context<'_>,
197
-
) -> std::task::Poll<Result<(), std::io::Error>> {
198
-
std::task::Poll::Ready(Ok(()))
199
-
}
200
-
201
-
fn poll_shutdown(
202
-
self: std::pin::Pin<&mut Self>,
203
-
_cx: &mut std::task::Context<'_>,
204
-
) -> std::task::Poll<Result<(), std::io::Error>> {
205
-
std::task::Poll::Ready(Ok(()))
206
-
}
207
-
208
-
fn poll_write(
209
-
mut self: std::pin::Pin<&mut Self>,
210
-
_cx: &mut std::task::Context<'_>,
211
-
buf: &[u8],
212
-
) -> std::task::Poll<Result<usize, std::io::Error>> {
213
-
std::task::Poll::Ready(self.write(buf))
214
-
}
215
-
}
216
-
217
-
impl tokio::io::AsyncSeek for MappedFile {
218
-
fn poll_complete(
219
-
self: std::pin::Pin<&mut Self>,
220
-
_cx: &mut std::task::Context<'_>,
221
-
) -> std::task::Poll<std::io::Result<u64>> {
222
-
std::task::Poll::Ready(Ok(self.off))
223
-
}
224
-
225
-
fn start_seek(
226
-
mut self: std::pin::Pin<&mut Self>,
227
-
position: std::io::SeekFrom,
228
-
) -> std::io::Result<()> {
229
-
self.seek(position).map(|_p| ())
230
-
}
231
-
}
232
-
233
-
#[cfg(test)]
234
-
mod test {
235
-
use rand::Rng as _;
236
-
use std::io::Write as _;
237
-
238
-
use super::*;
239
-
240
-
#[test]
241
-
fn basic_rw() {
242
-
let tmp = std::env::temp_dir().join(
243
-
rand::thread_rng()
244
-
.sample_iter(rand::distributions::Alphanumeric)
245
-
.take(10)
246
-
.map(char::from)
247
-
.collect::<String>(),
248
-
);
249
-
250
-
let mut m = MappedFile::new(
251
-
std::fs::File::options()
252
-
.create(true)
253
-
.truncate(true)
254
-
.read(true)
255
-
.write(true)
256
-
.open(&tmp)
257
-
.expect("Failed to open temporary file"),
258
-
)
259
-
.expect("Failed to create MappedFile");
260
-
261
-
m.write_all(b"abcd123").expect("Failed to write data");
262
-
let _: u64 = m
263
-
.seek(std::io::SeekFrom::Start(0))
264
-
.expect("Failed to seek to start");
265
-
266
-
let mut buf = [0_u8; 7];
267
-
m.read_exact(&mut buf).expect("Failed to read data");
268
-
269
-
assert_eq!(&buf, b"abcd123");
270
-
271
-
drop(m);
272
-
std::fs::remove_file(tmp).expect("Failed to remove temporary file");
273
-
}
274
-
}
+3
-1
src/oauth.rs
+3
-1
src/oauth.rs
···
1
1
//! OAuth endpoints
2
2
#![allow(unnameable_types, unused_qualifications)]
3
+
use crate::config::AppConfig;
4
+
use crate::error::Error;
3
5
use crate::metrics::AUTH_FAILED;
4
-
use crate::{AppConfig, AppState, Client, Error, Result, SigningKey};
6
+
use crate::serve::{AppState, Client, Result, SigningKey};
5
7
use anyhow::{Context as _, anyhow};
6
8
use argon2::{Argon2, PasswordHash, PasswordVerifier as _};
7
9
use atrium_crypto::keypair::Did as _;
-114
src/plc.rs
-114
src/plc.rs
···
1
-
//! PLC operations.
2
-
use std::collections::HashMap;
3
-
4
-
use anyhow::{Context as _, bail};
5
-
use base64::Engine as _;
6
-
use serde::{Deserialize, Serialize};
7
-
use tracing::debug;
8
-
9
-
use crate::{Client, RotationKey};
10
-
11
-
/// The URL of the public PLC directory.
12
-
const PLC_DIRECTORY: &str = "https://plc.directory/";
13
-
14
-
#[derive(Debug, Deserialize, Serialize, Clone)]
15
-
#[serde(rename_all = "camelCase", tag = "type")]
16
-
/// A PLC service.
17
-
pub(crate) enum PlcService {
18
-
#[serde(rename = "AtprotoPersonalDataServer")]
19
-
/// A personal data server.
20
-
Pds {
21
-
/// The URL of the PDS.
22
-
endpoint: String,
23
-
},
24
-
}
25
-
26
-
#[expect(
27
-
clippy::arbitrary_source_item_ordering,
28
-
reason = "serialized data might be structured"
29
-
)]
30
-
#[derive(Debug, Deserialize, Serialize, Clone)]
31
-
#[serde(rename_all = "camelCase")]
32
-
pub(crate) struct PlcOperation {
33
-
#[serde(rename = "type")]
34
-
pub typ: String,
35
-
pub rotation_keys: Vec<String>,
36
-
pub verification_methods: HashMap<String, String>,
37
-
pub also_known_as: Vec<String>,
38
-
pub services: HashMap<String, PlcService>,
39
-
pub prev: Option<String>,
40
-
}
41
-
42
-
impl PlcOperation {
43
-
/// Sign an operation with the provided signature.
44
-
pub(crate) fn sign(self, sig: Vec<u8>) -> SignedPlcOperation {
45
-
SignedPlcOperation {
46
-
typ: self.typ,
47
-
rotation_keys: self.rotation_keys,
48
-
verification_methods: self.verification_methods,
49
-
also_known_as: self.also_known_as,
50
-
services: self.services,
51
-
prev: self.prev,
52
-
sig: base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig),
53
-
}
54
-
}
55
-
}
56
-
57
-
#[expect(
58
-
clippy::arbitrary_source_item_ordering,
59
-
reason = "serialized data might be structured"
60
-
)]
61
-
#[derive(Debug, Deserialize, Serialize, Clone)]
62
-
#[serde(rename_all = "camelCase")]
63
-
/// A signed PLC operation.
64
-
pub(crate) struct SignedPlcOperation {
65
-
#[serde(rename = "type")]
66
-
pub typ: String,
67
-
pub rotation_keys: Vec<String>,
68
-
pub verification_methods: HashMap<String, String>,
69
-
pub also_known_as: Vec<String>,
70
-
pub services: HashMap<String, PlcService>,
71
-
pub prev: Option<String>,
72
-
pub sig: String,
73
-
}
74
-
75
-
pub(crate) fn sign_op(rkey: &RotationKey, op: PlcOperation) -> anyhow::Result<SignedPlcOperation> {
76
-
let bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode op")?;
77
-
let bytes = rkey.sign(&bytes).context("failed to sign op")?;
78
-
79
-
Ok(op.sign(bytes))
80
-
}
81
-
82
-
/// Submit a PLC operation to the public directory.
83
-
pub(crate) async fn submit(
84
-
client: &Client,
85
-
did: &str,
86
-
op: &SignedPlcOperation,
87
-
) -> anyhow::Result<()> {
88
-
debug!(
89
-
"submitting {} {}",
90
-
did,
91
-
serde_json::to_string(&op).context("should serialize")?
92
-
);
93
-
94
-
let res = client
95
-
.post(format!("{PLC_DIRECTORY}{did}"))
96
-
.json(&op)
97
-
.send()
98
-
.await
99
-
.context("failed to send directory request")?;
100
-
101
-
if res.status().is_success() {
102
-
Ok(())
103
-
} else {
104
-
let e = res
105
-
.json::<serde_json::Value>()
106
-
.await
107
-
.context("failed to read error response")?;
108
-
109
-
bail!(
110
-
"error from PLC directory: {}",
111
-
serde_json::to_string(&e).context("should serialize")?
112
-
);
113
-
}
114
-
}
+415
src/serve.rs
+415
src/serve.rs
···
1
+
use super::account_manager::{AccountManager, SharedAccountManager};
2
+
use super::config::AppConfig;
3
+
use super::db::establish_pool;
4
+
pub use super::error::Error;
5
+
use super::service_proxy::service_proxy;
6
+
use anyhow::Context as _;
7
+
use atrium_api::types::string::Did;
8
+
use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
9
+
use axum::{Router, extract::FromRef, routing::get};
10
+
use clap::Parser;
11
+
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
12
+
use deadpool_diesel::sqlite::Pool;
13
+
use diesel::prelude::*;
14
+
use diesel_migrations::{EmbeddedMigrations, embed_migrations};
15
+
use figment::{Figment, providers::Format as _};
16
+
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
17
+
use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer};
18
+
use serde::{Deserialize, Serialize};
19
+
use std::{
20
+
net::{IpAddr, Ipv4Addr, SocketAddr},
21
+
path::PathBuf,
22
+
str::FromStr as _,
23
+
sync::Arc,
24
+
};
25
+
use tokio::{net::TcpListener, sync::RwLock};
26
+
use tower_http::{cors::CorsLayer, trace::TraceLayer};
27
+
use tracing::{info, warn};
28
+
use uuid::Uuid;
29
+
30
+
/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`.
31
+
pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
32
+
33
+
/// Embedded migrations
34
+
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
35
+
36
+
/// The application-wide result type.
37
+
pub type Result<T> = std::result::Result<T, Error>;
38
+
/// The reqwest client type with middleware.
39
+
pub type Client = reqwest_middleware::ClientWithMiddleware;
40
+
41
+
/// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose.
42
+
pub struct SharedSequencer {
43
+
/// The sequencer instance.
44
+
pub sequencer: RwLock<Sequencer>,
45
+
}
46
+
47
+
#[expect(
48
+
clippy::arbitrary_source_item_ordering,
49
+
reason = "serialized data might be structured"
50
+
)]
51
+
#[derive(Serialize, Deserialize, Debug, Clone)]
52
+
/// The key data structure.
53
+
struct KeyData {
54
+
/// Primary signing key for all repo operations.
55
+
skey: Vec<u8>,
56
+
/// Primary signing (rotation) key for all PLC operations.
57
+
rkey: Vec<u8>,
58
+
}
59
+
60
+
// FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies,
61
+
// and the implementations of this algorithm are much more limited as compared to P256.
62
+
//
63
+
// Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
64
+
#[derive(Clone)]
65
+
/// The signing key for PLC/DID operations.
66
+
pub struct SigningKey(Arc<Secp256k1Keypair>);
67
+
#[derive(Clone)]
68
+
/// The rotation key for PLC operations.
69
+
pub struct RotationKey(Arc<Secp256k1Keypair>);
70
+
71
+
impl std::ops::Deref for SigningKey {
72
+
type Target = Secp256k1Keypair;
73
+
74
+
fn deref(&self) -> &Self::Target {
75
+
&self.0
76
+
}
77
+
}
78
+
79
+
impl SigningKey {
80
+
/// Import from a private key.
81
+
pub fn import(key: &[u8]) -> Result<Self> {
82
+
let key = Secp256k1Keypair::import(key).context("failed to import signing key")?;
83
+
Ok(Self(Arc::new(key)))
84
+
}
85
+
}
86
+
87
+
impl std::ops::Deref for RotationKey {
88
+
type Target = Secp256k1Keypair;
89
+
90
+
fn deref(&self) -> &Self::Target {
91
+
&self.0
92
+
}
93
+
}
94
+
95
+
#[derive(Parser, Debug, Clone)]
96
+
/// Command line arguments.
97
+
pub struct Args {
98
+
/// Path to the configuration file
99
+
#[arg(short, long, default_value = "default.toml")]
100
+
pub config: PathBuf,
101
+
/// The verbosity level.
102
+
#[command(flatten)]
103
+
pub verbosity: Verbosity<InfoLevel>,
104
+
}
105
+
106
+
/// The actor pools for the database connections.
107
+
pub struct ActorStorage {
108
+
/// The database connection pool for the actor's repository.
109
+
pub repo: Pool,
110
+
/// The file storage path for the actor's blobs.
111
+
pub blob: PathBuf,
112
+
}
113
+
114
+
impl Clone for ActorStorage {
115
+
fn clone(&self) -> Self {
116
+
Self {
117
+
repo: self.repo.clone(),
118
+
blob: self.blob.clone(),
119
+
}
120
+
}
121
+
}
122
+
123
+
#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
124
+
#[derive(Clone, FromRef)]
125
+
/// The application state, shared across all routes.
126
+
pub struct AppState {
127
+
/// The application configuration.
128
+
pub(crate) config: AppConfig,
129
+
/// The main database connection pool. Used for common PDS data, like invite codes.
130
+
pub db: Pool,
131
+
/// Actor-specific database connection pools. Hashed by DID.
132
+
pub db_actors: std::collections::HashMap<String, ActorStorage>,
133
+
134
+
/// The HTTP client with middleware.
135
+
pub client: Client,
136
+
/// The simple HTTP client.
137
+
pub simple_client: reqwest::Client,
138
+
/// The firehose producer.
139
+
pub sequencer: Arc<SharedSequencer>,
140
+
/// The account manager.
141
+
pub account_manager: Arc<SharedAccountManager>,
142
+
143
+
/// The signing key.
144
+
pub signing_key: SigningKey,
145
+
/// The rotation key.
146
+
pub rotation_key: RotationKey,
147
+
}
148
+
149
+
/// The main application entry point.
150
+
#[expect(
151
+
clippy::cognitive_complexity,
152
+
clippy::too_many_lines,
153
+
unused_qualifications,
154
+
reason = "main function has high complexity"
155
+
)]
156
+
pub async fn run() -> anyhow::Result<()> {
157
+
let args = Args::parse();
158
+
159
+
// Set up trace logging to console and account for the user-provided verbosity flag.
160
+
if args.verbosity.log_level_filter() != LevelFilter::Off {
161
+
let lvl = match args.verbosity.log_level_filter() {
162
+
LevelFilter::Error => tracing::Level::ERROR,
163
+
LevelFilter::Warn => tracing::Level::WARN,
164
+
LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO,
165
+
LevelFilter::Debug => tracing::Level::DEBUG,
166
+
LevelFilter::Trace => tracing::Level::TRACE,
167
+
};
168
+
tracing_subscriber::fmt().with_max_level(lvl).init();
169
+
}
170
+
171
+
if !args.config.exists() {
172
+
// Throw up a warning if the config file does not exist.
173
+
//
174
+
// This is not fatal because users can specify all configuration settings via
175
+
// the environment, but the most likely scenario here is that a user accidentally
176
+
// omitted the config file for some reason (e.g. forgot to mount it into Docker).
177
+
warn!(
178
+
"configuration file {} does not exist",
179
+
args.config.display()
180
+
);
181
+
}
182
+
183
+
// Read and parse the user-provided configuration.
184
+
let config: AppConfig = Figment::new()
185
+
.admerge(figment::providers::Toml::file(args.config))
186
+
.admerge(figment::providers::Env::prefixed("BLUEPDS_"))
187
+
.extract()
188
+
.context("failed to load configuration")?;
189
+
190
+
if config.test {
191
+
warn!("BluePDS starting up in TEST mode.");
192
+
warn!("This means the application will not federate with the rest of the network.");
193
+
warn!(
194
+
"If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`"
195
+
);
196
+
}
197
+
198
+
// Initialize metrics reporting.
199
+
super::metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?;
200
+
201
+
// Create a reqwest client that will be used for all outbound requests.
202
+
let simple_client = reqwest::Client::builder()
203
+
.user_agent(APP_USER_AGENT)
204
+
.build()
205
+
.context("failed to build requester client")?;
206
+
let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
207
+
.with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
208
+
mode: CacheMode::Default,
209
+
manager: MokaManager::default(),
210
+
options: HttpCacheOptions::default(),
211
+
}))
212
+
.build();
213
+
214
+
tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?)
215
+
.await
216
+
.context("failed to create key directory")?;
217
+
218
+
// Check if crypto keys exist. If not, create new ones.
219
+
let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) {
220
+
let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f))
221
+
.context("failed to deserialize crypto keys")?;
222
+
223
+
let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?;
224
+
let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?;
225
+
226
+
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
227
+
} else {
228
+
info!("signing keys not found, generating new ones");
229
+
230
+
let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
231
+
let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
232
+
233
+
let keys = KeyData {
234
+
skey: skey.export(),
235
+
rkey: rkey.export(),
236
+
};
237
+
238
+
let mut f = std::fs::File::create(&config.key).context("failed to create key file")?;
239
+
serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?;
240
+
241
+
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
242
+
};
243
+
244
+
tokio::fs::create_dir_all(&config.repo.path).await?;
245
+
tokio::fs::create_dir_all(&config.plc.path).await?;
246
+
tokio::fs::create_dir_all(&config.blob.path).await?;
247
+
248
+
// Create a database connection manager and pool for the main database.
249
+
let pool =
250
+
establish_pool(&config.db).context("failed to establish database connection pool")?;
251
+
252
+
// Create a dictionary of database connection pools for each actor.
253
+
let mut actor_pools = std::collections::HashMap::new();
254
+
// We'll determine actors by looking in the data/repo dir for .db files.
255
+
let mut actor_dbs = tokio::fs::read_dir(&config.repo.path)
256
+
.await
257
+
.context("failed to read repo directory")?;
258
+
while let Some(entry) = actor_dbs
259
+
.next_entry()
260
+
.await
261
+
.context("failed to read repo dir")?
262
+
{
263
+
let path = entry.path();
264
+
if path.extension().and_then(|s| s.to_str()) == Some("db") {
265
+
let actor_repo_pool = establish_pool(&format!("sqlite://{}", path.display()))
266
+
.context("failed to create database connection pool")?;
267
+
268
+
let did = Did::from_str(&format!(
269
+
"did:plc:{}",
270
+
path.file_stem()
271
+
.and_then(|s| s.to_str())
272
+
.context("failed to get actor DID")?
273
+
))
274
+
.expect("should be able to parse actor DID")
275
+
.to_string();
276
+
let blob_path = config.blob.path.to_path_buf();
277
+
let actor_storage = ActorStorage {
278
+
repo: actor_repo_pool,
279
+
blob: blob_path.clone(),
280
+
};
281
+
drop(actor_pools.insert(did, actor_storage));
282
+
}
283
+
}
284
+
// Apply pending migrations
285
+
// let conn = pool.get().await?;
286
+
// conn.run_pending_migrations(MIGRATIONS)
287
+
// .expect("should be able to run migrations");
288
+
289
+
let hostname = config.host_name.clone();
290
+
let crawlers: Vec<String> = config
291
+
.firehose
292
+
.relays
293
+
.iter()
294
+
.map(|s| s.to_string())
295
+
.collect();
296
+
let sequencer = Arc::new(SharedSequencer {
297
+
sequencer: RwLock::new(Sequencer::new(
298
+
Crawlers::new(hostname, crawlers.clone()),
299
+
None,
300
+
)),
301
+
});
302
+
let account_manager = SharedAccountManager {
303
+
account_manager: RwLock::new(AccountManager::new(pool.clone())),
304
+
};
305
+
306
+
let addr = config
307
+
.listen_address
308
+
.unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000));
309
+
310
+
let app = Router::new()
311
+
.route("/", get(super::index))
312
+
.merge(super::oauth::routes())
313
+
.nest(
314
+
"/xrpc",
315
+
super::apis::routes()
316
+
.merge(super::actor_endpoints::routes())
317
+
.fallback(service_proxy),
318
+
)
319
+
// .layer(RateLimitLayer::new(30, Duration::from_secs(30)))
320
+
.layer(CorsLayer::permissive())
321
+
.layer(TraceLayer::new_for_http())
322
+
.with_state(AppState {
323
+
config: config.clone(),
324
+
db: pool.clone(),
325
+
db_actors: actor_pools.clone(),
326
+
client: client.clone(),
327
+
simple_client,
328
+
sequencer: sequencer.clone(),
329
+
account_manager: Arc::new(account_manager),
330
+
signing_key: skey,
331
+
rotation_key: rkey,
332
+
});
333
+
334
+
info!("listening on {addr}");
335
+
info!("connect to: http://127.0.0.1:{}", addr.port());
336
+
337
+
// Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
338
+
// If so, create an invite code and share it via the console.
339
+
let conn = pool.get().await.context("failed to get db connection")?;
340
+
341
+
#[derive(QueryableByName)]
342
+
struct TotalCount {
343
+
#[diesel(sql_type = diesel::sql_types::Integer)]
344
+
total_count: i32,
345
+
}
346
+
347
+
let result = conn.interact(move |conn| {
348
+
diesel::sql_query(
349
+
"SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count",
350
+
)
351
+
.get_result::<TotalCount>(conn)
352
+
})
353
+
.await
354
+
.expect("should be able to query database")?;
355
+
356
+
let c = result.total_count;
357
+
358
+
#[expect(clippy::print_stdout)]
359
+
if c == 0 {
360
+
let uuid = Uuid::new_v4().to_string();
361
+
362
+
use crate::models::pds as models;
363
+
use crate::schema::pds::invite_code::dsl as InviteCode;
364
+
let uuid_clone = uuid.clone();
365
+
drop(
366
+
conn.interact(move |conn| {
367
+
diesel::insert_into(InviteCode::invite_code)
368
+
.values(models::InviteCode {
369
+
code: uuid_clone,
370
+
available_uses: 1,
371
+
disabled: 0,
372
+
for_account: "None".to_owned(),
373
+
created_by: "None".to_owned(),
374
+
created_at: "None".to_owned(),
375
+
})
376
+
.execute(conn)
377
+
.context("failed to create new invite code")
378
+
})
379
+
.await
380
+
.expect("should be able to create invite code"),
381
+
);
382
+
383
+
// N.B: This is a sensitive message, so we're bypassing `tracing` here and
384
+
// logging it directly to console.
385
+
println!("=====================================");
386
+
println!(" FIRST STARTUP ");
387
+
println!("=====================================");
388
+
println!("Use this code to create an account:");
389
+
println!("{uuid}");
390
+
println!("=====================================");
391
+
}
392
+
393
+
let listener = TcpListener::bind(&addr)
394
+
.await
395
+
.context("failed to bind address")?;
396
+
397
+
// Serve the app, and request crawling from upstream relays.
398
+
let serve = tokio::spawn(async move {
399
+
axum::serve(listener, app.into_make_service())
400
+
.await
401
+
.context("failed to serve app")
402
+
});
403
+
404
+
// Now that the app is live, request a crawl from upstream relays.
405
+
let mut background_sequencer = sequencer.sequencer.write().await.clone();
406
+
drop(tokio::spawn(
407
+
async move { background_sequencer.start().await },
408
+
));
409
+
410
+
serve
411
+
.await
412
+
.map_err(Into::into)
413
+
.and_then(|r| r)
414
+
.context("failed to serve app")
415
+
}
+6
-26
src/service_proxy.rs
+6
-26
src/service_proxy.rs
···
3
3
/// Reference: <https://atproto.com/specs/xrpc#service-proxying>
4
4
use anyhow::{Context as _, anyhow};
5
5
use atrium_api::types::string::Did;
6
-
use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
7
6
use axum::{
8
-
Router,
9
7
body::Body,
10
-
extract::{FromRef, Request, State},
8
+
extract::{Request, State},
11
9
http::{self, HeaderMap, Response, StatusCode, Uri},
12
-
response::IntoResponse,
13
-
routing::get,
14
10
};
15
-
use azure_core::credentials::TokenCredential;
16
-
use clap::Parser;
17
-
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
18
-
use deadpool_diesel::sqlite::Pool;
19
-
use diesel::prelude::*;
20
-
use diesel_migrations::{EmbeddedMigrations, embed_migrations};
21
-
use figment::{Figment, providers::Format as _};
22
-
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
23
11
use rand::Rng as _;
24
-
use serde::{Deserialize, Serialize};
25
-
use std::{
26
-
net::{IpAddr, Ipv4Addr, SocketAddr},
27
-
path::PathBuf,
28
-
str::FromStr as _,
29
-
sync::Arc,
30
-
};
31
-
use tokio::net::TcpListener;
32
-
use tower_http::{cors::CorsLayer, trace::TraceLayer};
33
-
use tracing::{info, warn};
34
-
use uuid::Uuid;
12
+
use std::str::FromStr as _;
35
13
36
-
use super::{Client, Error, Result};
37
-
use crate::{AuthenticatedUser, SigningKey};
14
+
use super::{
15
+
auth::AuthenticatedUser,
16
+
serve::{Client, Error, Result, SigningKey},
17
+
};
38
18
39
19
pub(super) async fn service_proxy(
40
20
uri: Uri,
-459
src/tests.rs
-459
src/tests.rs
···
1
-
//! Testing utilities for the PDS.
2
-
#![expect(clippy::arbitrary_source_item_ordering)]
3
-
use std::{
4
-
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener},
5
-
path::PathBuf,
6
-
time::{Duration, Instant},
7
-
};
8
-
9
-
use anyhow::Result;
10
-
use atrium_api::{
11
-
com::atproto::server,
12
-
types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey},
13
-
};
14
-
use figment::{Figment, providers::Format as _};
15
-
use futures::future::join_all;
16
-
use serde::{Deserialize, Serialize};
17
-
use tokio::sync::OnceCell;
18
-
use uuid::Uuid;
19
-
20
-
use crate::config::AppConfig;
21
-
22
-
/// Global test state, created once for all tests.
23
-
pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new();
24
-
25
-
/// A temporary test directory that will be cleaned up when the struct is dropped.
26
-
struct TempDir {
27
-
/// The path to the directory.
28
-
path: PathBuf,
29
-
}
30
-
31
-
impl TempDir {
32
-
/// Create a new temporary directory.
33
-
fn new() -> Result<Self> {
34
-
let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4()));
35
-
std::fs::create_dir_all(&path)?;
36
-
Ok(Self { path })
37
-
}
38
-
39
-
/// Get the path to the directory.
40
-
fn path(&self) -> &PathBuf {
41
-
&self.path
42
-
}
43
-
}
44
-
45
-
impl Drop for TempDir {
46
-
fn drop(&mut self) {
47
-
drop(std::fs::remove_dir_all(&self.path));
48
-
}
49
-
}
50
-
51
-
/// Test state for the application.
52
-
pub(crate) struct TestState {
53
-
/// The address the test server is listening on.
54
-
address: SocketAddr,
55
-
/// The HTTP client.
56
-
client: reqwest::Client,
57
-
/// The application configuration.
58
-
config: AppConfig,
59
-
/// The temporary directory for test data.
60
-
#[expect(dead_code)]
61
-
temp_dir: TempDir,
62
-
}
63
-
64
-
impl TestState {
65
-
/// Get a base URL for the test server.
66
-
pub(crate) fn base_url(&self) -> String {
67
-
format!("http://{}", self.address)
68
-
}
69
-
70
-
/// Create a test account.
71
-
pub(crate) async fn create_test_account(&self) -> Result<TestAccount> {
72
-
// Create the account
73
-
let handle = "test.handle";
74
-
let response = self
75
-
.client
76
-
.post(format!(
77
-
"http://{}/xrpc/com.atproto.server.createAccount",
78
-
self.address
79
-
))
80
-
.json(&server::create_account::InputData {
81
-
did: None,
82
-
verification_code: None,
83
-
verification_phone: None,
84
-
email: Some(format!("{}@example.com", &handle)),
85
-
handle: Handle::new(handle.to_owned()).expect("should be able to create handle"),
86
-
password: Some("password123".to_owned()),
87
-
invite_code: None,
88
-
recovery_key: None,
89
-
plc_op: None,
90
-
})
91
-
.send()
92
-
.await?;
93
-
94
-
let account: server::create_account::Output = response.json().await?;
95
-
96
-
Ok(TestAccount {
97
-
handle: handle.to_owned(),
98
-
did: account.did.to_string(),
99
-
access_token: account.access_jwt.clone(),
100
-
refresh_token: account.refresh_jwt.clone(),
101
-
})
102
-
}
103
-
104
-
/// Create a new test state.
105
-
#[expect(clippy::unused_async)]
106
-
async fn new() -> Result<Self> {
107
-
// Configure the test app
108
-
#[derive(Serialize, Deserialize)]
109
-
struct TestConfigInput {
110
-
db: Option<String>,
111
-
host_name: Option<String>,
112
-
key: Option<PathBuf>,
113
-
listen_address: Option<SocketAddr>,
114
-
test: Option<bool>,
115
-
}
116
-
// Create a temporary directory for test data
117
-
let temp_dir = TempDir::new()?;
118
-
119
-
// Find a free port
120
-
let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
121
-
let address = listener.local_addr()?;
122
-
drop(listener);
123
-
124
-
let test_config = TestConfigInput {
125
-
db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())),
126
-
host_name: Some(format!("localhost:{}", address.port())),
127
-
key: Some(temp_dir.path().join("test.key")),
128
-
listen_address: Some(address),
129
-
test: Some(true),
130
-
};
131
-
132
-
let config: AppConfig = Figment::new()
133
-
.admerge(figment::providers::Toml::file("default.toml"))
134
-
.admerge(figment::providers::Env::prefixed("BLUEPDS_"))
135
-
.merge(figment::providers::Serialized::defaults(test_config))
136
-
.merge(
137
-
figment::providers::Toml::string(
138
-
r#"
139
-
[firehose]
140
-
relays = []
141
-
142
-
[repo]
143
-
path = "repo"
144
-
145
-
[plc]
146
-
path = "plc"
147
-
148
-
[blob]
149
-
path = "blob"
150
-
limit = 10485760 # 10 MB
151
-
"#,
152
-
)
153
-
.nested(),
154
-
)
155
-
.extract()?;
156
-
157
-
// Create directories
158
-
std::fs::create_dir_all(temp_dir.path().join("repo"))?;
159
-
std::fs::create_dir_all(temp_dir.path().join("plc"))?;
160
-
std::fs::create_dir_all(temp_dir.path().join("blob"))?;
161
-
162
-
// Create client
163
-
let client = reqwest::Client::builder()
164
-
.timeout(Duration::from_secs(30))
165
-
.build()?;
166
-
167
-
Ok(Self {
168
-
address,
169
-
client,
170
-
config,
171
-
temp_dir,
172
-
})
173
-
}
174
-
175
-
/// Start the application in a background task.
176
-
async fn start_app(&self) -> Result<()> {
177
-
// // Get a reference to the config that can be moved into the task
178
-
// let config = self.config.clone();
179
-
// let address = self.address;
180
-
181
-
// // Start the application in a background task
182
-
// let _handle = tokio::spawn(async move {
183
-
// // Set up the application
184
-
// use crate::*;
185
-
186
-
// // Initialize metrics (noop in test mode)
187
-
// drop(metrics::setup(None));
188
-
189
-
// // Create client
190
-
// let simple_client = reqwest::Client::builder()
191
-
// .user_agent(APP_USER_AGENT)
192
-
// .build()
193
-
// .context("failed to build requester client")?;
194
-
// let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
195
-
// .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
196
-
// mode: CacheMode::Default,
197
-
// manager: MokaManager::default(),
198
-
// options: HttpCacheOptions::default(),
199
-
// }))
200
-
// .build();
201
-
202
-
// // Create a test keypair
203
-
// std::fs::create_dir_all(config.key.parent().context("should have parent")?)?;
204
-
// let (skey, rkey) = {
205
-
// let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
206
-
// let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
207
-
208
-
// let keys = KeyData {
209
-
// skey: skey.export(),
210
-
// rkey: rkey.export(),
211
-
// };
212
-
213
-
// let mut f =
214
-
// std::fs::File::create(&config.key).context("failed to create key file")?;
215
-
// serde_ipld_dagcbor::to_writer(&mut f, &keys)
216
-
// .context("failed to serialize crypto keys")?;
217
-
218
-
// (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
219
-
// };
220
-
221
-
// // Set up database
222
-
// let opts = SqliteConnectOptions::from_str(&config.db)
223
-
// .context("failed to parse database options")?
224
-
// .create_if_missing(true);
225
-
// let db = SqliteDbConn::connect_with(opts).await?;
226
-
227
-
// sqlx::migrate!()
228
-
// .run(&db)
229
-
// .await
230
-
// .context("failed to apply migrations")?;
231
-
232
-
// // Create firehose
233
-
// let (_fh, fhp) = firehose::spawn(client.clone(), config.clone());
234
-
235
-
// // Create the application state
236
-
// let app_state = AppState {
237
-
// cred: azure_identity::DefaultAzureCredential::new()?,
238
-
// config: config.clone(),
239
-
// db: db.clone(),
240
-
// client: client.clone(),
241
-
// simple_client,
242
-
// firehose: fhp,
243
-
// signing_key: skey,
244
-
// rotation_key: rkey,
245
-
// };
246
-
247
-
// // Create the router
248
-
// let app = Router::new()
249
-
// .route("/", get(index))
250
-
// .merge(oauth::routes())
251
-
// .nest(
252
-
// "/xrpc",
253
-
// endpoints::routes()
254
-
// .merge(actor_endpoints::routes())
255
-
// .fallback(service_proxy),
256
-
// )
257
-
// .layer(CorsLayer::permissive())
258
-
// .layer(TraceLayer::new_for_http())
259
-
// .with_state(app_state);
260
-
261
-
// // Listen for connections
262
-
// let listener = TcpListener::bind(&address)
263
-
// .await
264
-
// .context("failed to bind address")?;
265
-
266
-
// axum::serve(listener, app.into_make_service())
267
-
// .await
268
-
// .context("failed to serve app")
269
-
// });
270
-
271
-
// // Give the server a moment to start
272
-
// tokio::time::sleep(Duration::from_millis(500)).await;
273
-
274
-
Ok(())
275
-
}
276
-
}
277
-
278
-
/// A test account that can be used for testing.
279
-
pub(crate) struct TestAccount {
280
-
/// The access token for the account.
281
-
pub(crate) access_token: String,
282
-
/// The account DID.
283
-
pub(crate) did: String,
284
-
/// The account handle.
285
-
pub(crate) handle: String,
286
-
/// The refresh token for the account.
287
-
#[expect(dead_code)]
288
-
pub(crate) refresh_token: String,
289
-
}
290
-
291
-
/// Initialize the test state.
292
-
pub(crate) async fn init_test_state() -> Result<&'static TestState> {
293
-
async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> {
294
-
let state = TestState::new().await?;
295
-
state.start_app().await?;
296
-
Ok(state)
297
-
}
298
-
TEST_STATE.get_or_try_init(init_test_state).await
299
-
}
300
-
301
-
/// Create a record benchmark that creates records and measures the time it takes.
302
-
#[expect(
303
-
clippy::arithmetic_side_effects,
304
-
clippy::integer_division,
305
-
clippy::integer_division_remainder_used,
306
-
clippy::use_debug,
307
-
clippy::print_stdout
308
-
)]
309
-
pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> {
310
-
// Initialize the test state
311
-
let state = init_test_state().await?;
312
-
313
-
// Create a test account
314
-
let account = state.create_test_account().await?;
315
-
316
-
// Create the client with authorization
317
-
let client = reqwest::Client::builder()
318
-
.timeout(Duration::from_secs(30))
319
-
.build()?;
320
-
321
-
let start = Instant::now();
322
-
323
-
// Split the work into batches
324
-
let mut handles = Vec::new();
325
-
for batch_idx in 0..concurrent {
326
-
let batch_size = count / concurrent;
327
-
let client = client.clone();
328
-
let base_url = state.base_url();
329
-
let account_did = account.did.clone();
330
-
let account_handle = account.handle.clone();
331
-
let access_token = account.access_token.clone();
332
-
333
-
let handle = tokio::spawn(async move {
334
-
let mut results = Vec::new();
335
-
336
-
for i in 0..batch_size {
337
-
let request_start = Instant::now();
338
-
let record_idx = batch_idx * batch_size + i;
339
-
340
-
let result = client
341
-
.post(format!("{base_url}/xrpc/com.atproto.repo.createRecord"))
342
-
.header("Authorization", format!("Bearer {access_token}"))
343
-
.json(&atrium_api::com::atproto::repo::create_record::InputData {
344
-
repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")),
345
-
collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"),
346
-
rkey: Some(
347
-
RecordKey::new(format!("test-{record_idx}")).expect("valid record key"),
348
-
),
349
-
validate: None,
350
-
record: serde_json::from_str(
351
-
&serde_json::json!({
352
-
"$type": "app.bsky.feed.post",
353
-
"text": format!("Test post {record_idx} from {account_handle}"),
354
-
"createdAt": chrono::Utc::now().to_rfc3339(),
355
-
})
356
-
.to_string(),
357
-
)
358
-
.expect("valid JSON record"),
359
-
swap_commit: None,
360
-
})
361
-
.send()
362
-
.await;
363
-
364
-
// Fetch the record we just created
365
-
let get_response = client
366
-
.get(format!(
367
-
"{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}"
368
-
))
369
-
.header("Authorization", format!("Bearer {access_token}"))
370
-
.send()
371
-
.await;
372
-
if get_response.is_err() {
373
-
println!("Failed to fetch record {record_idx}: {get_response:?}");
374
-
results.push(get_response);
375
-
continue;
376
-
}
377
-
378
-
let request_duration = request_start.elapsed();
379
-
if record_idx % 10 == 0 {
380
-
println!("Created record {record_idx} in {request_duration:?}");
381
-
}
382
-
results.push(result);
383
-
}
384
-
385
-
results
386
-
});
387
-
388
-
handles.push(handle);
389
-
}
390
-
391
-
// Wait for all batches to complete
392
-
let results = join_all(handles).await;
393
-
394
-
// Check for errors
395
-
for batch_result in results {
396
-
let batch_responses = batch_result?;
397
-
for response_result in batch_responses {
398
-
match response_result {
399
-
Ok(response) => {
400
-
if !response.status().is_success() {
401
-
return Err(anyhow::anyhow!(
402
-
"Failed to create record: {}",
403
-
response.status()
404
-
));
405
-
}
406
-
}
407
-
Err(err) => {
408
-
return Err(anyhow::anyhow!("Failed to create record: {}", err));
409
-
}
410
-
}
411
-
}
412
-
}
413
-
414
-
let duration = start.elapsed();
415
-
Ok(duration)
416
-
}
417
-
418
-
#[cfg(test)]
419
-
#[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)]
420
-
mod tests {
421
-
use super::*;
422
-
use anyhow::anyhow;
423
-
424
-
#[tokio::test]
425
-
async fn test_create_account() -> Result<()> {
426
-
return Ok(());
427
-
#[expect(unreachable_code, reason = "Disabled")]
428
-
let state = init_test_state().await?;
429
-
let account = state.create_test_account().await?;
430
-
431
-
println!("Created test account: {}", account.handle);
432
-
if account.handle.is_empty() {
433
-
return Err(anyhow::anyhow!("Account handle is empty"));
434
-
}
435
-
if account.did.is_empty() {
436
-
return Err(anyhow::anyhow!("Account DID is empty"));
437
-
}
438
-
if account.access_token.is_empty() {
439
-
return Err(anyhow::anyhow!("Account access token is empty"));
440
-
}
441
-
442
-
Ok(())
443
-
}
444
-
445
-
#[tokio::test]
446
-
async fn test_create_record_benchmark() -> Result<()> {
447
-
return Ok(());
448
-
#[expect(unreachable_code, reason = "Disabled")]
449
-
let duration = create_record_benchmark(100, 1).await?;
450
-
451
-
println!("Created 100 records in {duration:?}");
452
-
453
-
if duration.as_secs() >= 10 {
454
-
return Err(anyhow!("Benchmark took too long"));
455
-
}
456
-
457
-
Ok(())
458
-
}
459
-
}