+491
src/account_manager/mod.rs
+491
src/account_manager/mod.rs
···
1
+
use anyhow::Result;
2
+
use chrono::DateTime;
3
+
use chrono::offset::Utc as UtcOffset;
4
+
use cidv10::Cid;
5
+
use futures::try_join;
6
+
use rsky_common::RFC3339_VARIANT;
7
+
use rsky_common::time::{HOUR, from_micros_to_str, from_str_to_micros};
8
+
use rsky_lexicon::com::atproto::admin::StatusAttr;
9
+
use rsky_lexicon::com::atproto::server::{AccountCodes, CreateAppPasswordOutput};
10
+
use rsky_pds::account_manager::CreateAccountOpts;
11
+
use rsky_pds::account_manager::helpers::account::{
12
+
AccountStatus, ActorAccount, AvailabilityFlags, GetAccountAdminStatusOutput,
13
+
};
14
+
use rsky_pds::account_manager::helpers::auth::{
15
+
AuthHelperError, CreateTokensOpts, RefreshGracePeriodOpts,
16
+
};
17
+
use rsky_pds::account_manager::helpers::invite::CodeDetail;
18
+
use rsky_pds::account_manager::helpers::password::UpdateUserPasswordOpts;
19
+
use rsky_pds::account_manager::helpers::repo;
20
+
use rsky_pds::account_manager::helpers::{account, auth, email_token, invite, password};
21
+
use rsky_pds::auth_verifier::AuthScope;
22
+
use rsky_pds::models::models::EmailTokenPurpose;
23
+
use secp256k1::{Keypair, Secp256k1, SecretKey};
24
+
use std::collections::BTreeMap;
25
+
use std::env;
26
+
use std::sync::Arc;
27
+
use std::time::SystemTime;
28
+
29
+
use crate::db::DbConn;
30
+
31
+
#[derive(Clone, Debug)]
32
+
pub struct AccountManager {
33
+
pub db: Arc<DbConn>,
34
+
}
35
+
36
+
pub type AccountManagerCreator = Box<dyn Fn(Arc<DbConn>) -> AccountManager + Send + Sync>;
37
+
38
+
impl AccountManager {
39
+
pub fn new(db: Arc<DbConn>) -> Self {
40
+
Self { db }
41
+
}
42
+
43
+
pub fn creator() -> AccountManagerCreator {
44
+
Box::new(move |db: Arc<DbConn>| -> AccountManager { AccountManager::new(db) })
45
+
}
46
+
47
+
pub async fn get_account(
48
+
&self,
49
+
handle_or_did: &str,
50
+
flags: Option<AvailabilityFlags>,
51
+
) -> Result<Option<ActorAccount>> {
52
+
let db = self.db.clone();
53
+
account::get_account(handle_or_did, flags, db.as_ref()).await
54
+
}
55
+
56
+
pub async fn get_account_by_email(
57
+
&self,
58
+
email: &str,
59
+
flags: Option<AvailabilityFlags>,
60
+
) -> Result<Option<ActorAccount>> {
61
+
let db = self.db.clone();
62
+
account::get_account_by_email(email, flags, db.as_ref()).await
63
+
}
64
+
65
+
pub async fn is_account_activated(&self, did: &str) -> Result<bool> {
66
+
let account = self
67
+
.get_account(
68
+
did,
69
+
Some(AvailabilityFlags {
70
+
include_taken_down: None,
71
+
include_deactivated: Some(true),
72
+
}),
73
+
)
74
+
.await?;
75
+
if let Some(account) = account {
76
+
Ok(account.deactivated_at.is_none())
77
+
} else {
78
+
Ok(false)
79
+
}
80
+
}
81
+
82
+
pub async fn get_did_for_actor(
83
+
&self,
84
+
handle_or_did: &str,
85
+
flags: Option<AvailabilityFlags>,
86
+
) -> Result<Option<String>> {
87
+
match self.get_account(handle_or_did, flags).await {
88
+
Ok(Some(got)) => Ok(Some(got.did)),
89
+
_ => Ok(None),
90
+
}
91
+
}
92
+
93
+
pub async fn create_account(&self, opts: CreateAccountOpts) -> Result<(String, String)> {
94
+
let db = self.db.clone();
95
+
let CreateAccountOpts {
96
+
did,
97
+
handle,
98
+
email,
99
+
password,
100
+
repo_cid,
101
+
repo_rev,
102
+
invite_code,
103
+
deactivated,
104
+
} = opts;
105
+
let password_encrypted: Option<String> = match password {
106
+
Some(password) => Some(password::gen_salt_and_hash(password)?),
107
+
None => None,
108
+
};
109
+
// Should be a global var so this only happens once
110
+
let secp = Secp256k1::new();
111
+
let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX")?;
112
+
let secret_key =
113
+
SecretKey::from_slice(&Result::unwrap(hex::decode(private_key.as_bytes())))?;
114
+
let jwt_key = Keypair::from_secret_key(&secp, &secret_key);
115
+
let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts {
116
+
did: did.clone(),
117
+
jwt_key,
118
+
service_did: env::var("PDS_SERVICE_DID").unwrap(),
119
+
scope: Some(AuthScope::Access),
120
+
jti: None,
121
+
expires_in: None,
122
+
})?;
123
+
let refresh_payload = auth::decode_refresh_token(refresh_jwt.clone(), jwt_key)?;
124
+
let now = rsky_common::now();
125
+
126
+
if let Some(invite_code) = invite_code.clone() {
127
+
invite::ensure_invite_is_available(invite_code, db.as_ref()).await?;
128
+
}
129
+
account::register_actor(did.clone(), handle, deactivated, db.as_ref()).await?;
130
+
if let (Some(email), Some(password_encrypted)) = (email, password_encrypted) {
131
+
account::register_account(did.clone(), email, password_encrypted, db.as_ref()).await?;
132
+
}
133
+
invite::record_invite_use(did.clone(), invite_code, now, db.as_ref()).await?;
134
+
auth::store_refresh_token(refresh_payload, None, db.as_ref()).await?;
135
+
repo::update_root(did, repo_cid, repo_rev, db.as_ref()).await?;
136
+
Ok((access_jwt, refresh_jwt))
137
+
}
138
+
139
+
pub async fn get_account_admin_status(
140
+
&self,
141
+
did: &str,
142
+
) -> Result<Option<GetAccountAdminStatusOutput>> {
143
+
let db = self.db.clone();
144
+
account::get_account_admin_status(did, db.as_ref()).await
145
+
}
146
+
147
+
pub async fn update_repo_root(&self, did: String, cid: Cid, rev: String) -> Result<()> {
148
+
let db = self.db.clone();
149
+
repo::update_root(did, cid, rev, db.as_ref()).await
150
+
}
151
+
152
+
pub async fn delete_account(&self, did: &str) -> Result<()> {
153
+
let db = self.db.clone();
154
+
account::delete_account(did, db.as_ref()).await
155
+
}
156
+
157
+
pub async fn takedown_account(&self, did: &str, takedown: StatusAttr) -> Result<()> {
158
+
(_, _) = try_join!(
159
+
account::update_account_takedown_status(did, takedown, self.db.as_ref()),
160
+
auth::revoke_refresh_tokens_by_did(did, self.db.as_ref())
161
+
)?;
162
+
Ok(())
163
+
}
164
+
165
+
// @NOTE should always be paired with a sequenceHandle().
166
+
pub async fn update_handle(&self, did: &str, handle: &str) -> Result<()> {
167
+
let db = self.db.clone();
168
+
account::update_handle(did, handle, db.as_ref()).await
169
+
}
170
+
171
+
pub async fn deactivate_account(&self, did: &str, delete_after: Option<String>) -> Result<()> {
172
+
account::deactivate_account(did, delete_after, self.db.as_ref()).await
173
+
}
174
+
175
+
pub async fn activate_account(&self, did: &str) -> Result<()> {
176
+
let db = self.db.clone();
177
+
account::activate_account(did, db.as_ref()).await
178
+
}
179
+
180
+
pub async fn get_account_status(&self, handle_or_did: &str) -> Result<AccountStatus> {
181
+
let got = account::get_account(
182
+
handle_or_did,
183
+
Some(AvailabilityFlags {
184
+
include_deactivated: Some(true),
185
+
include_taken_down: Some(true),
186
+
}),
187
+
self.db.as_ref(),
188
+
)
189
+
.await?;
190
+
let res = account::format_account_status(got);
191
+
match res.active {
192
+
true => Ok(AccountStatus::Active),
193
+
false => Ok(res.status.expect("Account status not properly formatted.")),
194
+
}
195
+
}
196
+
197
+
// Auth
198
+
// ----------
199
+
pub async fn create_session(
200
+
&self,
201
+
did: String,
202
+
app_password_name: Option<String>,
203
+
) -> Result<(String, String)> {
204
+
let db = self.db.clone();
205
+
let secp = Secp256k1::new();
206
+
let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX")?;
207
+
let secret_key = SecretKey::from_slice(&hex::decode(private_key.as_bytes())?)?;
208
+
let jwt_key = Keypair::from_secret_key(&secp, &secret_key);
209
+
let scope = if app_password_name.is_none() {
210
+
AuthScope::Access
211
+
} else {
212
+
AuthScope::AppPass
213
+
};
214
+
let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts {
215
+
did,
216
+
jwt_key,
217
+
service_did: env::var("PDS_SERVICE_DID").unwrap(),
218
+
scope: Some(scope),
219
+
jti: None,
220
+
expires_in: None,
221
+
})?;
222
+
let refresh_payload = auth::decode_refresh_token(refresh_jwt.clone(), jwt_key)?;
223
+
auth::store_refresh_token(refresh_payload, app_password_name, db.as_ref()).await?;
224
+
Ok((access_jwt, refresh_jwt))
225
+
}
226
+
227
+
pub async fn rotate_refresh_token(&self, id: &String) -> Result<Option<(String, String)>> {
228
+
let token = auth::get_refresh_token(id, self.db.as_ref()).await?;
229
+
if let Some(token) = token {
230
+
let system_time = SystemTime::now();
231
+
let dt: DateTime<UtcOffset> = system_time.into();
232
+
let now = format!("{}", dt.format(RFC3339_VARIANT));
233
+
234
+
// take the chance to tidy all of a user's expired tokens
235
+
// does not need to be transactional since this is just best-effort
236
+
auth::delete_expired_refresh_tokens(&token.did, now, self.db.as_ref()).await?;
237
+
238
+
// Shorten the refresh token lifespan down from its
239
+
// original expiration time to its revocation grace period.
240
+
let prev_expires_at = from_str_to_micros(&token.expires_at);
241
+
242
+
const REFRESH_GRACE_MS: i32 = 2 * HOUR;
243
+
let grace_expires_at = dt.timestamp_micros() + REFRESH_GRACE_MS as i64;
244
+
245
+
let expires_at = if grace_expires_at < prev_expires_at {
246
+
grace_expires_at
247
+
} else {
248
+
prev_expires_at
249
+
};
250
+
251
+
if expires_at <= dt.timestamp_micros() {
252
+
return Ok(None);
253
+
}
254
+
255
+
// Determine the next refresh token id: upon refresh token
256
+
// reuse you always receive a refresh token with the same id.
257
+
let next_id = token.next_id.unwrap_or_else(auth::get_refresh_token_id);
258
+
259
+
let secp = Secp256k1::new();
260
+
let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX").unwrap();
261
+
let secret_key =
262
+
SecretKey::from_slice(&hex::decode(private_key.as_bytes()).unwrap()).unwrap();
263
+
let jwt_key = Keypair::from_secret_key(&secp, &secret_key);
264
+
265
+
let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts {
266
+
did: token.did,
267
+
jwt_key,
268
+
service_did: env::var("PDS_SERVICE_DID").unwrap(),
269
+
scope: Some(if token.app_password_name.is_none() {
270
+
AuthScope::Access
271
+
} else {
272
+
AuthScope::AppPass
273
+
}),
274
+
jti: Some(next_id.clone()),
275
+
expires_in: None,
276
+
})?;
277
+
let refresh_payload = auth::decode_refresh_token(refresh_jwt.clone(), jwt_key)?;
278
+
match try_join!(
279
+
auth::add_refresh_grace_period(
280
+
RefreshGracePeriodOpts {
281
+
id: id.clone(),
282
+
expires_at: from_micros_to_str(expires_at),
283
+
next_id
284
+
},
285
+
self.db.as_ref()
286
+
),
287
+
auth::store_refresh_token(
288
+
refresh_payload,
289
+
token.app_password_name,
290
+
self.db.as_ref()
291
+
)
292
+
) {
293
+
Ok(_) => Ok(Some((access_jwt, refresh_jwt))),
294
+
Err(e) => match e.downcast_ref() {
295
+
Some(AuthHelperError::ConcurrentRefresh) => {
296
+
Box::pin(self.rotate_refresh_token(id)).await
297
+
}
298
+
_ => Err(e),
299
+
},
300
+
}
301
+
} else {
302
+
Ok(None)
303
+
}
304
+
}
305
+
306
+
pub async fn revoke_refresh_token(&self, id: String) -> Result<bool> {
307
+
auth::revoke_refresh_token(id, self.db.as_ref()).await
308
+
}
309
+
310
+
// Invites
311
+
// ----------
312
+
313
+
pub async fn create_invite_codes(
314
+
&self,
315
+
to_create: Vec<AccountCodes>,
316
+
use_count: i32,
317
+
) -> Result<()> {
318
+
let db = self.db.clone();
319
+
invite::create_invite_codes(to_create, use_count, db.as_ref()).await
320
+
}
321
+
322
+
pub async fn create_account_invite_codes(
323
+
&self,
324
+
for_account: &str,
325
+
codes: Vec<String>,
326
+
expected_total: usize,
327
+
disabled: bool,
328
+
) -> Result<Vec<CodeDetail>> {
329
+
invite::create_account_invite_codes(
330
+
for_account,
331
+
codes,
332
+
expected_total,
333
+
disabled,
334
+
self.db.as_ref(),
335
+
)
336
+
.await
337
+
}
338
+
339
+
pub async fn get_account_invite_codes(&self, did: &str) -> Result<Vec<CodeDetail>> {
340
+
let db = self.db.clone();
341
+
invite::get_account_invite_codes(did, db.as_ref()).await
342
+
}
343
+
344
+
pub async fn get_invited_by_for_accounts(
345
+
&self,
346
+
dids: Vec<String>,
347
+
) -> Result<BTreeMap<String, CodeDetail>> {
348
+
let db = self.db.clone();
349
+
invite::get_invited_by_for_accounts(dids, db.as_ref()).await
350
+
}
351
+
352
+
pub async fn set_account_invites_disabled(&self, did: &str, disabled: bool) -> Result<()> {
353
+
invite::set_account_invites_disabled(did, disabled, self.db.as_ref()).await
354
+
}
355
+
356
+
pub async fn disable_invite_codes(&self, opts: DisableInviteCodesOpts) -> Result<()> {
357
+
invite::disable_invite_codes(opts, self.db.as_ref()).await
358
+
}
359
+
360
+
// Passwords
361
+
// ----------
362
+
363
+
pub async fn create_app_password(
364
+
&self,
365
+
did: String,
366
+
name: String,
367
+
) -> Result<CreateAppPasswordOutput> {
368
+
password::create_app_password(did, name, self.db.as_ref()).await
369
+
}
370
+
371
+
pub async fn list_app_passwords(&self, did: &str) -> Result<Vec<(String, String)>> {
372
+
password::list_app_passwords(did, self.db.as_ref()).await
373
+
}
374
+
375
+
pub async fn verify_account_password(&self, did: &str, password_str: &String) -> Result<bool> {
376
+
let db = self.db.clone();
377
+
password::verify_account_password(did, password_str, db.as_ref()).await
378
+
}
379
+
380
+
pub async fn verify_app_password(
381
+
&self,
382
+
did: &str,
383
+
password_str: &str,
384
+
) -> Result<Option<String>> {
385
+
let db = self.db.clone();
386
+
password::verify_app_password(did, password_str, db.as_ref()).await
387
+
}
388
+
389
+
pub async fn reset_password(&self, opts: ResetPasswordOpts) -> Result<()> {
390
+
let db = self.db.clone();
391
+
let did = email_token::assert_valid_token_and_find_did(
392
+
EmailTokenPurpose::ResetPassword,
393
+
&opts.token,
394
+
None,
395
+
db.as_ref(),
396
+
)
397
+
.await?;
398
+
self.update_account_password(UpdateAccountPasswordOpts {
399
+
did,
400
+
password: opts.password,
401
+
})
402
+
.await
403
+
}
404
+
405
+
pub async fn update_account_password(&self, opts: UpdateAccountPasswordOpts) -> Result<()> {
406
+
let db = self.db.clone();
407
+
let UpdateAccountPasswordOpts { did, .. } = opts;
408
+
let password_encrypted = password::gen_salt_and_hash(opts.password)?;
409
+
try_join!(
410
+
password::update_user_password(
411
+
UpdateUserPasswordOpts {
412
+
did: did.clone(),
413
+
password_encrypted
414
+
},
415
+
self.db.as_ref()
416
+
),
417
+
email_token::delete_email_token(&did, EmailTokenPurpose::ResetPassword, db.as_ref()),
418
+
auth::revoke_refresh_tokens_by_did(&did, self.db.as_ref())
419
+
)?;
420
+
Ok(())
421
+
}
422
+
423
+
pub async fn revoke_app_password(&self, did: String, name: String) -> Result<()> {
424
+
try_join!(
425
+
password::delete_app_password(&did, &name, self.db.as_ref()),
426
+
auth::revoke_app_password_refresh_token(&did, &name, self.db.as_ref())
427
+
)?;
428
+
Ok(())
429
+
}
430
+
431
+
// Email Tokens
432
+
// ----------
433
+
pub async fn confirm_email<'em>(&self, opts: ConfirmEmailOpts<'em>) -> Result<()> {
434
+
let db = self.db.clone();
435
+
let ConfirmEmailOpts { did, token } = opts;
436
+
email_token::assert_valid_token(
437
+
did,
438
+
EmailTokenPurpose::ConfirmEmail,
439
+
token,
440
+
None,
441
+
db.as_ref(),
442
+
)
443
+
.await?;
444
+
let now = rsky_common::now();
445
+
try_join!(
446
+
email_token::delete_email_token(did, EmailTokenPurpose::ConfirmEmail, db.as_ref()),
447
+
account::set_email_confirmed_at(did, now, self.db.as_ref())
448
+
)?;
449
+
Ok(())
450
+
}
451
+
452
+
pub async fn update_email(&self, opts: UpdateEmailOpts) -> Result<()> {
453
+
let db = self.db.clone();
454
+
let UpdateEmailOpts { did, email } = opts;
455
+
try_join!(
456
+
account::update_email(&did, &email, db.as_ref()),
457
+
email_token::delete_all_email_tokens(&did, db.as_ref())
458
+
)?;
459
+
Ok(())
460
+
}
461
+
462
+
pub async fn assert_valid_email_token(
463
+
&self,
464
+
did: &str,
465
+
purpose: EmailTokenPurpose,
466
+
token: &str,
467
+
) -> Result<()> {
468
+
let db = self.db.clone();
469
+
email_token::assert_valid_token(did, purpose, token, None, db.as_ref()).await
470
+
}
471
+
472
+
pub async fn assert_valid_email_token_and_cleanup(
473
+
&self,
474
+
did: &str,
475
+
purpose: EmailTokenPurpose,
476
+
token: &str,
477
+
) -> Result<()> {
478
+
let db = self.db.clone();
479
+
email_token::assert_valid_token(did, purpose, token, None, db.as_ref()).await?;
480
+
email_token::delete_email_token(did, purpose, db.as_ref()).await
481
+
}
482
+
483
+
pub async fn create_email_token(
484
+
&self,
485
+
did: &str,
486
+
purpose: EmailTokenPurpose,
487
+
) -> Result<String> {
488
+
let db = self.db.clone();
489
+
email_token::create_email_token(did, purpose, db.as_ref()).await
490
+
}
491
+
}
+76
src/actor_endpoints.rs
+76
src/actor_endpoints.rs
···
1
+
use atrium_api::app::bsky::actor;
2
+
use axum::{Json, routing::post};
3
+
use constcat::concat;
4
+
use diesel::prelude::*;
5
+
6
+
use super::*;
7
+
8
+
async fn put_preferences(
9
+
user: AuthenticatedUser,
10
+
State(db): State<Db>,
11
+
Json(input): Json<actor::put_preferences::Input>,
12
+
) -> Result<()> {
13
+
let did = user.did();
14
+
let json_string =
15
+
serde_json::to_string(&input.preferences).context("failed to serialize preferences")?;
16
+
17
+
// Use the db connection pool to execute the update
18
+
let conn = &mut db.get().context("failed to get database connection")?;
19
+
diesel::sql_query("UPDATE accounts SET private_prefs = ? WHERE did = ?")
20
+
.bind::<diesel::sql_types::Text, _>(json_string)
21
+
.bind::<diesel::sql_types::Text, _>(did)
22
+
.execute(conn)
23
+
.context("failed to update user preferences")?;
24
+
25
+
Ok(())
26
+
}
27
+
28
+
async fn get_preferences(
29
+
user: AuthenticatedUser,
30
+
State(db): State<Db>,
31
+
) -> Result<Json<actor::get_preferences::Output>> {
32
+
let did = user.did();
33
+
let conn = &mut db.get().context("failed to get database connection")?;
34
+
35
+
#[derive(QueryableByName)]
36
+
struct Prefs {
37
+
#[diesel(sql_type = diesel::sql_types::Text)]
38
+
private_prefs: Option<String>,
39
+
}
40
+
41
+
let result = diesel::sql_query("SELECT private_prefs FROM accounts WHERE did = ?")
42
+
.bind::<diesel::sql_types::Text, _>(did)
43
+
.get_result::<Prefs>(conn)
44
+
.context("failed to fetch preferences")?;
45
+
46
+
if let Some(prefs_json) = result.private_prefs {
47
+
let prefs: actor::defs::Preferences =
48
+
serde_json::from_str(&prefs_json).context("failed to deserialize preferences")?;
49
+
50
+
Ok(Json(
51
+
actor::get_preferences::OutputData { preferences: prefs }.into(),
52
+
))
53
+
} else {
54
+
Ok(Json(
55
+
actor::get_preferences::OutputData {
56
+
preferences: Vec::new(),
57
+
}
58
+
.into(),
59
+
))
60
+
}
61
+
}
62
+
63
+
/// Register all actor endpoints.
64
+
pub(crate) fn routes() -> Router<AppState> {
65
+
// AP /xrpc/app.bsky.actor.putPreferences
66
+
// AG /xrpc/app.bsky.actor.getPreferences
67
+
Router::new()
68
+
.route(
69
+
concat!("/", actor::put_preferences::NSID),
70
+
post(put_preferences),
71
+
)
72
+
.route(
73
+
concat!("/", actor::get_preferences::NSID),
74
+
get(get_preferences),
75
+
)
76
+
}
+1
-1
src/actor_store/mod.rs
+1
-1
src/actor_store/mod.rs
src/db/mod.rs
src/db.rs
src/db/mod.rs
src/db.rs
+132
-434
src/endpoints/repo/apply_writes.rs
+132
-434
src/endpoints/repo/apply_writes.rs
···
1
1
//! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS.
2
-
use std::{collections::HashSet, str::FromStr};
3
-
2
+
use crate::{
3
+
AppState, Db, Error, Result, SigningKey,
4
+
actor_store::ActorStore,
5
+
actor_store::sql_blob::BlobStoreSql,
6
+
auth::AuthenticatedUser,
7
+
config::AppConfig,
8
+
error::ErrorMessage,
9
+
firehose::{self, FirehoseProducer, RepoOp},
10
+
metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE},
11
+
storage,
12
+
};
13
+
use anyhow::bail;
4
14
use anyhow::{Context as _, anyhow};
5
15
use atrium_api::com::atproto::repo::apply_writes::{self, InputWritesItem, OutputResultsItem};
6
16
use atrium_api::{
···
10
20
string::{AtIdentifier, Nsid, Tid},
11
21
},
12
22
};
13
-
use atrium_repo::{Cid, blockstore::CarStore};
23
+
use atrium_repo::blockstore::CarStore;
14
24
use axum::{
15
25
Json, Router,
16
26
body::Body,
···
18
28
http::{self, StatusCode},
19
29
routing::{get, post},
20
30
};
31
+
use cidv10::Cid;
21
32
use constcat::concat;
22
33
use futures::TryStreamExt as _;
34
+
use futures::stream::{self, StreamExt};
23
35
use metrics::counter;
36
+
use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite};
37
+
use rsky_pds::SharedSequencer;
38
+
use rsky_pds::account_manager::AccountManager;
39
+
use rsky_pds::account_manager::helpers::account::AvailabilityFlags;
40
+
use rsky_pds::apis::ApiError;
41
+
use rsky_pds::auth_verifier::AccessStandardIncludeChecks;
42
+
use rsky_pds::repo::prepare::{
43
+
PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete,
44
+
prepare_update,
45
+
};
46
+
use rsky_repo::types::PreparedWrite;
24
47
use rsky_syntax::aturi::AtUri;
25
48
use serde::Deserialize;
49
+
use std::{collections::HashSet, str::FromStr};
26
50
use tokio::io::AsyncWriteExt as _;
27
51
28
-
use crate::repo::block_map::cid_for_cbor;
29
-
use crate::repo::types::PreparedCreateOrUpdate;
30
-
use crate::{
31
-
AppState, Db, Error, Result, SigningKey,
32
-
actor_store::{ActorStoreTransactor, ActorStoreWriter},
33
-
auth::AuthenticatedUser,
34
-
config::AppConfig,
35
-
error::ErrorMessage,
36
-
firehose::{self, FirehoseProducer, RepoOp},
37
-
metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE},
38
-
repo::types::{PreparedWrite, WriteOpAction},
39
-
storage,
40
-
};
41
-
42
52
use super::resolve_did;
43
53
44
54
/// Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS.
···
57
67
State(config): State<AppConfig>,
58
68
State(db): State<Db>,
59
69
State(fhp): State<FirehoseProducer>,
60
-
Json(input): Json<repo::apply_writes::Input>,
70
+
Json(input): Json<ApplyWritesInput>,
61
71
) -> Result<Json<repo::apply_writes::Output>> {
62
-
todo!();
63
-
// // TODO: `input.validate`
64
-
65
-
// // Resolve DID from identifier
66
-
// let (target_did, _) = resolve_did(&db, &input.repo)
67
-
// .await
68
-
// .context("failed to resolve did")?;
69
-
70
-
// // Ensure that we are updating the correct repository
71
-
// if target_did.as_str() != user.did() {
72
-
// return Err(Error::with_status(
73
-
// StatusCode::BAD_REQUEST,
74
-
// anyhow!("repo did not match the authenticated user"),
75
-
// ));
76
-
// }
77
-
78
-
// // Validate writes count
79
-
// if input.writes.len() > 200 {
80
-
// return Err(Error::with_status(
81
-
// StatusCode::BAD_REQUEST,
82
-
// anyhow!("Too many writes. Max: 200"),
83
-
// ));
84
-
// }
85
-
86
-
// // Convert input writes to prepared format
87
-
// let mut prepared_writes = Vec::with_capacity(input.writes.len());
88
-
// for write in input.writes.iter() {
89
-
// match write {
90
-
// InputWritesItem::Create(create) => {
91
-
// let uri = AtUri::make(
92
-
// user.did(),
93
-
// &create.collection.as_str(),
94
-
// create
95
-
// .rkey
96
-
// .as_deref()
97
-
// .unwrap_or(&Tid::now(LimitedU32::MIN).to_string()),
98
-
// );
99
-
100
-
// let cid = match cid_for_cbor(&create.value) {
101
-
// Ok(cid) => cid,
102
-
// Err(e) => {
103
-
// return Err(Error::with_status(
104
-
// StatusCode::BAD_REQUEST,
105
-
// anyhow!("Failed to encode record: {}", e),
106
-
// ));
107
-
// }
108
-
// };
109
-
110
-
// let blobs = scan_blobs(&create.value)
111
-
// .unwrap_or_default()
112
-
// .into_iter()
113
-
// .map(|cid| {
114
-
// // TODO: Create BlobRef from cid with proper metadata
115
-
// BlobRef {
116
-
// cid,
117
-
// mime_type: "application/octet-stream".to_string(), // Default
118
-
// size: 0, // Unknown at this point
119
-
// }
120
-
// })
121
-
// .collect();
122
-
123
-
// prepared_writes.push(PreparedCreateOrUpdate {
124
-
// action: WriteOpAction::Create,
125
-
// uri: uri?.to_string(),
126
-
// cid,
127
-
// record: create.value.clone(),
128
-
// blobs,
129
-
// swap_cid: None,
130
-
// });
131
-
// }
132
-
// InputWritesItem::Update(update) => {
133
-
// let uri = AtUri::make(
134
-
// user.did(),
135
-
// Some(update.collection.to_string()),
136
-
// Some(update.rkey.to_string()),
137
-
// );
138
-
139
-
// let cid = match cid_for_cbor(&update.value) {
140
-
// Ok(cid) => cid,
141
-
// Err(e) => {
142
-
// return Err(Error::with_status(
143
-
// StatusCode::BAD_REQUEST,
144
-
// anyhow!("Failed to encode record: {}", e),
145
-
// ));
146
-
// }
147
-
// };
148
-
149
-
// let blobs = scan_blobs(&update.value)
150
-
// .unwrap_or_default()
151
-
// .into_iter()
152
-
// .map(|cid| {
153
-
// // TODO: Create BlobRef from cid with proper metadata
154
-
// BlobRef {
155
-
// cid,
156
-
// mime_type: "application/octet-stream".to_string(),
157
-
// size: 0,
158
-
// }
159
-
// })
160
-
// .collect();
161
-
162
-
// prepared_writes.push(PreparedCreateOrUpdate {
163
-
// action: WriteOpAction::Update,
164
-
// uri: uri?.to_string(),
165
-
// cid,
166
-
// record: update.value.clone(),
167
-
// blobs,
168
-
// swap_cid: None,
169
-
// });
170
-
// }
171
-
// InputWritesItem::Delete(delete) => {
172
-
// let uri = AtUri::make(user.did(), &delete.collection.as_str(), &delete.rkey);
173
-
174
-
// prepared_writes.push(PreparedCreateOrUpdate {
175
-
// action: WriteOpAction::Delete,
176
-
// uri: uri?.to_string(),
177
-
// cid: Cid::default(), // Not needed for delete
178
-
// record: serde_json::Value::Null,
179
-
// blobs: vec![],
180
-
// swap_cid: None,
181
-
// });
182
-
// }
183
-
// }
184
-
// }
185
-
186
-
// // Get swap commit CID if provided
187
-
// let swap_commit_cid = input.swap_commit.as_ref().map(|cid| *cid.as_ref());
188
-
189
-
// let did_str = user.did();
190
-
// let mut repo = storage::open_repo_db(&config.repo, &db, did_str)
191
-
// .await
192
-
// .context("failed to open user repo")?;
193
-
// let orig_cid = repo.root();
194
-
// let orig_rev = repo.commit().rev();
195
-
196
-
// let mut blobs = vec![];
197
-
// let mut res = vec![];
198
-
// let mut ops = vec![];
199
-
200
-
// for write in &prepared_writes {
201
-
// let (builder, key) = match write.action {
202
-
// WriteOpAction::Create => {
203
-
// let key = format!("{}/{}", write.uri.collection, write.uri.rkey);
204
-
// let uri = format!("at://{}/{}", user.did(), key);
205
-
206
-
// let (builder, cid) = repo
207
-
// .add_raw(&key, &write.record)
208
-
// .await
209
-
// .context("failed to add record")?;
210
-
211
-
// // Extract and track blobs
212
-
// if let Ok(new_blobs) = scan_blobs(&write.record) {
213
-
// blobs.extend(
214
-
// new_blobs
215
-
// .into_iter()
216
-
// .map(|blob_cid| (key.clone(), blob_cid)),
217
-
// );
218
-
// }
219
-
220
-
// ops.push(RepoOp::Create {
221
-
// cid,
222
-
// path: key.clone(),
223
-
// });
224
-
225
-
// res.push(OutputResultsItem::CreateResult(Box::new(
226
-
// apply_writes::CreateResultData {
227
-
// cid: atrium_api::types::string::Cid::new(cid),
228
-
// uri,
229
-
// validation_status: None,
230
-
// }
231
-
// .into(),
232
-
// )));
233
-
234
-
// (builder, key)
235
-
// }
236
-
// WriteOpAction::Update => {
237
-
// let key = format!("{}/{}", write.uri.collection, write.uri.rkey);
238
-
// let uri = format!("at://{}/{}", user.did(), key);
239
-
240
-
// let prev = repo
241
-
// .tree()
242
-
// .get(&key)
243
-
// .await
244
-
// .context("failed to search MST")?;
245
-
246
-
// if prev.is_none() {
247
-
// // No existing record, treat as create
248
-
// let (create_builder, cid) = repo
249
-
// .add_raw(&key, &write.record)
250
-
// .await
251
-
// .context("failed to add record")?;
252
-
253
-
// if let Ok(new_blobs) = scan_blobs(&write.record) {
254
-
// blobs.extend(
255
-
// new_blobs
256
-
// .into_iter()
257
-
// .map(|blob_cid| (key.clone(), blob_cid)),
258
-
// );
259
-
// }
260
-
261
-
// ops.push(RepoOp::Create {
262
-
// cid,
263
-
// path: key.clone(),
264
-
// });
265
-
266
-
// res.push(OutputResultsItem::CreateResult(Box::new(
267
-
// apply_writes::CreateResultData {
268
-
// cid: atrium_api::types::string::Cid::new(cid),
269
-
// uri,
270
-
// validation_status: None,
271
-
// }
272
-
// .into(),
273
-
// )));
274
-
275
-
// (create_builder, key)
276
-
// } else {
277
-
// // Update existing record
278
-
// let prev = prev.context("should be able to find previous record")?;
279
-
// let (update_builder, cid) = repo
280
-
// .update_raw(&key, &write.record)
281
-
// .await
282
-
// .context("failed to add record")?;
283
-
284
-
// if let Ok(new_blobs) = scan_blobs(&write.record) {
285
-
// blobs.extend(
286
-
// new_blobs
287
-
// .into_iter()
288
-
// .map(|blob_cid| (key.clone(), blob_cid)),
289
-
// );
290
-
// }
291
-
292
-
// ops.push(RepoOp::Update {
293
-
// cid,
294
-
// path: key.clone(),
295
-
// prev,
296
-
// });
297
-
298
-
// res.push(OutputResultsItem::UpdateResult(Box::new(
299
-
// apply_writes::UpdateResultData {
300
-
// cid: atrium_api::types::string::Cid::new(cid),
301
-
// uri,
302
-
// validation_status: None,
303
-
// }
304
-
// .into(),
305
-
// )));
72
+
let tx: ApplyWritesInput = input;
73
+
let ApplyWritesInput {
74
+
repo,
75
+
validate,
76
+
swap_commit,
77
+
..
78
+
} = tx;
79
+
let account = account_manager
80
+
.get_account(
81
+
&repo,
82
+
Some(AvailabilityFlags {
83
+
include_deactivated: Some(true),
84
+
include_taken_down: None,
85
+
}),
86
+
)
87
+
.await?;
306
88
307
-
// (update_builder, key)
308
-
// }
309
-
// }
310
-
// WriteOpAction::Delete => {
311
-
// let key = format!("{}/{}", write.uri.collection, write.uri.rkey);
89
+
if let Some(account) = account {
90
+
if account.deactivated_at.is_some() {
91
+
return Err(Error::with_message(
92
+
StatusCode::FORBIDDEN,
93
+
anyhow!("Account is deactivated"),
94
+
ErrorMessage::new("AccountDeactivated", "Account is deactivated"),
95
+
));
96
+
}
97
+
let did = account.did;
98
+
if did != user.did() {
99
+
return Err(Error::with_message(
100
+
StatusCode::FORBIDDEN,
101
+
anyhow!("AuthRequiredError"),
102
+
ErrorMessage::new("AuthRequiredError", "Auth required"),
103
+
));
104
+
}
105
+
let did: &String = &did;
106
+
if tx.writes.len() > 200 {
107
+
return Err(Error::with_message(
108
+
StatusCode::BAD_REQUEST,
109
+
anyhow!("Too many writes. Max: 200"),
110
+
ErrorMessage::new("TooManyWrites", "Too many writes. Max: 200"),
111
+
));
112
+
}
312
113
313
-
// let prev = repo
314
-
// .tree()
315
-
// .get(&key)
316
-
// .await
317
-
// .context("failed to search MST")?
318
-
// .context("previous record does not exist")?;
319
-
320
-
// ops.push(RepoOp::Delete {
321
-
// path: key.clone(),
322
-
// prev,
323
-
// });
324
-
325
-
// res.push(OutputResultsItem::DeleteResult(Box::new(
326
-
// apply_writes::DeleteResultData {}.into(),
327
-
// )));
114
+
let writes: Vec<PreparedWrite> = stream::iter(tx.writes)
115
+
.then(|write| async move {
116
+
Ok::<PreparedWrite, anyhow::Error>(match write {
117
+
ApplyWritesInputRefWrite::Create(write) => PreparedWrite::Create(
118
+
prepare_create(PrepareCreateOpts {
119
+
did: did.clone(),
120
+
collection: write.collection,
121
+
rkey: write.rkey,
122
+
swap_cid: None,
123
+
record: serde_json::from_value(write.value)?,
124
+
validate,
125
+
})
126
+
.await?,
127
+
),
128
+
ApplyWritesInputRefWrite::Update(write) => PreparedWrite::Update(
129
+
prepare_update(PrepareUpdateOpts {
130
+
did: did.clone(),
131
+
collection: write.collection,
132
+
rkey: write.rkey,
133
+
swap_cid: None,
134
+
record: serde_json::from_value(write.value)?,
135
+
validate,
136
+
})
137
+
.await?,
138
+
),
139
+
ApplyWritesInputRefWrite::Delete(write) => {
140
+
PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts {
141
+
did: did.clone(),
142
+
collection: write.collection,
143
+
rkey: write.rkey,
144
+
swap_cid: None,
145
+
})?)
146
+
}
147
+
})
148
+
})
149
+
.collect::<Vec<_>>()
150
+
.await
151
+
.into_iter()
152
+
.collect::<Result<Vec<PreparedWrite>, _>>()?;
328
153
329
-
// let builder = repo
330
-
// .delete_raw(&key)
331
-
// .await
332
-
// .context("failed to add record")?;
154
+
let swap_commit_cid = match swap_commit {
155
+
Some(swap_commit) => Some(Cid::from_str(&swap_commit)?),
156
+
None => None,
157
+
};
333
158
334
-
// (builder, key)
335
-
// }
336
-
// };
159
+
let mut actor_store = ActorStore::new(did.clone(), BlobStoreSql::new(did.clone(), db), db);
337
160
338
-
// let sig = skey
339
-
// .sign(&builder.bytes())
340
-
// .context("failed to sign commit")?;
161
+
let commit = actor_store
162
+
.process_writes(writes.clone(), swap_commit_cid)
163
+
.await?;
341
164
342
-
// _ = builder
343
-
// .finalize(sig)
344
-
// .await
345
-
// .context("failed to write signed commit")?;
346
-
// }
347
-
348
-
// // Construct a firehose record
349
-
// let mut mem = Vec::new();
350
-
// let mut store = CarStore::create_with_roots(std::io::Cursor::new(&mut mem), [repo.root()])
351
-
// .await
352
-
// .context("failed to create temp store")?;
353
-
354
-
// // Extract the records out of the user's repository
355
-
// for write in &prepared_writes {
356
-
// let key = format!("{}/{}", write.uri.collection, write.uri.rkey);
357
-
// repo.extract_raw_into(&key, &mut store)
358
-
// .await
359
-
// .context("failed to extract key")?;
360
-
// }
361
-
362
-
// let mut tx = db.begin().await.context("failed to begin transaction")?;
363
-
364
-
// if !swap_commit(
365
-
// &mut *tx,
366
-
// repo.root(),
367
-
// repo.commit().rev(),
368
-
// input.swap_commit.as_ref().map(|cid| *cid.as_ref()),
369
-
// &user.did(),
370
-
// )
371
-
// .await
372
-
// .context("failed to swap commit")?
373
-
// {
374
-
// // This should always succeed.
375
-
// let old = input
376
-
// .swap_commit
377
-
// .clone()
378
-
// .context("swap_commit should always be Some")?;
379
-
380
-
// // The swap failed. Return the old commit and do not update the repository.
381
-
// return Ok(Json(
382
-
// apply_writes::OutputData {
383
-
// results: None,
384
-
// commit: Some(
385
-
// CommitMetaData {
386
-
// cid: old,
387
-
// rev: orig_rev,
388
-
// }
389
-
// .into(),
390
-
// ),
391
-
// }
392
-
// .into(),
393
-
// ));
394
-
// }
395
-
396
-
// // For updates and removals, unlink the old/deleted record from the blob_ref table
397
-
// for op in &ops {
398
-
// match op {
399
-
// &RepoOp::Update { ref path, .. } | &RepoOp::Delete { ref path, .. } => {
400
-
// // FIXME: This may cause issues if a user deletes more than one record referencing the same blob.
401
-
// _ = &sqlx::query!(
402
-
// r#"UPDATE blob_ref SET record = NULL WHERE did = ? AND record = ?"#,
403
-
// did_str,
404
-
// path
405
-
// )
406
-
// .execute(&mut *tx)
407
-
// .await
408
-
// .context("failed to remove blob_ref")?;
409
-
// }
410
-
// &RepoOp::Create { .. } => {}
411
-
// }
412
-
// }
413
-
414
-
// // Process blobs
415
-
// for (key, cid) in &blobs {
416
-
// let cid_str = cid.to_string();
417
-
418
-
// // Handle the case where a new record references an existing blob
419
-
// if sqlx::query!(
420
-
// r#"UPDATE blob_ref SET record = ? WHERE cid = ? AND did = ? AND record IS NULL"#,
421
-
// key,
422
-
// cid_str,
423
-
// did_str,
424
-
// )
425
-
// .execute(&mut *tx)
426
-
// .await
427
-
// .context("failed to update blob_ref")?
428
-
// .rows_affected()
429
-
// == 0
430
-
// {
431
-
// _ = sqlx::query!(
432
-
// r#"INSERT INTO blob_ref (record, cid, did) VALUES (?, ?, ?)"#,
433
-
// key,
434
-
// cid_str,
435
-
// did_str,
436
-
// )
437
-
// .execute(&mut *tx)
438
-
// .await
439
-
// .context("failed to update blob_ref")?;
440
-
// }
441
-
// }
442
-
443
-
// tx.commit()
444
-
// .await
445
-
// .context("failed to commit blob ref to database")?;
446
-
447
-
// // Update counters
448
-
// counter!(REPO_COMMITS).increment(1);
449
-
// for op in &ops {
450
-
// match *op {
451
-
// RepoOp::Create { .. } => counter!(REPO_OP_CREATE).increment(1),
452
-
// RepoOp::Update { .. } => counter!(REPO_OP_UPDATE).increment(1),
453
-
// RepoOp::Delete { .. } => counter!(REPO_OP_DELETE).increment(1),
454
-
// }
455
-
// }
456
-
457
-
// // We've committed the transaction to the database, and the commit is now stored in the user's
458
-
// // canonical repository.
459
-
// // We can now broadcast this on the firehose.
460
-
// fhp.commit(firehose::Commit {
461
-
// car: mem,
462
-
// ops,
463
-
// cid: repo.root(),
464
-
// rev: repo.commit().rev().to_string(),
465
-
// did: atrium_api::types::string::Did::new(user.did()).expect("should be valid DID"),
466
-
// pcid: Some(orig_cid),
467
-
// blobs: blobs.into_iter().map(|(_, cid)| cid).collect::<Vec<_>>(),
468
-
// })
469
-
// .await;
470
-
471
-
// Ok(Json(
472
-
// apply_writes::OutputData {
473
-
// results: Some(res),
474
-
// commit: Some(
475
-
// CommitMetaData {
476
-
// cid: atrium_api::types::string::Cid::new(repo.root()),
477
-
// rev: repo.commit().rev(),
478
-
// }
479
-
// .into(),
480
-
// ),
481
-
// }
482
-
// .into(),
483
-
// ))
165
+
let mut lock = sequencer.sequencer.write().await;
166
+
lock.sequence_commit(did.clone(), commit.clone()).await?;
167
+
account_manager
168
+
.update_repo_root(
169
+
did.to_string(),
170
+
commit.commit_data.cid,
171
+
commit.commit_data.rev,
172
+
)
173
+
.await?;
174
+
Ok(())
175
+
} else {
176
+
Err(Error::with_message(
177
+
StatusCode::NOT_FOUND,
178
+
anyhow!("Could not find repo: `{repo}`"),
179
+
ErrorMessage::new("RepoNotFound", "Could not find repo"),
180
+
))
181
+
}
484
182
}
+38
-97
src/main.rs
+38
-97
src/main.rs
···
1
1
//! PDS implementation.
2
+
mod account_manager;
2
3
mod actor_store;
3
4
mod auth;
4
5
mod config;
···
11
12
mod mmap;
12
13
mod oauth;
13
14
mod plc;
14
-
mod storage;
15
15
#[cfg(test)]
16
16
mod tests;
17
17
···
19
19
///
20
20
/// We shouldn't have to know about any bsky endpoints to store private user data.
21
21
/// This will _very likely_ be changed in the future.
22
-
mod actor_endpoints {
23
-
use atrium_api::app::bsky::actor;
24
-
use axum::{Json, routing::post};
25
-
use constcat::concat;
26
-
27
-
use super::*;
28
-
29
-
async fn put_preferences(
30
-
user: AuthenticatedUser,
31
-
State(db): State<Db>,
32
-
Json(input): Json<actor::put_preferences::Input>,
33
-
) -> Result<()> {
34
-
let did = user.did();
35
-
let prefs = sqlx::types::Json(input.preferences.clone());
36
-
_ = sqlx::query!(
37
-
r#"UPDATE accounts SET private_prefs = ? WHERE did = ?"#,
38
-
prefs,
39
-
did
40
-
)
41
-
.execute(&db)
42
-
.await
43
-
.context("failed to update user preferences")?;
44
-
45
-
Ok(())
46
-
}
47
-
48
-
async fn get_preferences(
49
-
user: AuthenticatedUser,
50
-
State(db): State<Db>,
51
-
) -> Result<Json<actor::get_preferences::Output>> {
52
-
let did = user.did();
53
-
let json: Option<sqlx::types::Json<actor::defs::Preferences>> =
54
-
sqlx::query_scalar("SELECT private_prefs FROM accounts WHERE did = ?")
55
-
.bind(did)
56
-
.fetch_one(&db)
57
-
.await
58
-
.context("failed to fetch preferences")?;
59
-
60
-
if let Some(prefs) = json {
61
-
Ok(Json(
62
-
actor::get_preferences::OutputData {
63
-
preferences: prefs.0,
64
-
}
65
-
.into(),
66
-
))
67
-
} else {
68
-
Ok(Json(
69
-
actor::get_preferences::OutputData {
70
-
preferences: Vec::new(),
71
-
}
72
-
.into(),
73
-
))
74
-
}
75
-
}
76
-
77
-
/// Register all actor endpoints.
78
-
pub(crate) fn routes() -> Router<AppState> {
79
-
// AP /xrpc/app.bsky.actor.putPreferences
80
-
// AG /xrpc/app.bsky.actor.getPreferences
81
-
Router::new()
82
-
.route(
83
-
concat!("/", actor::put_preferences::NSID),
84
-
post(put_preferences),
85
-
)
86
-
.route(
87
-
concat!("/", actor::get_preferences::NSID),
88
-
get(get_preferences),
89
-
)
90
-
}
91
-
}
22
+
mod actor_endpoints;
92
23
93
24
use anyhow::{Context as _, anyhow};
94
25
use atrium_api::types::string::Did;
···
106
37
use clap::Parser;
107
38
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
108
39
use config::AppConfig;
40
+
use diesel::prelude::*;
41
+
use diesel::r2d2::{self, ConnectionManager};
42
+
use diesel_migrations::{EmbeddedMigrations, MigrationHarness, embed_migrations};
109
43
#[expect(clippy::pub_use, clippy::useless_attribute)]
110
44
pub use error::Error;
111
45
use figment::{Figment, providers::Format as _};
···
113
47
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
114
48
use rand::Rng as _;
115
49
use serde::{Deserialize, Serialize};
116
-
use sqlx::{SqlitePool, sqlite::SqliteConnectOptions};
117
50
use std::{
118
51
net::{IpAddr, Ipv4Addr, SocketAddr},
119
52
path::PathBuf,
···
128
61
/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`.
129
62
pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
130
63
64
+
/// Embedded migrations
65
+
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
66
+
131
67
/// The application-wide result type.
132
68
pub type Result<T> = std::result::Result<T, Error>;
133
69
/// The reqwest client type with middleware.
134
70
pub type Client = reqwest_middleware::ClientWithMiddleware;
135
71
/// The database connection pool.
136
-
pub type Db = SqlitePool;
72
+
pub type Db = r2d2::Pool<ConnectionManager<SqliteConnection>>;
137
73
/// The Azure credential type.
138
74
pub type Cred = Arc<dyn TokenCredential>;
139
75
···
451
387
452
388
let cred = azure_identity::DefaultAzureCredential::new()
453
389
.context("failed to create Azure credential")?;
454
-
let opts = SqliteConnectOptions::from_str(&config.db)
455
-
.context("failed to parse database options")?
456
-
.create_if_missing(true);
457
-
let db = SqlitePool::connect_with(opts).await?;
458
390
459
-
sqlx::migrate!()
460
-
.run(&db)
461
-
.await
462
-
.context("failed to apply migrations")?;
391
+
// Create a database connection manager and pool
392
+
let manager = ConnectionManager::<SqliteConnection>::new(&config.db);
393
+
let db = r2d2::Pool::builder()
394
+
.build(manager)
395
+
.context("failed to create database connection pool")?;
396
+
397
+
// Apply pending migrations
398
+
let conn = &mut db
399
+
.get()
400
+
.context("failed to get database connection for migrations")?;
401
+
conn.run_pending_migrations(MIGRATIONS)
402
+
.expect("should be able to run migrations");
463
403
464
404
let (_fh, fhp) = firehose::spawn(client.clone(), config.clone());
465
405
···
495
435
496
436
// Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
497
437
// If so, create an invite code and share it via the console.
498
-
let c = sqlx::query_scalar!(
499
-
r#"
500
-
SELECT
501
-
(SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites)
502
-
AS total_count
503
-
"#
438
+
let conn = &mut db.get().context("failed to get database connection")?;
439
+
440
+
#[derive(QueryableByName)]
441
+
struct TotalCount {
442
+
#[diesel(sql_type = diesel::sql_types::Integer)]
443
+
total_count: i32,
444
+
}
445
+
446
+
let result = diesel::sql_query(
447
+
"SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count",
504
448
)
505
-
.fetch_one(&db)
506
-
.await
449
+
.get_result::<TotalCount>(conn)
507
450
.context("failed to query database")?;
508
451
452
+
let c = result.total_count;
453
+
509
454
#[expect(clippy::print_stdout)]
510
455
if c == 0 {
511
456
let uuid = Uuid::new_v4().to_string();
512
457
513
-
_ = sqlx::query!(
514
-
r#"
515
-
INSERT INTO invites (id, did, count, created_at)
516
-
VALUES (?, NULL, 1, datetime('now'))
517
-
"#,
518
-
uuid,
458
+
diesel::sql_query(
459
+
"INSERT INTO invites (id, did, count, created_at) VALUES (?, NULL, 1, datetime('now'))",
519
460
)
520
-
.execute(&db)
521
-
.await
461
+
.bind::<diesel::sql_types::Text, _>(uuid.clone())
462
+
.execute(conn)
522
463
.context("failed to create new invite code")?;
523
464
524
465
// N.B: This is a sensitive message, so we're bypassing `tracing` here and
-28
src/storage/car.rs
-28
src/storage/car.rs
···
1
-
//! CAR file-based repository storage
2
-
3
-
use anyhow::{Context as _, Result};
4
-
use atrium_repo::blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite, CarStore};
5
-
6
-
use crate::{config::RepoConfig, mmap::MappedFile};
7
-
8
-
/// Open a CAR block store for a given DID.
9
-
pub(crate) async fn open_car_store(
10
-
config: &RepoConfig,
11
-
did: impl AsRef<str>,
12
-
) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> {
13
-
let id = did
14
-
.as_ref()
15
-
.strip_prefix("did:plc:")
16
-
.context("did in unknown format")?;
17
-
18
-
let p = config.path.join(id).with_extension("car");
19
-
20
-
let f = std::fs::File::options()
21
-
.read(true)
22
-
.write(true)
23
-
.open(p)
24
-
.context("failed to open repository file")?;
25
-
let f = MappedFile::new(f).context("failed to map repo")?;
26
-
27
-
CarStore::open(f).await.context("failed to open car store")
28
-
}
-159
src/storage/mod.rs
-159
src/storage/mod.rs
···
1
-
//! `ATProto` user repository datastore functionality.
2
-
3
-
pub(crate) mod car;
4
-
mod sqlite;
5
-
6
-
use anyhow::{Context as _, Result};
7
-
use atrium_repo::{
8
-
Cid, Repository,
9
-
blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite},
10
-
};
11
-
use std::str::FromStr as _;
12
-
13
-
use crate::{Db, config::RepoConfig};
14
-
15
-
// Re-export public items
16
-
pub(crate) use car::open_car_store;
17
-
pub(crate) use sqlite::{SQLiteStore, open_sqlite_store};
18
-
19
-
/// Open a repository for a given DID.
20
-
pub(crate) async fn open_repo_db(
21
-
config: &RepoConfig,
22
-
db: &Db,
23
-
did: impl Into<String>,
24
-
) -> Result<Repository<impl AsyncBlockStoreRead + AsyncBlockStoreWrite>> {
25
-
let did = did.into();
26
-
let cid = sqlx::query_scalar!(
27
-
r#"
28
-
SELECT root FROM accounts
29
-
WHERE did = ?
30
-
"#,
31
-
did
32
-
)
33
-
.fetch_one(db)
34
-
.await
35
-
.context("failed to query database")?;
36
-
37
-
open_repo(
38
-
config,
39
-
did,
40
-
Cid::from_str(&cid).context("should be valid CID")?,
41
-
)
42
-
.await
43
-
}
44
-
45
-
/// Open a repository for a given DID and CID.
46
-
pub(crate) async fn open_repo(
47
-
config: &RepoConfig,
48
-
did: impl Into<String>,
49
-
cid: Cid,
50
-
) -> Result<Repository<impl AsyncBlockStoreRead + AsyncBlockStoreWrite>> {
51
-
let store = open_car_store(config, did.into()).await?;
52
-
Repository::open(store, cid)
53
-
.await
54
-
.context("failed to open repo")
55
-
}
56
-
/// Open a repository for a given DID and CID.
57
-
/// SQLite backend.
58
-
pub(crate) async fn open_repo_sqlite(
59
-
config: &RepoConfig,
60
-
did: impl Into<String>,
61
-
cid: Cid,
62
-
) -> Result<Repository<impl AsyncBlockStoreRead + AsyncBlockStoreWrite>> {
63
-
let store = open_sqlite_store(config, did.into()).await?;
64
-
return Repository::open(store, cid)
65
-
.await
66
-
.context("failed to open repo");
67
-
}
68
-
69
-
/// Open a block store for a given DID.
70
-
pub(crate) async fn open_store(
71
-
config: &RepoConfig,
72
-
did: impl Into<String>,
73
-
) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> {
74
-
let did = did.into();
75
-
76
-
// if config.use_sqlite {
77
-
return open_sqlite_store(config, did.clone()).await;
78
-
// }
79
-
// Default to CAR store
80
-
// open_car_store(config, &did).await
81
-
}
82
-
83
-
/// Create a storage backend for a DID
84
-
pub(crate) async fn create_storage_for_did(
85
-
config: &RepoConfig,
86
-
did_hash: &str,
87
-
) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> {
88
-
// Use standard file structure but change extension based on type
89
-
// if config.use_sqlite {
90
-
// For SQLite, create a new database file
91
-
let db_path = config.path.join(format!("{}.db", did_hash));
92
-
93
-
// Ensure parent directory exists
94
-
if let Some(parent) = db_path.parent() {
95
-
tokio::fs::create_dir_all(parent)
96
-
.await
97
-
.context("failed to create directory")?;
98
-
}
99
-
100
-
// Create SQLite store
101
-
let pool = sqlx::sqlite::SqlitePoolOptions::new()
102
-
.max_connections(5)
103
-
.connect_with(
104
-
sqlx::sqlite::SqliteConnectOptions::new()
105
-
.filename(&db_path)
106
-
.create_if_missing(true),
107
-
)
108
-
.await
109
-
.context("failed to connect to SQLite database")?;
110
-
111
-
// Initialize tables
112
-
_ = sqlx::query(
113
-
"
114
-
CREATE TABLE IF NOT EXISTS blocks (
115
-
cid TEXT PRIMARY KEY NOT NULL,
116
-
data BLOB NOT NULL,
117
-
multicodec INTEGER NOT NULL,
118
-
multihash INTEGER NOT NULL
119
-
);
120
-
CREATE TABLE IF NOT EXISTS tree_nodes (
121
-
repo_did TEXT NOT NULL,
122
-
key TEXT NOT NULL,
123
-
value_cid TEXT NOT NULL,
124
-
PRIMARY KEY (repo_did, key),
125
-
FOREIGN KEY (value_cid) REFERENCES blocks(cid)
126
-
);
127
-
CREATE INDEX IF NOT EXISTS idx_blocks_cid ON blocks(cid);
128
-
CREATE INDEX IF NOT EXISTS idx_tree_nodes_repo ON tree_nodes(repo_did);
129
-
PRAGMA journal_mode=WAL;
130
-
",
131
-
)
132
-
.execute(&pool)
133
-
.await
134
-
.context("failed to create tables")?;
135
-
136
-
Ok(SQLiteStore {
137
-
pool,
138
-
did: format!("did:plc:{}", did_hash),
139
-
})
140
-
// } else {
141
-
// // For CAR files, create a new file
142
-
// let file_path = config.path.join(format!("{}.car", did_hash));
143
-
144
-
// // Ensure parent directory exists
145
-
// if let Some(parent) = file_path.parent() {
146
-
// tokio::fs::create_dir_all(parent)
147
-
// .await
148
-
// .context("failed to create directory")?;
149
-
// }
150
-
151
-
// let file = tokio::fs::File::create_new(file_path)
152
-
// .await
153
-
// .context("failed to create repo file")?;
154
-
155
-
// CarStore::create(file)
156
-
// .await
157
-
// .context("failed to create carstore")
158
-
// }
159
-
}
-149
src/storage/sqlite.rs
-149
src/storage/sqlite.rs
···
1
-
//! SQLite-based repository storage implementation.
2
-
3
-
use anyhow::{Context as _, Result};
4
-
use atrium_repo::{
5
-
Cid, Multihash,
6
-
blockstore::{AsyncBlockStoreRead, AsyncBlockStoreWrite, Error as BlockstoreError},
7
-
};
8
-
use sha2::Digest;
9
-
use sqlx::SqlitePool;
10
-
11
-
use crate::config::RepoConfig;
12
-
13
-
/// SQLite-based implementation of block storage.
14
-
pub(crate) struct SQLiteStore {
15
-
pub did: String,
16
-
pub pool: SqlitePool,
17
-
}
18
-
19
-
impl AsyncBlockStoreRead for SQLiteStore {
20
-
async fn read_block(&mut self, cid: Cid) -> Result<Vec<u8>, BlockstoreError> {
21
-
let mut contents = Vec::new();
22
-
self.read_block_into(cid, &mut contents).await?;
23
-
Ok(contents)
24
-
}
25
-
async fn read_block_into(
26
-
&mut self,
27
-
cid: Cid,
28
-
contents: &mut Vec<u8>,
29
-
) -> Result<(), BlockstoreError> {
30
-
let cid_str = cid.to_string();
31
-
let record = sqlx::query!(r#"SELECT data FROM blocks WHERE cid = ?"#, cid_str)
32
-
.fetch_optional(&self.pool)
33
-
.await
34
-
.map_err(|e| BlockstoreError::Other(Box::new(e)))?
35
-
.ok_or(BlockstoreError::CidNotFound)?;
36
-
37
-
contents.clear();
38
-
contents.extend_from_slice(&record.data);
39
-
Ok(())
40
-
}
41
-
}
42
-
43
-
impl AsyncBlockStoreWrite for SQLiteStore {
44
-
async fn write_block(
45
-
&mut self,
46
-
codec: u64,
47
-
hash: u64,
48
-
contents: &[u8],
49
-
) -> Result<Cid, BlockstoreError> {
50
-
let digest = match hash {
51
-
atrium_repo::blockstore::SHA2_256 => sha2::Sha256::digest(&contents),
52
-
_ => return Err(BlockstoreError::UnsupportedHash(hash)),
53
-
};
54
-
55
-
let multihash = Multihash::wrap(hash, digest.as_slice())
56
-
.map_err(|_| BlockstoreError::UnsupportedHash(hash))?;
57
-
58
-
let cid = Cid::new_v1(codec, multihash);
59
-
let cid_str = cid.to_string();
60
-
61
-
// Use a transaction for atomicity
62
-
let mut tx = self
63
-
.pool
64
-
.begin()
65
-
.await
66
-
.map_err(|e| BlockstoreError::Other(Box::new(e)))?;
67
-
68
-
// Check if block already exists
69
-
let exists = sqlx::query_scalar!(r#"SELECT COUNT(*) FROM blocks WHERE cid = ?"#, cid_str)
70
-
.fetch_one(&mut *tx)
71
-
.await
72
-
.map_err(|e| BlockstoreError::Other(Box::new(e)))?;
73
-
74
-
// Only insert if block doesn't exist
75
-
let codec = codec as i64;
76
-
let hash = hash as i64;
77
-
if exists == 0 {
78
-
_ = sqlx::query!(
79
-
r#"INSERT INTO blocks (cid, data, multicodec, multihash) VALUES (?, ?, ?, ?)"#,
80
-
cid_str,
81
-
contents,
82
-
codec,
83
-
hash
84
-
)
85
-
.execute(&mut *tx)
86
-
.await
87
-
.map_err(|e| BlockstoreError::Other(Box::new(e)))?;
88
-
}
89
-
90
-
tx.commit()
91
-
.await
92
-
.map_err(|e| BlockstoreError::Other(Box::new(e)))?;
93
-
94
-
Ok(cid)
95
-
}
96
-
}
97
-
98
-
/// Open a SQLite store for the given DID.
99
-
pub(crate) async fn open_sqlite_store(
100
-
config: &RepoConfig,
101
-
did: impl Into<String>,
102
-
) -> Result<impl AsyncBlockStoreRead + AsyncBlockStoreWrite> {
103
-
tracing::info!("Opening SQLite store for DID");
104
-
let did_str = did.into();
105
-
106
-
// Extract the PLC ID from the DID
107
-
let id = did_str
108
-
.strip_prefix("did:plc:")
109
-
.context("DID in unknown format")?;
110
-
111
-
// Create database connection pool
112
-
let db_path = config.path.join(format!("{id}.db"));
113
-
114
-
let pool = sqlx::sqlite::SqlitePoolOptions::new()
115
-
.max_connections(5)
116
-
.connect_with(
117
-
sqlx::sqlite::SqliteConnectOptions::new()
118
-
.filename(&db_path)
119
-
.create_if_missing(true),
120
-
)
121
-
.await
122
-
.context("failed to connect to SQLite database")?;
123
-
124
-
// Ensure tables exist
125
-
_ = sqlx::query(
126
-
"
127
-
CREATE TABLE IF NOT EXISTS blocks (
128
-
cid TEXT PRIMARY KEY NOT NULL,
129
-
data BLOB NOT NULL,
130
-
multicodec INTEGER NOT NULL,
131
-
multihash INTEGER NOT NULL
132
-
);
133
-
CREATE TABLE IF NOT EXISTS tree_nodes (
134
-
repo_did TEXT NOT NULL,
135
-
key TEXT NOT NULL,
136
-
value_cid TEXT NOT NULL,
137
-
PRIMARY KEY (repo_did, key),
138
-
FOREIGN KEY (value_cid) REFERENCES blocks(cid)
139
-
);
140
-
CREATE INDEX IF NOT EXISTS idx_blocks_cid ON blocks(cid);
141
-
CREATE INDEX IF NOT EXISTS idx_tree_nodes_repo ON tree_nodes(repo_did);
142
-
",
143
-
)
144
-
.execute(&pool)
145
-
.await
146
-
.context("failed to create tables")?;
147
-
148
-
Ok(SQLiteStore { pool, did: did_str })
149
-
}