+2
Cargo.lock
+2
Cargo.lock
···
1310
1310
"reqwest 0.12.15",
1311
1311
"reqwest-middleware",
1312
1312
"rsky-common",
1313
+
"rsky-identity",
1313
1314
"rsky-lexicon",
1314
1315
"rsky-pds",
1315
1316
"rsky-repo",
···
1325
1326
"tower-http",
1326
1327
"tracing",
1327
1328
"tracing-subscriber",
1329
+
"ubyte",
1328
1330
"url",
1329
1331
"urlencoding",
1330
1332
"uuid 1.16.0",
+3
-1
Cargo.toml
+3
-1
Cargo.toml
···
79
79
# unstable-features = "allow"
80
80
# # Temporary Allows
81
81
dead_code = "allow"
82
-
unused_imports = "allow"
82
+
# unused_imports = "allow"
83
83
84
84
[lints.clippy]
85
85
# Groups
···
158
158
rsky-pds = { git = "https://github.com/blacksky-algorithms/rsky.git" }
159
159
rsky-common = { git = "https://github.com/blacksky-algorithms/rsky.git" }
160
160
rsky-lexicon = { git = "https://github.com/blacksky-algorithms/rsky.git" }
161
+
rsky-identity = { git = "https://github.com/blacksky-algorithms/rsky.git" }
161
162
162
163
# async in streams
163
164
# async-stream = "0.3"
···
268
269
"sqlite",
269
270
"tracing",
270
271
] }
272
+
ubyte = "0.10.4"
+31
-118
README.md
+31
-118
README.md
···
11
11
\/_/
12
12
```
13
13
14
-
This is an implementation of an ATProto PDS, built with [Axum](https://github.com/tokio-rs/axum) and [Atrium](https://github.com/sugyan/atrium).
15
-
This PDS implementation uses a SQLite database to store private account information and file storage to store canonical user data.
14
+
This is an implementation of an ATProto PDS, built with [Axum](https://github.com/tokio-rs/axum), [rsky](https://github.com/blacksky-algorithms/rsky/) and [Atrium](https://github.com/sugyan/atrium).
15
+
This PDS implementation uses a SQLite database and [diesel.rs](https://diesel.rs/) ORM to store canonical user data, and file system storage to store user blobs.
16
16
17
17
Heavily inspired by David Buchanan's [millipds](https://github.com/DavidBuchanan314/millipds).
18
-
This implementation forked from the [azure-rust-app](https://github.com/DrChat/azure-rust-app) starter template and the upstream [DrChat/bluepds](https://github.com/DrChat/bluepds).
19
-
See TODO below for this fork's changes from upstream.
18
+
This implementation forked from [DrChat/bluepds](https://github.com/DrChat/bluepds), and now makes heavy use of the [rsky-repo](https://github.com/blacksky-algorithms/rsky/tree/main/rsky-repo) repository implementation.
19
+
The `actor_store` and `account_manager` modules have been reimplemented from [rsky-pds](https://github.com/blacksky-algorithms/rsky/tree/main/rsky-pds) to use a SQLite backend and file storage, which are themselves adapted from the [original Bluesky implementation](https://github.com/bluesky-social/atproto) using SQLite in Typescript.
20
+
20
21
21
22
If you want to see this fork in action, there is a live account hosted by this PDS at [@teq.shatteredsky.net](https://bsky.app/profile/teq.shatteredsky.net)!
22
23
23
24
> [!WARNING]
24
-
> This PDS is undergoing heavy development. Do _NOT_ use this to host your primary account or any important data!
25
+
> This PDS is undergoing heavy development, and this branch is not at an operable release. Do _NOT_ use this to host your primary account or any important data!
25
26
26
27
## Quick Start
27
28
```
···
43
44
- Size: 47 GB
44
45
- VPUs/GB: 10
45
46
46
-
This is about half of the 3,000 OCPU hours and 18,000 GB hours available per month for free on the VM.Standard.A1.Flex shape. This is _without_ optimizing for costs. The PDS can likely be made much cheaper.
47
-
48
-
## Code map
49
-
```
50
-
* migrations/ - SQLite database migrations
51
-
* src/
52
-
* endpoints/ - ATProto API endpoints
53
-
* auth.rs - Authentication primitives
54
-
* config.rs - Application configuration
55
-
* did.rs - Decentralized Identifier helpers
56
-
* error.rs - Axum error helpers
57
-
* firehose.rs - ATProto firehose producer
58
-
* main.rs - Main entrypoint
59
-
* metrics.rs - Definitions for telemetry instruments
60
-
* oauth.rs - OAuth routes
61
-
* plc.rs - Functionality to access the Public Ledger of Credentials
62
-
* storage.rs - Helpers to access user repository storage
63
-
```
47
+
This is about half of the 3,000 OCPU hours and 18,000 GB hours available per month for free on the VM.Standard.A1.Flex shape. This is _without_ optimizing for costs. The PDS can likely be made to run on much less resources.
64
48
65
49
## To-do
66
-
### Teq's fork
67
-
- [ ] OAuth
68
-
- [X] `/.well-known/oauth-protected-resource` - Authorization Server Metadata
69
-
- [X] `/.well-known/oauth-authorization-server`
70
-
- [X] `/par` - Pushed Authorization Request
71
-
- [X] `/client-metadata.json` - Client metadata discovery
72
-
- [X] `/oauth/authorize`
73
-
- [X] `/oauth/authorize/sign-in`
74
-
- [X] `/oauth/token`
75
-
- [ ] Authorization flow - Backend client
76
-
- [X] Authorization flow - Serverless browser app
77
-
- [ ] DPoP-Nonce
78
-
- [ ] Verify JWT signature with JWK
79
-
- [ ] Email verification
80
-
- [ ] 2FA
81
-
- [ ] Admin endpoints
82
-
- [ ] App passwords
83
-
- [X] `listRecords` fixes
84
-
- [X] Fix collection prefixing (terminate with `/`)
85
-
- [X] Fix cursor handling (return `cid` instead of `key`)
86
-
- [X] Session management (JWT)
87
-
- [X] Match token fields to reference implementation
88
-
- [X] RefreshSession from Bluesky Client
89
-
- [X] Respond with JSON error message `ExpiredToken`
90
-
- [X] Cursor handling
91
-
- [X] Implement time-based unix microsecond sequences
92
-
- [X] Startup with present cursor
93
-
- [X] Respond `RecordNotFound`, required for:
94
-
- [X] app.bsky.feed.postgate
95
-
- [X] app.bsky.feed.threadgate
96
-
- [ ] app.bsky... (profile creation?)
97
-
- [X] Linting
98
-
- [X] Rustfmt
99
-
- [X] warnings
100
-
- [X] deprecated-safe
101
-
- [X] future-incompatible
102
-
- [X] keyword-idents
103
-
- [X] let-underscore
104
-
- [X] nonstandard-style
105
-
- [X] refining-impl-trait
106
-
- [X] rust-2018-idioms
107
-
- [X] rust-2018/2021/2024-compatibility
108
-
- [X] ungrouped
109
-
- [X] Clippy
110
-
- [X] nursery
111
-
- [X] correctness
112
-
- [X] suspicious
113
-
- [X] complexity
114
-
- [X] perf
115
-
- [X] style
116
-
- [X] pedantic
117
-
- [X] cargo
118
-
- [X] ungrouped
119
-
120
-
### High-level features
121
-
- [ ] Storage backend abstractions
122
-
- [ ] Azure blob storage backend
123
-
- [ ] Backblaze b2(?)
124
-
- [ ] Telemetry
125
-
- [X] [Metrics](https://github.com/metrics-rs/metrics) (counters/gauges/etc)
126
-
- [X] Exporters for common backends (Prometheus/etc)
127
-
128
50
### APIs
129
-
- [X] [Service proxying](https://atproto.com/specs/xrpc#service-proxying)
130
-
- [X] UG /xrpc/_health (undocumented, but impl by reference PDS)
51
+
- [ ] [Service proxying](https://atproto.com/specs/xrpc#service-proxying)
52
+
- [ ] UG /xrpc/_health (undocumented, but impl by reference PDS)
131
53
<!-- - [ ] xx /xrpc/app.bsky.notification.registerPush
132
54
- app.bsky.actor
133
-
- [X] AG /xrpc/app.bsky.actor.getPreferences
55
+
- [ ] AG /xrpc/app.bsky.actor.getPreferences
134
56
- [ ] xx /xrpc/app.bsky.actor.getProfile
135
57
- [ ] xx /xrpc/app.bsky.actor.getProfiles
136
-
- [X] AP /xrpc/app.bsky.actor.putPreferences
58
+
- [ ] AP /xrpc/app.bsky.actor.putPreferences
137
59
- app.bsky.feed
138
60
- [ ] xx /xrpc/app.bsky.feed.getActorLikes
139
61
- [ ] xx /xrpc/app.bsky.feed.getAuthorFeed
···
157
79
- com.atproto.identity
158
80
- [ ] xx /xrpc/com.atproto.identity.getRecommendedDidCredentials
159
81
- [ ] AP /xrpc/com.atproto.identity.requestPlcOperationSignature
160
-
- [X] UG /xrpc/com.atproto.identity.resolveHandle
82
+
- [ ] UG /xrpc/com.atproto.identity.resolveHandle
161
83
- [ ] AP /xrpc/com.atproto.identity.signPlcOperation
162
84
- [ ] xx /xrpc/com.atproto.identity.submitPlcOperation
163
-
- [X] AP /xrpc/com.atproto.identity.updateHandle
85
+
- [ ] AP /xrpc/com.atproto.identity.updateHandle
164
86
<!-- - com.atproto.moderation
165
87
- [ ] xx /xrpc/com.atproto.moderation.createReport -->
166
88
- com.atproto.repo
···
169
91
- [X] AP /xrpc/com.atproto.repo.deleteRecord
170
92
- [X] UG /xrpc/com.atproto.repo.describeRepo
171
93
- [X] UG /xrpc/com.atproto.repo.getRecord
172
-
- [ ] xx /xrpc/com.atproto.repo.importRepo
173
-
- [ ] xx /xrpc/com.atproto.repo.listMissingBlobs
94
+
- [X] xx /xrpc/com.atproto.repo.importRepo
95
+
- [X] xx /xrpc/com.atproto.repo.listMissingBlobs
174
96
- [X] UG /xrpc/com.atproto.repo.listRecords
175
97
- [X] AP /xrpc/com.atproto.repo.putRecord
176
98
- [X] AP /xrpc/com.atproto.repo.uploadBlob
···
178
100
- [ ] xx /xrpc/com.atproto.server.activateAccount
179
101
- [ ] xx /xrpc/com.atproto.server.checkAccountStatus
180
102
- [ ] xx /xrpc/com.atproto.server.confirmEmail
181
-
- [X] UP /xrpc/com.atproto.server.createAccount
103
+
- [ ] UP /xrpc/com.atproto.server.createAccount
182
104
- [ ] xx /xrpc/com.atproto.server.createAppPassword
183
-
- [X] AP /xrpc/com.atproto.server.createInviteCode
105
+
- [ ] AP /xrpc/com.atproto.server.createInviteCode
184
106
- [ ] xx /xrpc/com.atproto.server.createInviteCodes
185
-
- [X] UP /xrpc/com.atproto.server.createSession
107
+
- [ ] UP /xrpc/com.atproto.server.createSession
186
108
- [ ] xx /xrpc/com.atproto.server.deactivateAccount
187
109
- [ ] xx /xrpc/com.atproto.server.deleteAccount
188
110
- [ ] xx /xrpc/com.atproto.server.deleteSession
189
-
- [X] UG /xrpc/com.atproto.server.describeServer
111
+
- [ ] UG /xrpc/com.atproto.server.describeServer
190
112
- [ ] xx /xrpc/com.atproto.server.getAccountInviteCodes
191
-
- [X] AG /xrpc/com.atproto.server.getServiceAuth
192
-
- [X] AG /xrpc/com.atproto.server.getSession
113
+
- [ ] AG /xrpc/com.atproto.server.getServiceAuth
114
+
- [ ] AG /xrpc/com.atproto.server.getSession
193
115
- [ ] xx /xrpc/com.atproto.server.listAppPasswords
194
116
- [ ] xx /xrpc/com.atproto.server.refreshSession
195
117
- [ ] xx /xrpc/com.atproto.server.requestAccountDelete
···
201
123
- [ ] xx /xrpc/com.atproto.server.revokeAppPassword
202
124
- [ ] xx /xrpc/com.atproto.server.updateEmail
203
125
- com.atproto.sync
204
-
- [X] UG /xrpc/com.atproto.sync.getBlob
205
-
- [X] UG /xrpc/com.atproto.sync.getBlocks
206
-
- [X] UG /xrpc/com.atproto.sync.getLatestCommit
207
-
- [X] UG /xrpc/com.atproto.sync.getRecord
208
-
- [X] UG /xrpc/com.atproto.sync.getRepo
209
-
- [X] UG /xrpc/com.atproto.sync.getRepoStatus
210
-
- [X] UG /xrpc/com.atproto.sync.listBlobs
211
-
- [X] UG /xrpc/com.atproto.sync.listRepos
212
-
- [X] UG /xrpc/com.atproto.sync.subscribeRepos
126
+
- [ ] UG /xrpc/com.atproto.sync.getBlob
127
+
- [ ] UG /xrpc/com.atproto.sync.getBlocks
128
+
- [ ] UG /xrpc/com.atproto.sync.getLatestCommit
129
+
- [ ] UG /xrpc/com.atproto.sync.getRecord
130
+
- [ ] UG /xrpc/com.atproto.sync.getRepo
131
+
- [ ] UG /xrpc/com.atproto.sync.getRepoStatus
132
+
- [ ] UG /xrpc/com.atproto.sync.listBlobs
133
+
- [ ] UG /xrpc/com.atproto.sync.listRepos
134
+
- [ ] UG /xrpc/com.atproto.sync.subscribeRepos
213
135
214
-
## Quick Deployment (Azure CLI)
215
-
```
216
-
az group create --name "webapp" --location southcentralus
217
-
az deployment group create --resource-group "webapp" --template-file .\deployment.bicep --parameters webAppName=testapp
218
-
219
-
az acr login --name <insert name of ACR resource here>
220
-
docker build -t <ACR>.azurecr.io/testapp:latest .
221
-
docker push <ACR>.azurecr.io/testapp:latest
222
-
```
223
-
## Quick Deployment (NixOS)
136
+
## Deployment (NixOS)
224
137
```nix
225
138
{
226
139
inputs = {
-12
migrations/2025-05-15-182818_init_diff/down.sql
-12
migrations/2025-05-15-182818_init_diff/down.sql
···
1
-
-- This file should undo anything in `up.sql`
2
-
DROP TABLE IF EXISTS `oauth_refresh_tokens`;
3
1
DROP TABLE IF EXISTS `repo_seq`;
4
-
DROP TABLE IF EXISTS `blob`;
5
-
DROP TABLE IF EXISTS `oauth_used_jtis`;
6
2
DROP TABLE IF EXISTS `app_password`;
7
-
DROP TABLE IF EXISTS `repo_block`;
8
3
DROP TABLE IF EXISTS `device_account`;
9
-
DROP TABLE IF EXISTS `backlink`;
10
4
DROP TABLE IF EXISTS `actor`;
11
5
DROP TABLE IF EXISTS `device`;
12
6
DROP TABLE IF EXISTS `did_doc`;
13
7
DROP TABLE IF EXISTS `email_token`;
14
8
DROP TABLE IF EXISTS `invite_code`;
15
-
DROP TABLE IF EXISTS `oauth_par_requests`;
16
-
DROP TABLE IF EXISTS `record`;
17
-
DROP TABLE IF EXISTS `repo_root`;
18
9
DROP TABLE IF EXISTS `used_refresh_token`;
19
10
DROP TABLE IF EXISTS `invite_code_use`;
20
-
DROP TABLE IF EXISTS `oauth_authorization_codes`;
21
11
DROP TABLE IF EXISTS `authorization_request`;
22
12
DROP TABLE IF EXISTS `token`;
23
13
DROP TABLE IF EXISTS `refresh_token`;
24
-
DROP TABLE IF EXISTS `account_pref`;
25
-
DROP TABLE IF EXISTS `record_blob`;
26
14
DROP TABLE IF EXISTS `account`;
-108
migrations/2025-05-15-182818_init_diff/up.sql
-108
migrations/2025-05-15-182818_init_diff/up.sql
···
1
-
CREATE TABLE `oauth_refresh_tokens`(
2
-
`token` VARCHAR NOT NULL PRIMARY KEY,
3
-
`client_id` VARCHAR NOT NULL,
4
-
`subject` VARCHAR NOT NULL,
5
-
`dpop_thumbprint` VARCHAR NOT NULL,
6
-
`scope` VARCHAR,
7
-
`created_at` INT8 NOT NULL,
8
-
`expires_at` INT8 NOT NULL,
9
-
`revoked` BOOL NOT NULL
10
-
);
11
-
12
1
CREATE TABLE `repo_seq`(
13
2
`seq` INT8 NOT NULL PRIMARY KEY,
14
3
`did` VARCHAR NOT NULL,
···
18
7
`sequencedat` VARCHAR NOT NULL
19
8
);
20
9
21
-
CREATE TABLE `blob`(
22
-
`cid` VARCHAR NOT NULL,
23
-
`did` VARCHAR NOT NULL,
24
-
`mimetype` VARCHAR NOT NULL,
25
-
`size` INT4 NOT NULL,
26
-
`tempkey` VARCHAR,
27
-
`width` INT4,
28
-
`height` INT4,
29
-
`createdat` VARCHAR NOT NULL,
30
-
`takedownref` VARCHAR,
31
-
PRIMARY KEY(`cid`, `did`)
32
-
);
33
-
34
-
CREATE TABLE `oauth_used_jtis`(
35
-
`jti` VARCHAR NOT NULL PRIMARY KEY,
36
-
`issuer` VARCHAR NOT NULL,
37
-
`created_at` INT8 NOT NULL,
38
-
`expires_at` INT8 NOT NULL
39
-
);
40
-
41
10
CREATE TABLE `app_password`(
42
11
`did` VARCHAR NOT NULL,
43
12
`name` VARCHAR NOT NULL,
···
46
15
PRIMARY KEY(`did`, `name`)
47
16
);
48
17
49
-
CREATE TABLE `repo_block`(
50
-
`cid` VARCHAR NOT NULL,
51
-
`did` VARCHAR NOT NULL,
52
-
`reporev` VARCHAR NOT NULL,
53
-
`size` INT4 NOT NULL,
54
-
`content` BYTEA NOT NULL,
55
-
PRIMARY KEY(`cid`, `did`)
56
-
);
57
-
58
18
CREATE TABLE `device_account`(
59
19
`did` VARCHAR NOT NULL,
60
20
`deviceid` VARCHAR NOT NULL,
···
62
22
`remember` BOOL NOT NULL,
63
23
`authorizedclients` VARCHAR NOT NULL,
64
24
PRIMARY KEY(`deviceId`, `did`)
65
-
);
66
-
67
-
CREATE TABLE `backlink`(
68
-
`uri` VARCHAR NOT NULL,
69
-
`path` VARCHAR NOT NULL,
70
-
`linkto` VARCHAR NOT NULL,
71
-
PRIMARY KEY(`uri`, `path`)
72
25
);
73
26
74
27
CREATE TABLE `actor`(
···
111
64
`createdat` VARCHAR NOT NULL
112
65
);
113
66
114
-
CREATE TABLE `oauth_par_requests`(
115
-
`request_uri` VARCHAR NOT NULL PRIMARY KEY,
116
-
`client_id` VARCHAR NOT NULL,
117
-
`response_type` VARCHAR NOT NULL,
118
-
`code_challenge` VARCHAR NOT NULL,
119
-
`code_challenge_method` VARCHAR NOT NULL,
120
-
`state` VARCHAR,
121
-
`login_hint` VARCHAR,
122
-
`scope` VARCHAR,
123
-
`redirect_uri` VARCHAR,
124
-
`response_mode` VARCHAR,
125
-
`display` VARCHAR,
126
-
`created_at` INT8 NOT NULL,
127
-
`expires_at` INT8 NOT NULL
128
-
);
129
-
130
-
CREATE TABLE `record`(
131
-
`uri` VARCHAR NOT NULL PRIMARY KEY,
132
-
`cid` VARCHAR NOT NULL,
133
-
`did` VARCHAR NOT NULL,
134
-
`collection` VARCHAR NOT NULL,
135
-
`rkey` VARCHAR NOT NULL,
136
-
`reporev` VARCHAR,
137
-
`indexedat` VARCHAR NOT NULL,
138
-
`takedownref` VARCHAR
139
-
);
140
-
141
-
CREATE TABLE `repo_root`(
142
-
`did` VARCHAR NOT NULL PRIMARY KEY,
143
-
`cid` VARCHAR NOT NULL,
144
-
`rev` VARCHAR NOT NULL,
145
-
`indexedat` VARCHAR NOT NULL
146
-
);
147
-
148
67
CREATE TABLE `used_refresh_token`(
149
68
`refreshtoken` VARCHAR NOT NULL PRIMARY KEY,
150
69
`tokenid` VARCHAR NOT NULL
···
157
76
PRIMARY KEY(`code`, `usedBy`)
158
77
);
159
78
160
-
CREATE TABLE `oauth_authorization_codes`(
161
-
`code` VARCHAR NOT NULL PRIMARY KEY,
162
-
`client_id` VARCHAR NOT NULL,
163
-
`subject` VARCHAR NOT NULL,
164
-
`code_challenge` VARCHAR NOT NULL,
165
-
`code_challenge_method` VARCHAR NOT NULL,
166
-
`redirect_uri` VARCHAR NOT NULL,
167
-
`scope` VARCHAR,
168
-
`created_at` INT8 NOT NULL,
169
-
`expires_at` INT8 NOT NULL,
170
-
`used` BOOL NOT NULL
171
-
);
172
-
173
79
CREATE TABLE `authorization_request`(
174
80
`id` VARCHAR NOT NULL PRIMARY KEY,
175
81
`did` VARCHAR,
···
203
109
`expiresat` VARCHAR NOT NULL,
204
110
`nextid` VARCHAR,
205
111
`apppasswordname` VARCHAR
206
-
);
207
-
208
-
CREATE TABLE `account_pref`(
209
-
`id` INT4 NOT NULL PRIMARY KEY,
210
-
`did` VARCHAR NOT NULL,
211
-
`name` VARCHAR NOT NULL,
212
-
`valuejson` TEXT
213
-
);
214
-
215
-
CREATE TABLE `record_blob`(
216
-
`blobcid` VARCHAR NOT NULL,
217
-
`recorduri` VARCHAR NOT NULL,
218
-
`did` VARCHAR NOT NULL,
219
-
PRIMARY KEY(`blobCid`, `recordUri`)
220
112
);
221
113
222
114
CREATE TABLE `account`(
+4
migrations/2025-05-17-094600_oauth_temp/down.sql
+4
migrations/2025-05-17-094600_oauth_temp/down.sql
+46
migrations/2025-05-17-094600_oauth_temp/up.sql
+46
migrations/2025-05-17-094600_oauth_temp/up.sql
···
1
+
CREATE TABLE `oauth_refresh_tokens`(
2
+
`token` VARCHAR NOT NULL PRIMARY KEY,
3
+
`client_id` VARCHAR NOT NULL,
4
+
`subject` VARCHAR NOT NULL,
5
+
`dpop_thumbprint` VARCHAR NOT NULL,
6
+
`scope` VARCHAR,
7
+
`created_at` INT8 NOT NULL,
8
+
`expires_at` INT8 NOT NULL,
9
+
`revoked` BOOL NOT NULL
10
+
);
11
+
12
+
CREATE TABLE `oauth_used_jtis`(
13
+
`jti` VARCHAR NOT NULL PRIMARY KEY,
14
+
`issuer` VARCHAR NOT NULL,
15
+
`created_at` INT8 NOT NULL,
16
+
`expires_at` INT8 NOT NULL
17
+
);
18
+
19
+
CREATE TABLE `oauth_par_requests`(
20
+
`request_uri` VARCHAR NOT NULL PRIMARY KEY,
21
+
`client_id` VARCHAR NOT NULL,
22
+
`response_type` VARCHAR NOT NULL,
23
+
`code_challenge` VARCHAR NOT NULL,
24
+
`code_challenge_method` VARCHAR NOT NULL,
25
+
`state` VARCHAR,
26
+
`login_hint` VARCHAR,
27
+
`scope` VARCHAR,
28
+
`redirect_uri` VARCHAR,
29
+
`response_mode` VARCHAR,
30
+
`display` VARCHAR,
31
+
`created_at` INT8 NOT NULL,
32
+
`expires_at` INT8 NOT NULL
33
+
);
34
+
35
+
CREATE TABLE `oauth_authorization_codes`(
36
+
`code` VARCHAR NOT NULL PRIMARY KEY,
37
+
`client_id` VARCHAR NOT NULL,
38
+
`subject` VARCHAR NOT NULL,
39
+
`code_challenge` VARCHAR NOT NULL,
40
+
`code_challenge_method` VARCHAR NOT NULL,
41
+
`redirect_uri` VARCHAR NOT NULL,
42
+
`scope` VARCHAR,
43
+
`created_at` INT8 NOT NULL,
44
+
`expires_at` INT8 NOT NULL,
45
+
`used` BOOL NOT NULL
46
+
);
+17
-6
src/account_manager/helpers/account.rs
+17
-6
src/account_manager/helpers/account.rs
···
23
23
use thiserror::Error;
24
24
25
25
use diesel::dsl::{LeftJoinOn, exists, not};
26
-
use diesel::helper_types::{Eq, IntoBoxed};
26
+
use diesel::helper_types::Eq;
27
27
28
28
#[derive(Error, Debug)]
29
29
pub enum AccountHelperError {
···
277
277
})
278
278
.await
279
279
.expect("Failed to delete actor")?;
280
-
let did = did.to_owned();
280
+
let did_clone = did.to_owned();
281
281
_ = db
282
282
.get()
283
283
.await?
284
284
.interact(move |conn| {
285
285
_ = delete(EmailTokenSchema::email_token)
286
-
.filter(EmailTokenSchema::did.eq(&did))
286
+
.filter(EmailTokenSchema::did.eq(&did_clone))
287
287
.execute(conn)?;
288
288
_ = delete(RefreshTokenSchema::refresh_token)
289
-
.filter(RefreshTokenSchema::did.eq(&did))
289
+
.filter(RefreshTokenSchema::did.eq(&did_clone))
290
290
.execute(conn)?;
291
291
_ = delete(AccountSchema::account)
292
-
.filter(AccountSchema::did.eq(&did))
292
+
.filter(AccountSchema::did.eq(&did_clone))
293
293
.execute(conn)?;
294
294
delete(ActorSchema::actor)
295
-
.filter(ActorSchema::did.eq(&did))
295
+
.filter(ActorSchema::did.eq(&did_clone))
296
296
.execute(conn)
297
297
})
298
298
.await
299
299
.expect("Failed to delete account")?;
300
+
301
+
let data_repo_file = format!("data/repo/{}.db", did.to_owned());
302
+
let data_blob_path = format!("data/blob/{}", did);
303
+
let data_blob_path = std::path::Path::new(&data_blob_path);
304
+
let data_repo_file = std::path::Path::new(&data_repo_file);
305
+
if data_repo_file.exists() {
306
+
std::fs::remove_file(data_repo_file)?;
307
+
};
308
+
if data_blob_path.exists() {
309
+
std::fs::remove_dir_all(data_blob_path)?;
310
+
};
300
311
Ok(())
301
312
}
302
313
+18
-15
src/account_manager/mod.rs
+18
-15
src/account_manager/mod.rs
···
2
2
//! blacksky-algorithms/rsky is licensed under the Apache License 2.0
3
3
//!
4
4
//! Modified for SQLite backend
5
-
use crate::ActorPools;
6
5
use crate::account_manager::helpers::account::{
7
6
AccountStatus, ActorAccount, AvailabilityFlags, GetAccountAdminStatusOutput,
8
7
};
···
12
11
use crate::account_manager::helpers::invite::CodeDetail;
13
12
use crate::account_manager::helpers::password::UpdateUserPasswordOpts;
14
13
use crate::models::pds::EmailTokenPurpose;
14
+
use crate::serve::ActorStorage;
15
15
use anyhow::Result;
16
-
use axum::extract::FromRef;
17
16
use chrono::DateTime;
18
17
use chrono::offset::Utc as UtcOffset;
19
18
use cidv10::Cid;
20
-
use deadpool_diesel::sqlite::Pool;
21
19
use diesel::*;
22
20
use futures::try_join;
23
21
use helpers::{account, auth, email_token, invite, password, repo};
···
136
134
pub async fn create_account(
137
135
&self,
138
136
opts: CreateAccountOpts,
139
-
actor_pools: &mut std::collections::HashMap<String, ActorPools>,
137
+
actor_pools: &mut std::collections::HashMap<String, ActorStorage>,
140
138
) -> Result<(String, String)> {
141
139
let CreateAccountOpts {
142
140
did,
···
182
180
let did_path = did
183
181
.strip_prefix("did:plc:")
184
182
.ok_or_else(|| anyhow::anyhow!("Invalid DID"))?;
185
-
let path_repo = format!("sqlite://{}", did_path);
183
+
let repo_path = format!("sqlite://data/repo/{}.db", did_path);
186
184
let actor_repo_pool =
187
-
crate::establish_pool(path_repo.as_str()).expect("Failed to establish pool");
188
-
let path_blob = path_repo.replace("repo", "blob");
189
-
let actor_blob_pool = crate::establish_pool(&path_blob).expect("Failed to establish pool");
190
-
let actor_pool = ActorPools {
185
+
crate::db::establish_pool(repo_path.as_str()).expect("Failed to establish pool");
186
+
let blob_path = std::path::Path::new("data/blob").to_path_buf();
187
+
let actor_pool = ActorStorage {
191
188
repo: actor_repo_pool,
192
-
blob: actor_blob_pool,
189
+
blob: blob_path.clone(),
193
190
};
194
-
actor_pools
195
-
.insert(did.clone(), actor_pool)
196
-
.expect("Failed to insert actor pools");
191
+
let blob_path = blob_path.join(did_path);
192
+
tokio::fs::create_dir_all(&blob_path)
193
+
.await
194
+
.map_err(|_| anyhow::anyhow!("Failed to create blob path"))?;
195
+
drop(
196
+
actor_pools
197
+
.insert(did.clone(), actor_pool)
198
+
.expect("Failed to insert actor pools"),
199
+
);
197
200
let db = actor_pools
198
201
.get(&did)
199
202
.ok_or_else(|| anyhow::anyhow!("Actor not found"))?
···
215
218
did: String,
216
219
cid: Cid,
217
220
rev: String,
218
-
actor_pools: &std::collections::HashMap<String, ActorPools>,
221
+
actor_pools: &std::collections::HashMap<String, ActorStorage>,
219
222
) -> Result<()> {
220
223
let db = actor_pools
221
224
.get(&did)
···
228
231
pub async fn delete_account(
229
232
&self,
230
233
did: &str,
231
-
actor_pools: &std::collections::HashMap<String, ActorPools>,
234
+
actor_pools: &std::collections::HashMap<String, ActorStorage>,
232
235
) -> Result<()> {
233
236
let db = actor_pools
234
237
.get(did)
+20
-8
src/actor_endpoints.rs
+20
-8
src/actor_endpoints.rs
···
3
3
/// We shouldn't have to know about any bsky endpoints to store private user data.
4
4
/// This will _very likely_ be changed in the future.
5
5
use atrium_api::app::bsky::actor;
6
-
use axum::{Json, routing::post};
6
+
use axum::{
7
+
Json, Router,
8
+
extract::State,
9
+
routing::{get, post},
10
+
};
7
11
use constcat::concat;
8
-
use diesel::prelude::*;
9
12
10
-
use crate::actor_store::ActorStore;
13
+
use crate::auth::AuthenticatedUser;
11
14
12
-
use super::*;
15
+
use super::serve::*;
13
16
14
17
async fn put_preferences(
15
18
user: AuthenticatedUser,
16
-
State(actor_pools): State<std::collections::HashMap<String, ActorPools>>,
19
+
State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>,
17
20
Json(input): Json<actor::put_preferences::Input>,
18
21
) -> Result<()> {
19
22
let did = user.did();
20
-
let json_string =
21
-
serde_json::to_string(&input.preferences).context("failed to serialize preferences")?;
23
+
// let json_string =
24
+
// serde_json::to_string(&input.preferences).context("failed to serialize preferences")?;
22
25
23
26
// let conn = &mut actor_pools
24
27
// .get(&did)
···
35
38
// .context("failed to update user preferences")
36
39
// });
37
40
todo!("Use actor_store's preferences writer instead");
41
+
// let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
42
+
// let values = actor::defs::Preferences {
43
+
// private_prefs: Some(json_string),
44
+
// ..Default::default()
45
+
// };
46
+
// let namespace = actor::defs::PreferencesNamespace::Private;
47
+
// let scope = actor::defs::PreferencesScope::User;
48
+
// actor_store.pref.put_preferences(values, namespace, scope);
49
+
38
50
Ok(())
39
51
}
40
52
41
53
async fn get_preferences(
42
54
user: AuthenticatedUser,
43
-
State(actor_pools): State<std::collections::HashMap<String, ActorPools>>,
55
+
State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>,
44
56
) -> Result<Json<actor::get_preferences::Output>> {
45
57
let did = user.did();
46
58
// let conn = &mut actor_pools
+9
-7
src/actor_store/blob.rs
+9
-7
src/actor_store/blob.rs
···
6
6
7
7
use crate::models::actor_store as models;
8
8
use anyhow::{Result, bail};
9
+
use axum::body::Bytes;
9
10
use cidv10::Cid;
10
11
use diesel::dsl::{count_distinct, exists, not};
11
12
use diesel::sql_types::{Integer, Nullable, Text};
···
28
29
use rsky_repo::types::{PreparedBlobRef, PreparedWrite};
29
30
use std::str::FromStr as _;
30
31
31
-
use super::sql_blob::{BlobStoreSql, ByteStream};
32
+
use super::blob_fs::{BlobStoreFs, ByteStream};
32
33
33
34
pub struct GetBlobOutput {
34
35
pub size: i32,
···
39
40
/// Handles blob operations for an actor store
40
41
pub struct BlobReader {
41
42
/// SQL-based blob storage
42
-
pub blobstore: BlobStoreSql,
43
+
pub blobstore: BlobStoreFs,
43
44
/// DID of the actor
44
45
pub did: String,
45
46
/// Database connection
···
52
53
impl BlobReader {
53
54
/// Create a new blob reader
54
55
pub fn new(
55
-
blobstore: BlobStoreSql,
56
+
blobstore: BlobStoreFs,
56
57
db: deadpool_diesel::Pool<
57
58
deadpool_diesel::Manager<SqliteConnection>,
58
59
deadpool_diesel::sqlite::Object,
···
138
139
pub async fn upload_blob_and_get_metadata(
139
140
&self,
140
141
user_suggested_mime: String,
141
-
blob: Vec<u8>,
142
+
blob: Bytes,
142
143
) -> Result<BlobMetadata> {
143
144
let bytes = blob;
144
145
let size = bytes.len() as i64;
145
146
146
147
let (temp_key, sha256, img_info, sniffed_mime) = try_join!(
147
148
self.blobstore.put_temp(bytes.clone()),
148
-
sha256_stream(bytes.clone()),
149
-
image::maybe_get_info(bytes.clone()),
150
-
image::mime_type_from_bytes(bytes.clone())
149
+
// TODO: reimpl funcs to use Bytes instead of Vec<u8>
150
+
sha256_stream(bytes.to_vec()),
151
+
image::maybe_get_info(bytes.to_vec()),
152
+
image::mime_type_from_bytes(bytes.to_vec())
151
153
)?;
152
154
153
155
let cid = sha256_raw_to_cid(sha256);
+287
src/actor_store/blob_fs.rs
+287
src/actor_store/blob_fs.rs
···
1
+
//! File system implementation of blob storage
2
+
//! Based on the S3 implementation but using local file system instead
3
+
use anyhow::Result;
4
+
use axum::body::Bytes;
5
+
use cidv10::Cid;
6
+
use rsky_common::get_random_str;
7
+
use rsky_repo::error::BlobError;
8
+
use std::path::PathBuf;
9
+
use std::str::FromStr;
10
+
use tokio::fs as async_fs;
11
+
use tokio::io::AsyncWriteExt;
12
+
use tracing::{debug, error, warn};
13
+
14
+
/// ByteStream implementation for blob data
15
+
pub struct ByteStream {
16
+
pub bytes: Bytes,
17
+
}
18
+
19
+
impl ByteStream {
20
+
/// Create a new ByteStream with the given bytes
21
+
pub const fn new(bytes: Bytes) -> Self {
22
+
Self { bytes }
23
+
}
24
+
25
+
/// Collect the bytes from the stream
26
+
pub async fn collect(self) -> Result<Bytes> {
27
+
Ok(self.bytes)
28
+
}
29
+
}
30
+
31
+
/// Path information for moving a blob
32
+
struct MoveObject {
33
+
from: PathBuf,
34
+
to: PathBuf,
35
+
}
36
+
37
+
/// File system implementation of blob storage
38
+
pub struct BlobStoreFs {
39
+
/// Base directory for storing blobs
40
+
pub base_dir: PathBuf,
41
+
/// DID of the actor
42
+
pub did: String,
43
+
}
44
+
45
+
impl BlobStoreFs {
46
+
/// Create a new file system blob store for the given DID and base directory
47
+
pub const fn new(did: String, base_dir: PathBuf) -> Self {
48
+
Self { base_dir, did }
49
+
}
50
+
51
+
/// Create a factory function for blob stores
52
+
pub fn creator(base_dir: PathBuf) -> Box<dyn Fn(String) -> Self> {
53
+
let base_dir_clone = base_dir;
54
+
Box::new(move |did: String| Self::new(did, base_dir_clone.clone()))
55
+
}
56
+
57
+
/// Generate a random key for temporary storage
58
+
fn gen_key(&self) -> String {
59
+
get_random_str()
60
+
}
61
+
62
+
/// Get path to the temporary blob storage
63
+
fn get_tmp_path(&self, key: &str) -> PathBuf {
64
+
self.base_dir.join("tmp").join(&self.did).join(key)
65
+
}
66
+
67
+
/// Get path to the stored blob with appropriate sharding
68
+
fn get_stored_path(&self, cid: Cid) -> PathBuf {
69
+
let cid_str = cid.to_string();
70
+
71
+
// Create two-level sharded structure based on CID
72
+
// First 10 chars for level 1, next 10 chars for level 2
73
+
let first_level = if cid_str.len() >= 10 {
74
+
&cid_str[0..10]
75
+
} else {
76
+
"short"
77
+
};
78
+
79
+
let second_level = if cid_str.len() >= 20 {
80
+
&cid_str[10..20]
81
+
} else {
82
+
"short"
83
+
};
84
+
85
+
self.base_dir
86
+
.join("blocks")
87
+
.join(&self.did)
88
+
.join(first_level)
89
+
.join(second_level)
90
+
.join(&cid_str)
91
+
}
92
+
93
+
/// Get path to the quarantined blob
94
+
fn get_quarantined_path(&self, cid: Cid) -> PathBuf {
95
+
let cid_str = cid.to_string();
96
+
self.base_dir
97
+
.join("quarantine")
98
+
.join(&self.did)
99
+
.join(&cid_str)
100
+
}
101
+
102
+
/// Store a blob temporarily
103
+
pub async fn put_temp(&self, bytes: Bytes) -> Result<String> {
104
+
let key = self.gen_key();
105
+
let temp_path = self.get_tmp_path(&key);
106
+
107
+
// Ensure the directory exists
108
+
if let Some(parent) = temp_path.parent() {
109
+
async_fs::create_dir_all(parent).await?;
110
+
}
111
+
112
+
// Write the temporary blob
113
+
let mut file = async_fs::File::create(&temp_path).await?;
114
+
file.write_all(&bytes).await?;
115
+
file.flush().await?;
116
+
117
+
debug!("Stored temp blob at: {:?}", temp_path);
118
+
Ok(key)
119
+
}
120
+
121
+
/// Make a temporary blob permanent by moving it to the blob store
122
+
pub async fn make_permanent(&self, key: String, cid: Cid) -> Result<()> {
123
+
let already_has = self.has_stored(cid).await?;
124
+
125
+
if !already_has {
126
+
// Move the temporary blob to permanent storage
127
+
self.move_object(MoveObject {
128
+
from: self.get_tmp_path(&key),
129
+
to: self.get_stored_path(cid),
130
+
})
131
+
.await?;
132
+
debug!("Moved temp blob to permanent: {} -> {}", key, cid);
133
+
} else {
134
+
// Already saved, so just delete the temp
135
+
let temp_path = self.get_tmp_path(&key);
136
+
if temp_path.exists() {
137
+
async_fs::remove_file(temp_path).await?;
138
+
debug!("Deleted temp blob as permanent already exists: {}", key);
139
+
}
140
+
}
141
+
142
+
Ok(())
143
+
}
144
+
145
+
/// Store a blob directly as permanent
146
+
pub async fn put_permanent(&self, cid: Cid, bytes: Bytes) -> Result<()> {
147
+
let target_path = self.get_stored_path(cid);
148
+
149
+
// Ensure the directory exists
150
+
if let Some(parent) = target_path.parent() {
151
+
async_fs::create_dir_all(parent).await?;
152
+
}
153
+
154
+
// Write the blob
155
+
let mut file = async_fs::File::create(&target_path).await?;
156
+
file.write_all(&bytes).await?;
157
+
file.flush().await?;
158
+
159
+
debug!("Stored permanent blob: {}", cid);
160
+
Ok(())
161
+
}
162
+
163
+
/// Quarantine a blob by moving it to the quarantine area
164
+
pub async fn quarantine(&self, cid: Cid) -> Result<()> {
165
+
self.move_object(MoveObject {
166
+
from: self.get_stored_path(cid),
167
+
to: self.get_quarantined_path(cid),
168
+
})
169
+
.await?;
170
+
171
+
debug!("Quarantined blob: {}", cid);
172
+
Ok(())
173
+
}
174
+
175
+
/// Unquarantine a blob by moving it back to regular storage
176
+
pub async fn unquarantine(&self, cid: Cid) -> Result<()> {
177
+
self.move_object(MoveObject {
178
+
from: self.get_quarantined_path(cid),
179
+
to: self.get_stored_path(cid),
180
+
})
181
+
.await?;
182
+
183
+
debug!("Unquarantined blob: {}", cid);
184
+
Ok(())
185
+
}
186
+
187
+
/// Get a blob as a stream
188
+
async fn get_object(&self, cid: Cid) -> Result<ByteStream> {
189
+
let blob_path = self.get_stored_path(cid);
190
+
191
+
match async_fs::read(&blob_path).await {
192
+
Ok(bytes) => Ok(ByteStream::new(Bytes::from(bytes))),
193
+
Err(e) => {
194
+
error!("Failed to read blob at path {:?}: {}", blob_path, e);
195
+
Err(anyhow::Error::new(BlobError::BlobNotFoundError))
196
+
}
197
+
}
198
+
}
199
+
200
+
/// Get blob bytes
201
+
pub async fn get_bytes(&self, cid: Cid) -> Result<Bytes> {
202
+
let stream = self.get_object(cid).await?;
203
+
stream.collect().await
204
+
}
205
+
206
+
/// Get a blob as a stream
207
+
pub async fn get_stream(&self, cid: Cid) -> Result<ByteStream> {
208
+
self.get_object(cid).await
209
+
}
210
+
211
+
/// Delete a blob by CID string
212
+
pub async fn delete(&self, cid_str: String) -> Result<()> {
213
+
match Cid::from_str(&cid_str) {
214
+
Ok(cid) => self.delete_path(self.get_stored_path(cid)).await,
215
+
Err(e) => {
216
+
warn!("Invalid CID: {} - {}", cid_str, e);
217
+
Err(anyhow::anyhow!("Invalid CID: {}", e))
218
+
}
219
+
}
220
+
}
221
+
222
+
/// Delete multiple blobs by CID
223
+
pub async fn delete_many(&self, cids: Vec<Cid>) -> Result<()> {
224
+
let mut futures = Vec::with_capacity(cids.len());
225
+
226
+
for cid in cids {
227
+
futures.push(self.delete_path(self.get_stored_path(cid)));
228
+
}
229
+
230
+
// Execute all delete operations concurrently
231
+
let results = futures::future::join_all(futures).await;
232
+
233
+
// Count errors but don't fail the operation
234
+
let error_count = results.iter().filter(|r| r.is_err()).count();
235
+
if error_count > 0 {
236
+
warn!(
237
+
"{} errors occurred while deleting {} blobs",
238
+
error_count,
239
+
results.len()
240
+
);
241
+
}
242
+
243
+
Ok(())
244
+
}
245
+
246
+
/// Check if a blob is stored in the regular storage
247
+
pub async fn has_stored(&self, cid: Cid) -> Result<bool> {
248
+
let blob_path = self.get_stored_path(cid);
249
+
Ok(blob_path.exists())
250
+
}
251
+
252
+
/// Check if a temporary blob exists
253
+
pub async fn has_temp(&self, key: String) -> Result<bool> {
254
+
let temp_path = self.get_tmp_path(&key);
255
+
Ok(temp_path.exists())
256
+
}
257
+
258
+
/// Helper function to delete a file at the given path
259
+
async fn delete_path(&self, path: PathBuf) -> Result<()> {
260
+
if path.exists() {
261
+
async_fs::remove_file(&path).await?;
262
+
debug!("Deleted file at: {:?}", path);
263
+
Ok(())
264
+
} else {
265
+
Err(anyhow::Error::new(BlobError::BlobNotFoundError))
266
+
}
267
+
}
268
+
269
+
/// Move a blob from one path to another
270
+
async fn move_object(&self, mov: MoveObject) -> Result<()> {
271
+
// Ensure the source exists
272
+
if !mov.from.exists() {
273
+
return Err(anyhow::Error::new(BlobError::BlobNotFoundError));
274
+
}
275
+
276
+
// Ensure the target directory exists
277
+
if let Some(parent) = mov.to.parent() {
278
+
async_fs::create_dir_all(parent).await?;
279
+
}
280
+
281
+
// Move the file
282
+
async_fs::rename(&mov.from, &mov.to).await?;
283
+
284
+
debug!("Moved blob: {:?} -> {:?}", mov.from, mov.to);
285
+
Ok(())
286
+
}
287
+
}
+7
-6
src/actor_store/mod.rs
+7
-6
src/actor_store/mod.rs
···
7
7
//! Modified for SQLite backend
8
8
9
9
mod blob;
10
+
pub(crate) mod blob_fs;
10
11
mod preference;
11
12
mod record;
12
13
pub(crate) mod sql_blob;
···
33
34
use tokio::sync::RwLock;
34
35
35
36
use blob::BlobReader;
37
+
use blob_fs::BlobStoreFs;
36
38
use preference::PreferenceReader;
37
39
use record::RecordReader;
38
-
use sql_blob::BlobStoreSql;
39
40
use sql_repo::SqlRepoReader;
40
41
41
-
use crate::ActorPools;
42
+
use crate::serve::ActorStorage;
42
43
43
44
#[derive(Debug)]
44
45
enum FormatCommitError {
···
73
74
74
75
// Combination of RepoReader/Transactor, BlobReader/Transactor, SqlRepoReader/Transactor
75
76
impl ActorStore {
76
-
/// Concrete reader of an individual repo (hence BlobStoreSql which takes `did` param)
77
+
/// Concrete reader of an individual repo (hence BlobStoreFs which takes `did` param)
77
78
pub fn new(
78
79
did: String,
79
-
blobstore: BlobStoreSql,
80
+
blobstore: BlobStoreFs,
80
81
db: deadpool_diesel::Pool<
81
82
deadpool_diesel::Manager<SqliteConnection>,
82
83
deadpool_diesel::sqlite::Object,
···
95
96
/// Create a new ActorStore taking ActorPools HashMap as input
96
97
pub async fn from_actor_pools(
97
98
did: &String,
98
-
hashmap_actor_pools: &std::collections::HashMap<String, ActorPools>,
99
+
hashmap_actor_pools: &std::collections::HashMap<String, ActorStorage>,
99
100
) -> Self {
100
101
let actor_pool = hashmap_actor_pools
101
102
.get(did)
102
103
.expect("Actor pool not found")
103
104
.clone();
104
-
let blobstore = BlobStoreSql::new(did.clone(), actor_pool.blob);
105
+
let blobstore = BlobStoreFs::new(did.clone(), actor_pool.blob);
105
106
let conn = actor_pool
106
107
.repo
107
108
.clone()
+14
-53
src/apis/com/atproto/repo/apply_writes.rs
+14
-53
src/apis/com/atproto/repo/apply_writes.rs
···
1
1
//! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS.
2
-
use crate::SharedSequencer;
3
-
use crate::account_manager::helpers::account::AvailabilityFlags;
4
-
use crate::account_manager::{AccountManager, AccountManagerCreator, SharedAccountManager};
5
-
use crate::{
6
-
ActorPools, AppState, SigningKey,
7
-
actor_store::{ActorStore, sql_blob::BlobStoreSql},
8
-
auth::AuthenticatedUser,
9
-
config::AppConfig,
10
-
error::{ApiError, ErrorMessage},
11
-
};
12
-
use anyhow::{Result, bail};
13
-
use axum::{
14
-
Json, Router,
15
-
body::Body,
16
-
extract::{Query, Request, State},
17
-
http::{self, StatusCode},
18
-
routing::{get, post},
19
-
};
20
-
use cidv10::Cid;
21
-
use deadpool_diesel::sqlite::Pool;
22
-
use futures::stream::{self, StreamExt};
23
-
use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite};
24
-
use rsky_pds::auth_verifier::AccessStandardIncludeChecks;
25
-
use rsky_pds::repo::prepare::{
26
-
PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete,
27
-
prepare_update,
28
-
};
29
-
use rsky_pds::sequencer::Sequencer;
30
-
use rsky_repo::types::PreparedWrite;
31
-
use std::str::FromStr;
32
-
use std::sync::Arc;
33
-
use tokio::sync::RwLock;
2
+
3
+
use super::*;
34
4
35
5
async fn inner_apply_writes(
36
6
body: ApplyWritesInput,
37
-
user: AuthenticatedUser,
38
-
sequencer: &RwLock<Sequencer>,
39
-
actor_pools: std::collections::HashMap<String, ActorPools>,
40
-
account_manager: &RwLock<AccountManager>,
7
+
auth: AuthenticatedUser,
8
+
sequencer: Arc<RwLock<Sequencer>>,
9
+
actor_pools: HashMap<String, ActorStorage>,
10
+
account_manager: Arc<RwLock<AccountManager>>,
41
11
) -> Result<()> {
42
12
let tx: ApplyWritesInput = body;
43
13
let ApplyWritesInput {
···
63
33
bail!("Account is deactivated")
64
34
}
65
35
let did = account.did;
66
-
if did != user.did() {
36
+
if did != auth.did() {
67
37
bail!("AuthRequiredError")
68
38
}
69
39
let did: &String = &did;
···
72
42
}
73
43
74
44
let writes: Vec<PreparedWrite> = stream::iter(tx.writes)
75
-
.then(|write| async move {
45
+
.then(async |write| {
76
46
Ok::<PreparedWrite, anyhow::Error>(match write {
77
47
ApplyWritesInputRefWrite::Create(write) => PreparedWrite::Create(
78
48
prepare_create(PrepareCreateOpts {
···
122
92
.process_writes(writes.clone(), swap_commit_cid)
123
93
.await?;
124
94
125
-
sequencer
95
+
_ = sequencer
126
96
.write()
127
97
.await
128
98
.sequence_commit(did.clone(), commit.clone())
···
155
125
/// - `swap_commit`: `cid` // If provided, the entire operation will fail if the current repo commit CID does not match this value. Used to prevent conflicting repo mutations.
156
126
#[axum::debug_handler(state = AppState)]
157
127
pub(crate) async fn apply_writes(
158
-
user: AuthenticatedUser,
159
-
State(state): State<AppState>,
128
+
auth: AuthenticatedUser,
129
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
130
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
131
+
State(sequencer): State<Arc<RwLock<Sequencer>>>,
160
132
Json(body): Json<ApplyWritesInput>,
161
133
) -> Result<(), ApiError> {
162
134
tracing::debug!("@LOG: debug apply_writes {body:#?}");
163
-
let db_actors = state.db_actors.clone();
164
-
let sequencer = &state.sequencer.sequencer;
165
-
let account_manager = &state.account_manager.account_manager;
166
-
match inner_apply_writes(
167
-
body,
168
-
user,
169
-
sequencer.clone(),
170
-
db_actors,
171
-
account_manager.clone(),
172
-
)
173
-
.await
174
-
{
135
+
match inner_apply_writes(body, auth, sequencer, actor_pools, account_manager).await {
175
136
Ok(()) => Ok(()),
176
137
Err(error) => {
177
138
tracing::error!("@LOG: ERROR: {error}");
+140
src/apis/com/atproto/repo/create_record.rs
+140
src/apis/com/atproto/repo/create_record.rs
···
1
+
//! Create a single new repository record. Requires auth, implemented by PDS.
2
+
3
+
use super::*;
4
+
5
+
async fn inner_create_record(
6
+
body: CreateRecordInput,
7
+
user: AuthenticatedUser,
8
+
sequencer: Arc<RwLock<Sequencer>>,
9
+
actor_pools: HashMap<String, ActorStorage>,
10
+
account_manager: Arc<RwLock<AccountManager>>,
11
+
) -> Result<CreateRecordOutput> {
12
+
let CreateRecordInput {
13
+
repo,
14
+
collection,
15
+
record,
16
+
rkey,
17
+
validate,
18
+
swap_commit,
19
+
} = body;
20
+
let account = account_manager
21
+
.read()
22
+
.await
23
+
.get_account(
24
+
&repo,
25
+
Some(AvailabilityFlags {
26
+
include_deactivated: Some(true),
27
+
include_taken_down: None,
28
+
}),
29
+
)
30
+
.await?;
31
+
if let Some(account) = account {
32
+
if account.deactivated_at.is_some() {
33
+
bail!("Account is deactivated")
34
+
}
35
+
let did = account.did;
36
+
// if did != auth.access.credentials.unwrap().did.unwrap() {
37
+
if did != user.did() {
38
+
bail!("AuthRequiredError")
39
+
}
40
+
let swap_commit_cid = match swap_commit {
41
+
Some(swap_commit) => Some(Cid::from_str(&swap_commit)?),
42
+
None => None,
43
+
};
44
+
let write = prepare_create(PrepareCreateOpts {
45
+
did: did.clone(),
46
+
collection: collection.clone(),
47
+
record: serde_json::from_value(record)?,
48
+
rkey,
49
+
validate,
50
+
swap_cid: None,
51
+
})
52
+
.await?;
53
+
54
+
let did: &String = &did;
55
+
let mut actor_store = ActorStore::from_actor_pools(did, &actor_pools).await;
56
+
let backlink_conflicts: Vec<AtUri> = match validate {
57
+
Some(true) => {
58
+
let write_at_uri: AtUri = write.uri.clone().try_into()?;
59
+
actor_store
60
+
.record
61
+
.get_backlink_conflicts(&write_at_uri, &write.record)
62
+
.await?
63
+
}
64
+
_ => Vec::new(),
65
+
};
66
+
67
+
let backlink_deletions: Vec<PreparedDelete> = backlink_conflicts
68
+
.iter()
69
+
.map(|at_uri| {
70
+
prepare_delete(PrepareDeleteOpts {
71
+
did: at_uri.get_hostname().to_string(),
72
+
collection: at_uri.get_collection(),
73
+
rkey: at_uri.get_rkey(),
74
+
swap_cid: None,
75
+
})
76
+
})
77
+
.collect::<Result<Vec<PreparedDelete>>>()?;
78
+
let mut writes: Vec<PreparedWrite> = vec![PreparedWrite::Create(write.clone())];
79
+
for delete in backlink_deletions {
80
+
writes.push(PreparedWrite::Delete(delete));
81
+
}
82
+
let commit = actor_store
83
+
.process_writes(writes.clone(), swap_commit_cid)
84
+
.await?;
85
+
86
+
_ = sequencer
87
+
.write()
88
+
.await
89
+
.sequence_commit(did.clone(), commit.clone())
90
+
.await?;
91
+
account_manager
92
+
.write()
93
+
.await
94
+
.update_repo_root(
95
+
did.to_string(),
96
+
commit.commit_data.cid,
97
+
commit.commit_data.rev,
98
+
&actor_pools,
99
+
)
100
+
.await?;
101
+
102
+
Ok(CreateRecordOutput {
103
+
uri: write.uri.clone(),
104
+
cid: write.cid.to_string(),
105
+
})
106
+
} else {
107
+
bail!("Could not find repo: `{repo}`")
108
+
}
109
+
}
110
+
111
+
/// Create a single new repository record. Requires auth, implemented by PDS.
112
+
/// - POST /xrpc/com.atproto.repo.createRecord
113
+
/// ### Request Body
114
+
/// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account).
115
+
/// - `collection`: `nsid` // The NSID of the record collection.
116
+
/// - `rkey`: `string` // The record key. <= 512 characters.
117
+
/// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons.
118
+
/// - `record`
119
+
/// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID.
120
+
/// ### Responses
121
+
/// - 200 OK: {`cid`: `cid`, `uri`: `at-uri`, `commit`: {`cid`: `cid`, `rev`: `tid`}, `validation_status`: [`valid`, `unknown`]}
122
+
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]}
123
+
/// - 401 Unauthorized
124
+
#[axum::debug_handler(state = AppState)]
125
+
pub async fn create_record(
126
+
user: AuthenticatedUser,
127
+
State(db_actors): State<HashMap<String, ActorStorage, RandomState>>,
128
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
129
+
State(sequencer): State<Arc<RwLock<Sequencer>>>,
130
+
Json(body): Json<CreateRecordInput>,
131
+
) -> Result<Json<CreateRecordOutput>, ApiError> {
132
+
tracing::debug!("@LOG: debug create_record {body:#?}");
133
+
match inner_create_record(body, user, sequencer, db_actors, account_manager).await {
134
+
Ok(res) => Ok(Json(res)),
135
+
Err(error) => {
136
+
tracing::error!("@LOG: ERROR: {error}");
137
+
Err(ApiError::RuntimeError)
138
+
}
139
+
}
140
+
}
+117
src/apis/com/atproto/repo/delete_record.rs
+117
src/apis/com/atproto/repo/delete_record.rs
···
1
+
//! Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS.
2
+
use super::*;
3
+
4
+
async fn inner_delete_record(
5
+
body: DeleteRecordInput,
6
+
user: AuthenticatedUser,
7
+
sequencer: Arc<RwLock<Sequencer>>,
8
+
actor_pools: HashMap<String, ActorStorage>,
9
+
account_manager: Arc<RwLock<AccountManager>>,
10
+
) -> Result<()> {
11
+
let DeleteRecordInput {
12
+
repo,
13
+
collection,
14
+
rkey,
15
+
swap_record,
16
+
swap_commit,
17
+
} = body;
18
+
let account = account_manager
19
+
.read()
20
+
.await
21
+
.get_account(
22
+
&repo,
23
+
Some(AvailabilityFlags {
24
+
include_deactivated: Some(true),
25
+
include_taken_down: None,
26
+
}),
27
+
)
28
+
.await?;
29
+
match account {
30
+
None => bail!("Could not find repo: `{repo}`"),
31
+
Some(account) if account.deactivated_at.is_some() => bail!("Account is deactivated"),
32
+
Some(account) => {
33
+
let did = account.did;
34
+
// if did != auth.access.credentials.unwrap().did.unwrap() {
35
+
if did != user.did() {
36
+
bail!("AuthRequiredError")
37
+
}
38
+
39
+
let swap_commit_cid = match swap_commit {
40
+
Some(swap_commit) => Some(Cid::from_str(&swap_commit)?),
41
+
None => None,
42
+
};
43
+
let swap_record_cid = match swap_record {
44
+
Some(swap_record) => Some(Cid::from_str(&swap_record)?),
45
+
None => None,
46
+
};
47
+
48
+
let write = prepare_delete(PrepareDeleteOpts {
49
+
did: did.clone(),
50
+
collection,
51
+
rkey,
52
+
swap_cid: swap_record_cid,
53
+
})?;
54
+
let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
55
+
let write_at_uri: AtUri = write.uri.clone().try_into()?;
56
+
let record = actor_store
57
+
.record
58
+
.get_record(&write_at_uri, None, Some(true))
59
+
.await?;
60
+
let commit = match record {
61
+
None => return Ok(()), // No-op if record already doesn't exist
62
+
Some(_) => {
63
+
actor_store
64
+
.process_writes(vec![PreparedWrite::Delete(write.clone())], swap_commit_cid)
65
+
.await?
66
+
}
67
+
};
68
+
69
+
_ = sequencer
70
+
.write()
71
+
.await
72
+
.sequence_commit(did.clone(), commit.clone())
73
+
.await?;
74
+
account_manager
75
+
.write()
76
+
.await
77
+
.update_repo_root(
78
+
did,
79
+
commit.commit_data.cid,
80
+
commit.commit_data.rev,
81
+
&actor_pools,
82
+
)
83
+
.await?;
84
+
85
+
Ok(())
86
+
}
87
+
}
88
+
}
89
+
90
+
/// Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS.
91
+
/// - POST /xrpc/com.atproto.repo.deleteRecord
92
+
/// ### Request Body
93
+
/// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account).
94
+
/// - `collection`: `nsid` // The NSID of the record collection.
95
+
/// - `rkey`: `string` // The record key. <= 512 characters.
96
+
/// - `swap_record`: `boolean` // Compare and swap with the previous record by CID.
97
+
/// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID.
98
+
/// ### Responses
99
+
/// - 200 OK: {"commit": {"cid": "string","rev": "string"}}
100
+
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]}
101
+
/// - 401 Unauthorized
102
+
#[axum::debug_handler(state = AppState)]
103
+
pub async fn delete_record(
104
+
user: AuthenticatedUser,
105
+
State(db_actors): State<HashMap<String, ActorStorage, RandomState>>,
106
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
107
+
State(sequencer): State<Arc<RwLock<Sequencer>>>,
108
+
Json(body): Json<DeleteRecordInput>,
109
+
) -> Result<(), ApiError> {
110
+
match inner_delete_record(body, user, sequencer, db_actors, account_manager).await {
111
+
Ok(()) => Ok(()),
112
+
Err(error) => {
113
+
tracing::error!("@LOG: ERROR: {error}");
114
+
Err(ApiError::RuntimeError)
115
+
}
116
+
}
117
+
}
+70
src/apis/com/atproto/repo/describe_repo.rs
+70
src/apis/com/atproto/repo/describe_repo.rs
···
1
+
//! Get information about an account and repository, including the list of collections. Does not require auth.
2
+
use super::*;
3
+
4
+
async fn inner_describe_repo(
5
+
repo: String,
6
+
id_resolver: Arc<RwLock<IdResolver>>,
7
+
actor_pools: HashMap<String, ActorStorage>,
8
+
account_manager: Arc<RwLock<AccountManager>>,
9
+
) -> Result<DescribeRepoOutput> {
10
+
let account = account_manager
11
+
.read()
12
+
.await
13
+
.get_account(&repo, None)
14
+
.await?;
15
+
match account {
16
+
None => bail!("Cound not find user: `{repo}`"),
17
+
Some(account) => {
18
+
let did_doc: DidDocument = match id_resolver
19
+
.write()
20
+
.await
21
+
.did
22
+
.ensure_resolve(&account.did, None)
23
+
.await
24
+
{
25
+
Err(err) => bail!("Could not resolve DID: `{err}`"),
26
+
Ok(res) => res,
27
+
};
28
+
let handle = rsky_common::get_handle(&did_doc);
29
+
let handle_is_correct = handle == account.handle;
30
+
31
+
let actor_store =
32
+
ActorStore::from_actor_pools(&account.did.clone(), &actor_pools).await;
33
+
let collections = actor_store.record.list_collections().await?;
34
+
35
+
Ok(DescribeRepoOutput {
36
+
handle: account.handle.unwrap_or_else(|| INVALID_HANDLE.to_owned()),
37
+
did: account.did,
38
+
did_doc: serde_json::to_value(did_doc)?,
39
+
collections,
40
+
handle_is_correct,
41
+
})
42
+
}
43
+
}
44
+
}
45
+
46
+
/// Get information about an account and repository, including the list of collections. Does not require auth.
47
+
/// - GET /xrpc/com.atproto.repo.describeRepo
48
+
/// ### Query Parameters
49
+
/// - `repo`: `at-identifier` // The handle or DID of the repo.
50
+
/// ### Responses
51
+
/// - 200 OK: {"handle": "string","did": "string","didDoc": {},"collections": [string],"handleIsCorrect": true} \
52
+
/// handeIsCorrect - boolean - Indicates if handle is currently valid (resolves bi-directionally)
53
+
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]}
54
+
/// - 401 Unauthorized
55
+
#[tracing::instrument(skip_all)]
56
+
#[axum::debug_handler(state = AppState)]
57
+
pub async fn describe_repo(
58
+
Query(input): Query<atrium_repo::describe_repo::ParametersData>,
59
+
State(db_actors): State<HashMap<String, ActorStorage, RandomState>>,
60
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
61
+
State(id_resolver): State<Arc<RwLock<IdResolver>>>,
62
+
) -> Result<Json<DescribeRepoOutput>, ApiError> {
63
+
match inner_describe_repo(input.repo.into(), id_resolver, db_actors, account_manager).await {
64
+
Ok(res) => Ok(Json(res)),
65
+
Err(error) => {
66
+
tracing::error!("{error:?}");
67
+
Err(ApiError::RuntimeError)
68
+
}
69
+
}
70
+
}
+37
src/apis/com/atproto/repo/ex.rs
+37
src/apis/com/atproto/repo/ex.rs
···
1
+
//!
2
+
use crate::account_manager::AccountManager;
3
+
use crate::serve::ActorStorage;
4
+
use crate::{actor_store::ActorStore, error::ApiError, serve::AppState};
5
+
use anyhow::{Result, bail};
6
+
use axum::extract::Query;
7
+
use axum::{Json, extract::State};
8
+
use rsky_identity::IdResolver;
9
+
use rsky_pds::sequencer::Sequencer;
10
+
use std::collections::HashMap;
11
+
use std::hash::RandomState;
12
+
use std::sync::Arc;
13
+
use tokio::sync::RwLock;
14
+
15
+
async fn fun(
16
+
actor_pools: HashMap<String, ActorStorage>,
17
+
account_manager: Arc<RwLock<AccountManager>>,
18
+
id_resolver: Arc<RwLock<IdResolver>>,
19
+
sequencer: Arc<RwLock<Sequencer>>,
20
+
) -> Result<_> {
21
+
todo!();
22
+
}
23
+
24
+
///
25
+
#[tracing::instrument(skip_all)]
26
+
#[axum::debug_handler(state = AppState)]
27
+
pub async fn fun(
28
+
auth: AuthenticatedUser,
29
+
Query(input): Query<atrium_api::com::atproto::repo::describe_repo::ParametersData>,
30
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
31
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
32
+
State(id_resolver): State<Arc<RwLock<IdResolver>>>,
33
+
State(sequencer): State<Arc<RwLock<Sequencer>>>,
34
+
Json(body): Json<ApplyWritesInput>,
35
+
) -> Result<Json<_>, ApiError> {
36
+
todo!();
37
+
}
+102
src/apis/com/atproto/repo/get_record.rs
+102
src/apis/com/atproto/repo/get_record.rs
···
1
+
//! Get a single record from a repository. Does not require auth.
2
+
3
+
use crate::pipethrough::{ProxyRequest, pipethrough};
4
+
5
+
use super::*;
6
+
7
+
use rsky_pds::pipethrough::OverrideOpts;
8
+
9
+
async fn inner_get_record(
10
+
repo: String,
11
+
collection: String,
12
+
rkey: String,
13
+
cid: Option<String>,
14
+
req: ProxyRequest,
15
+
actor_pools: HashMap<String, ActorStorage>,
16
+
account_manager: Arc<RwLock<AccountManager>>,
17
+
) -> Result<GetRecordOutput> {
18
+
let did = account_manager
19
+
.read()
20
+
.await
21
+
.get_did_for_actor(&repo, None)
22
+
.await?;
23
+
24
+
// fetch from pds if available, if not then fetch from appview
25
+
if let Some(did) = did {
26
+
let uri = AtUri::make(did.clone(), Some(collection), Some(rkey))?;
27
+
28
+
let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
29
+
30
+
match actor_store.record.get_record(&uri, cid, None).await {
31
+
Ok(Some(record)) if record.takedown_ref.is_none() => Ok(GetRecordOutput {
32
+
uri: uri.to_string(),
33
+
cid: Some(record.cid),
34
+
value: serde_json::to_value(record.value)?,
35
+
}),
36
+
_ => bail!("Could not locate record: `{uri}`"),
37
+
}
38
+
} else {
39
+
match req.cfg.bsky_app_view {
40
+
None => bail!("Could not locate record"),
41
+
Some(_) => match pipethrough(
42
+
&req,
43
+
None,
44
+
OverrideOpts {
45
+
aud: None,
46
+
lxm: None,
47
+
},
48
+
)
49
+
.await
50
+
{
51
+
Err(error) => {
52
+
tracing::error!("@LOG: ERROR: {error}");
53
+
bail!("Could not locate record")
54
+
}
55
+
Ok(res) => {
56
+
let output: GetRecordOutput = serde_json::from_slice(res.buffer.as_slice())?;
57
+
Ok(output)
58
+
}
59
+
},
60
+
}
61
+
}
62
+
}
63
+
64
+
/// Get a single record from a repository. Does not require auth.
65
+
/// - GET /xrpc/com.atproto.repo.getRecord
66
+
/// ### Query Parameters
67
+
/// - `repo`: `at-identifier` // The handle or DID of the repo.
68
+
/// - `collection`: `nsid` // The NSID of the record collection.
69
+
/// - `rkey`: `string` // The record key. <= 512 characters.
70
+
/// - `cid`: `cid` // The CID of the version of the record. If not specified, then return the most recent version.
71
+
/// ### Responses
72
+
/// - 200 OK: {"uri": "string","cid": "string","value": {}}
73
+
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`]}
74
+
/// - 401 Unauthorized
75
+
#[tracing::instrument(skip_all)]
76
+
#[axum::debug_handler(state = AppState)]
77
+
pub async fn get_record(
78
+
Query(input): Query<ParametersData>,
79
+
State(db_actors): State<HashMap<String, ActorStorage, RandomState>>,
80
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
81
+
req: ProxyRequest,
82
+
) -> Result<Json<GetRecordOutput>, ApiError> {
83
+
let repo = input.repo;
84
+
let collection = input.collection;
85
+
let rkey = input.rkey;
86
+
let cid = input.cid;
87
+
match inner_get_record(repo, collection, rkey, cid, req, db_actors, account_manager).await {
88
+
Ok(res) => Ok(Json(res)),
89
+
Err(error) => {
90
+
tracing::error!("@LOG: ERROR: {error}");
91
+
Err(ApiError::RecordNotFound)
92
+
}
93
+
}
94
+
}
95
+
96
+
#[derive(serde::Deserialize, Debug)]
97
+
pub struct ParametersData {
98
+
pub cid: Option<String>,
99
+
pub collection: String,
100
+
pub repo: String,
101
+
pub rkey: String,
102
+
}
+183
src/apis/com/atproto/repo/import_repo.rs
+183
src/apis/com/atproto/repo/import_repo.rs
···
1
+
use axum::{body::Bytes, http::HeaderMap};
2
+
use reqwest::header;
3
+
use rsky_common::env::env_int;
4
+
use rsky_repo::block_map::BlockMap;
5
+
use rsky_repo::car::{CarWithRoot, read_stream_car_with_root};
6
+
use rsky_repo::parse::get_and_parse_record;
7
+
use rsky_repo::repo::Repo;
8
+
use rsky_repo::sync::consumer::{VerifyRepoInput, verify_diff};
9
+
use rsky_repo::types::{RecordWriteDescript, VerifiedDiff};
10
+
use ubyte::ToByteUnit;
11
+
12
+
use super::*;
13
+
14
+
async fn from_data(bytes: Bytes) -> Result<CarWithRoot, ApiError> {
15
+
let max_import_size = env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes();
16
+
if bytes.len() > max_import_size {
17
+
return Err(ApiError::InvalidRequest(format!(
18
+
"Content-Length is greater than maximum of {max_import_size}"
19
+
)));
20
+
}
21
+
22
+
let mut cursor = std::io::Cursor::new(bytes);
23
+
match read_stream_car_with_root(&mut cursor).await {
24
+
Ok(car_with_root) => Ok(car_with_root),
25
+
Err(error) => {
26
+
tracing::error!("Error reading stream car with root\n{error}");
27
+
Err(ApiError::InvalidRequest("Invalid CAR file".to_owned()))
28
+
}
29
+
}
30
+
}
31
+
32
+
#[tracing::instrument(skip_all)]
33
+
#[axum::debug_handler(state = AppState)]
34
+
/// Import a repo in the form of a CAR file. Requires Content-Length HTTP header to be set.
35
+
/// Request
36
+
/// mime application/vnd.ipld.car
37
+
/// Body - required
38
+
pub async fn import_repo(
39
+
// auth: AccessFullImport,
40
+
auth: AuthenticatedUser,
41
+
headers: HeaderMap,
42
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
43
+
body: Bytes,
44
+
) -> Result<(), ApiError> {
45
+
// let requester = auth.access.credentials.unwrap().did.unwrap();
46
+
let requester = auth.did();
47
+
let mut actor_store = ActorStore::from_actor_pools(&requester, &actor_pools).await;
48
+
49
+
// Check headers
50
+
let content_length = headers
51
+
.get(header::CONTENT_LENGTH)
52
+
.expect("no content length provided")
53
+
.to_str()
54
+
.map_err(anyhow::Error::from)
55
+
.and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from))
56
+
.expect("invalid content-length header");
57
+
if content_length > env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes() {
58
+
return Err(ApiError::InvalidRequest(format!(
59
+
"Content-Length is greater than maximum of {}",
60
+
env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes()
61
+
)));
62
+
};
63
+
64
+
// Get current repo if it exists
65
+
let curr_root: Option<Cid> = actor_store.get_repo_root().await;
66
+
let curr_repo: Option<Repo> = match curr_root {
67
+
None => None,
68
+
Some(_root) => Some(Repo::load(actor_store.storage.clone(), curr_root).await?),
69
+
};
70
+
71
+
// Process imported car
72
+
// let car_with_root = import_repo_input.car_with_root;
73
+
let car_with_root: CarWithRoot = match from_data(body).await {
74
+
Ok(car) => car,
75
+
Err(error) => {
76
+
tracing::error!("Error importing repo\n{error:?}");
77
+
return Err(ApiError::InvalidRequest("Invalid CAR file".to_owned()));
78
+
}
79
+
};
80
+
81
+
// Get verified difference from current repo and imported repo
82
+
let mut imported_blocks: BlockMap = car_with_root.blocks;
83
+
let imported_root: Cid = car_with_root.root;
84
+
let opts = VerifyRepoInput {
85
+
ensure_leaves: Some(false),
86
+
};
87
+
88
+
let diff: VerifiedDiff = match verify_diff(
89
+
curr_repo,
90
+
&mut imported_blocks,
91
+
imported_root,
92
+
None,
93
+
None,
94
+
Some(opts),
95
+
)
96
+
.await
97
+
{
98
+
Ok(res) => res,
99
+
Err(error) => {
100
+
tracing::error!("{:?}", error);
101
+
return Err(ApiError::RuntimeError);
102
+
}
103
+
};
104
+
105
+
let commit_data = diff.commit;
106
+
let prepared_writes: Vec<PreparedWrite> =
107
+
prepare_import_repo_writes(requester, diff.writes, &imported_blocks).await?;
108
+
match actor_store
109
+
.process_import_repo(commit_data, prepared_writes)
110
+
.await
111
+
{
112
+
Ok(_res) => {}
113
+
Err(error) => {
114
+
tracing::error!("Error importing repo\n{error}");
115
+
return Err(ApiError::RuntimeError);
116
+
}
117
+
}
118
+
119
+
Ok(())
120
+
}
121
+
122
+
/// Converts list of RecordWriteDescripts into a list of PreparedWrites
123
+
async fn prepare_import_repo_writes(
124
+
did: String,
125
+
writes: Vec<RecordWriteDescript>,
126
+
blocks: &BlockMap,
127
+
) -> Result<Vec<PreparedWrite>, ApiError> {
128
+
match stream::iter(writes)
129
+
.then(|write| {
130
+
let did = did.clone();
131
+
async move {
132
+
Ok::<PreparedWrite, anyhow::Error>(match write {
133
+
RecordWriteDescript::Create(write) => {
134
+
let parsed_record = get_and_parse_record(blocks, write.cid)?;
135
+
PreparedWrite::Create(
136
+
prepare_create(PrepareCreateOpts {
137
+
did: did.clone(),
138
+
collection: write.collection,
139
+
rkey: Some(write.rkey),
140
+
swap_cid: None,
141
+
record: parsed_record.record,
142
+
validate: Some(true),
143
+
})
144
+
.await?,
145
+
)
146
+
}
147
+
RecordWriteDescript::Update(write) => {
148
+
let parsed_record = get_and_parse_record(blocks, write.cid)?;
149
+
PreparedWrite::Update(
150
+
prepare_update(PrepareUpdateOpts {
151
+
did: did.clone(),
152
+
collection: write.collection,
153
+
rkey: write.rkey,
154
+
swap_cid: None,
155
+
record: parsed_record.record,
156
+
validate: Some(true),
157
+
})
158
+
.await?,
159
+
)
160
+
}
161
+
RecordWriteDescript::Delete(write) => {
162
+
PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts {
163
+
did: did.clone(),
164
+
collection: write.collection,
165
+
rkey: write.rkey,
166
+
swap_cid: None,
167
+
})?)
168
+
}
169
+
})
170
+
}
171
+
})
172
+
.collect::<Vec<_>>()
173
+
.await
174
+
.into_iter()
175
+
.collect::<Result<Vec<PreparedWrite>, _>>()
176
+
{
177
+
Ok(res) => Ok(res),
178
+
Err(error) => {
179
+
tracing::error!("Error preparing import repo writes\n{error}");
180
+
Err(ApiError::RuntimeError)
181
+
}
182
+
}
183
+
}
+48
src/apis/com/atproto/repo/list_missing_blobs.rs
+48
src/apis/com/atproto/repo/list_missing_blobs.rs
···
1
+
//! Returns a list of missing blobs for the requesting account. Intended to be used in the account migration flow.
2
+
use rsky_lexicon::com::atproto::repo::ListMissingBlobsOutput;
3
+
use rsky_pds::actor_store::blob::ListMissingBlobsOpts;
4
+
5
+
use super::*;
6
+
7
+
/// Returns a list of missing blobs for the requesting account. Intended to be used in the account migration flow.
8
+
/// Request
9
+
/// Query Parameters
10
+
/// limit integer
11
+
/// Possible values: >= 1 and <= 1000
12
+
/// Default value: 500
13
+
/// cursor string
14
+
/// Responses
15
+
/// cursor string
16
+
/// blobs object[]
17
+
#[tracing::instrument(skip_all)]
18
+
#[axum::debug_handler(state = AppState)]
19
+
pub async fn list_missing_blobs(
20
+
user: AuthenticatedUser,
21
+
Query(input): Query<atrium_repo::list_missing_blobs::ParametersData>,
22
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
23
+
) -> Result<Json<ListMissingBlobsOutput>, ApiError> {
24
+
let cursor = input.cursor;
25
+
let limit = input.limit;
26
+
let default_limit: atrium_api::types::LimitedNonZeroU16<1000> =
27
+
atrium_api::types::LimitedNonZeroU16::try_from(500).expect("default limit");
28
+
let limit: u16 = limit.unwrap_or(default_limit).into();
29
+
// let did = auth.access.credentials.unwrap().did.unwrap();
30
+
let did = user.did();
31
+
32
+
let actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
33
+
34
+
match actor_store
35
+
.blob
36
+
.list_missing_blobs(ListMissingBlobsOpts { cursor, limit })
37
+
.await
38
+
{
39
+
Ok(blobs) => {
40
+
let cursor = blobs.last().map(|last_blob| last_blob.cid.clone());
41
+
Ok(Json(ListMissingBlobsOutput { cursor, blobs }))
42
+
}
43
+
Err(error) => {
44
+
tracing::error!("{error:?}");
45
+
Err(ApiError::RuntimeError)
46
+
}
47
+
}
48
+
}
+146
src/apis/com/atproto/repo/list_records.rs
+146
src/apis/com/atproto/repo/list_records.rs
···
1
+
//! List a range of records in a repository, matching a specific collection. Does not require auth.
2
+
use super::*;
3
+
4
+
// #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
5
+
// #[serde(rename_all = "camelCase")]
6
+
// /// Parameters for [`list_records`].
7
+
// pub(super) struct ListRecordsParameters {
8
+
// ///The NSID of the record type.
9
+
// pub collection: Nsid,
10
+
// /// The cursor to start from.
11
+
// #[serde(skip_serializing_if = "core::option::Option::is_none")]
12
+
// pub cursor: Option<String>,
13
+
// ///The number of records to return.
14
+
// #[serde(skip_serializing_if = "core::option::Option::is_none")]
15
+
// pub limit: Option<String>,
16
+
// ///The handle or DID of the repo.
17
+
// pub repo: AtIdentifier,
18
+
// ///Flag to reverse the order of the returned records.
19
+
// #[serde(skip_serializing_if = "core::option::Option::is_none")]
20
+
// pub reverse: Option<bool>,
21
+
// ///DEPRECATED: The highest sort-ordered rkey to stop at (exclusive)
22
+
// #[serde(skip_serializing_if = "core::option::Option::is_none")]
23
+
// pub rkey_end: Option<String>,
24
+
// ///DEPRECATED: The lowest sort-ordered rkey to start from (exclusive)
25
+
// #[serde(skip_serializing_if = "core::option::Option::is_none")]
26
+
// pub rkey_start: Option<String>,
27
+
// }
28
+
29
+
#[expect(non_snake_case, clippy::too_many_arguments)]
30
+
async fn inner_list_records(
31
+
// The handle or DID of the repo.
32
+
repo: String,
33
+
// The NSID of the record type.
34
+
collection: String,
35
+
// The number of records to return.
36
+
limit: u16,
37
+
cursor: Option<String>,
38
+
// DEPRECATED: The lowest sort-ordered rkey to start from (exclusive)
39
+
rkeyStart: Option<String>,
40
+
// DEPRECATED: The highest sort-ordered rkey to stop at (exclusive)
41
+
rkeyEnd: Option<String>,
42
+
// Flag to reverse the order of the returned records.
43
+
reverse: bool,
44
+
// The actor pools
45
+
actor_pools: HashMap<String, ActorStorage>,
46
+
account_manager: Arc<RwLock<AccountManager>>,
47
+
) -> Result<ListRecordsOutput> {
48
+
if limit > 100 {
49
+
bail!("Error: limit can not be greater than 100")
50
+
}
51
+
let did = account_manager
52
+
.read()
53
+
.await
54
+
.get_did_for_actor(&repo, None)
55
+
.await?;
56
+
if let Some(did) = did {
57
+
let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
58
+
59
+
let records: Vec<Record> = actor_store
60
+
.record
61
+
.list_records_for_collection(
62
+
collection,
63
+
limit as i64,
64
+
reverse,
65
+
cursor,
66
+
rkeyStart,
67
+
rkeyEnd,
68
+
None,
69
+
)
70
+
.await?
71
+
.into_iter()
72
+
.map(|record| {
73
+
Ok(Record {
74
+
uri: record.uri.clone(),
75
+
cid: record.cid.clone(),
76
+
value: serde_json::to_value(record)?,
77
+
})
78
+
})
79
+
.collect::<Result<Vec<Record>>>()?;
80
+
81
+
let last_record = records.last();
82
+
let cursor: Option<String>;
83
+
if let Some(last_record) = last_record {
84
+
let last_at_uri: AtUri = last_record.uri.clone().try_into()?;
85
+
cursor = Some(last_at_uri.get_rkey());
86
+
} else {
87
+
cursor = None;
88
+
}
89
+
Ok(ListRecordsOutput { records, cursor })
90
+
} else {
91
+
bail!("Could not find repo: {repo}")
92
+
}
93
+
}
94
+
95
+
/// List a range of records in a repository, matching a specific collection. Does not require auth.
96
+
/// - GET /xrpc/com.atproto.repo.listRecords
97
+
/// ### Query Parameters
98
+
/// - `repo`: `at-identifier` // The handle or DID of the repo.
99
+
/// - `collection`: `nsid` // The NSID of the record type.
100
+
/// - `limit`: `integer` // The maximum number of records to return. Default 50, >=1 and <=100.
101
+
/// - `cursor`: `string`
102
+
/// - `reverse`: `boolean` // Flag to reverse the order of the returned records.
103
+
/// ### Responses
104
+
/// - 200 OK: {"cursor": "string","records": [{"uri": "string","cid": "string","value": {}}]}
105
+
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]}
106
+
/// - 401 Unauthorized
107
+
#[tracing::instrument(skip_all)]
108
+
#[allow(non_snake_case)]
109
+
#[axum::debug_handler(state = AppState)]
110
+
pub async fn list_records(
111
+
Query(input): Query<atrium_repo::list_records::ParametersData>,
112
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
113
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
114
+
) -> Result<Json<ListRecordsOutput>, ApiError> {
115
+
let repo = input.repo;
116
+
let collection = input.collection;
117
+
let limit: Option<u8> = input.limit.map(u8::from);
118
+
let limit: Option<u16> = limit.map(|x| x.into());
119
+
let cursor = input.cursor;
120
+
let reverse = input.reverse;
121
+
let rkeyStart = None;
122
+
let rkeyEnd = None;
123
+
124
+
let limit = limit.unwrap_or(50);
125
+
let reverse = reverse.unwrap_or(false);
126
+
127
+
match inner_list_records(
128
+
repo.into(),
129
+
collection.into(),
130
+
limit,
131
+
cursor,
132
+
rkeyStart,
133
+
rkeyEnd,
134
+
reverse,
135
+
actor_pools,
136
+
account_manager,
137
+
)
138
+
.await
139
+
{
140
+
Ok(res) => Ok(Json(res)),
141
+
Err(error) => {
142
+
tracing::error!("@LOG: ERROR: {error}");
143
+
Err(ApiError::RuntimeError)
144
+
}
145
+
}
146
+
}
+92
-26
src/apis/com/atproto/repo/mod.rs
+92
-26
src/apis/com/atproto/repo/mod.rs
···
1
-
use atrium_api::com::atproto::repo;
2
-
use axum::{Router, routing::post};
1
+
use atrium_api::com::atproto::repo as atrium_repo;
2
+
use axum::{
3
+
Router,
4
+
routing::{get, post},
5
+
};
3
6
use constcat::concat;
4
7
5
-
use crate::AppState;
6
-
7
8
pub mod apply_writes;
8
-
// pub mod create_record;
9
-
// pub mod delete_record;
10
-
// pub mod describe_repo;
11
-
// pub mod get_record;
12
-
// pub mod import_repo;
13
-
// pub mod list_missing_blobs;
14
-
// pub mod list_records;
15
-
// pub mod put_record;
16
-
// pub mod upload_blob;
9
+
pub mod create_record;
10
+
pub mod delete_record;
11
+
pub mod describe_repo;
12
+
pub mod get_record;
13
+
pub mod import_repo;
14
+
pub mod list_missing_blobs;
15
+
pub mod list_records;
16
+
pub mod put_record;
17
+
pub mod upload_blob;
18
+
19
+
use crate::account_manager::AccountManager;
20
+
use crate::account_manager::helpers::account::AvailabilityFlags;
21
+
use crate::{
22
+
actor_store::ActorStore,
23
+
auth::AuthenticatedUser,
24
+
error::ApiError,
25
+
serve::{ActorStorage, AppState},
26
+
};
27
+
use anyhow::{Result, bail};
28
+
use axum::extract::Query;
29
+
use axum::{Json, extract::State};
30
+
use cidv10::Cid;
31
+
use futures::stream::{self, StreamExt};
32
+
use rsky_identity::IdResolver;
33
+
use rsky_identity::types::DidDocument;
34
+
use rsky_lexicon::com::atproto::repo::DeleteRecordInput;
35
+
use rsky_lexicon::com::atproto::repo::DescribeRepoOutput;
36
+
use rsky_lexicon::com::atproto::repo::GetRecordOutput;
37
+
use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite};
38
+
use rsky_lexicon::com::atproto::repo::{CreateRecordInput, CreateRecordOutput};
39
+
use rsky_lexicon::com::atproto::repo::{ListRecordsOutput, Record};
40
+
// use rsky_pds::pipethrough::{OverrideOpts, ProxyRequest, pipethrough};
41
+
use rsky_pds::repo::prepare::{
42
+
PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete,
43
+
prepare_update,
44
+
};
45
+
use rsky_pds::sequencer::Sequencer;
46
+
use rsky_repo::types::PreparedDelete;
47
+
use rsky_repo::types::PreparedWrite;
48
+
use rsky_syntax::aturi::AtUri;
49
+
use rsky_syntax::handle::INVALID_HANDLE;
50
+
use std::collections::HashMap;
51
+
use std::hash::RandomState;
52
+
use std::str::FromStr;
53
+
use std::sync::Arc;
54
+
use tokio::sync::RwLock;
17
55
18
56
/// These endpoints are part of the atproto PDS repository management APIs. \
19
57
/// Requests usually require authentication (unlike the com.atproto.sync.* endpoints), and are made directly to the user's own PDS instance.
···
29
67
/// - [ ] xx /xrpc/com.atproto.repo.importRepo
30
68
// - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs
31
69
pub(crate) fn routes() -> Router<AppState> {
32
-
Router::new().route(
33
-
concat!("/", repo::apply_writes::NSID),
34
-
post(apply_writes::apply_writes),
35
-
)
36
-
// .route(concat!("/", repo::create_record::NSID), post(create_record))
37
-
// .route(concat!("/", repo::put_record::NSID), post(put_record))
38
-
// .route(concat!("/", repo::delete_record::NSID), post(delete_record))
39
-
// .route(concat!("/", repo::upload_blob::NSID), post(upload_blob))
40
-
// .route(concat!("/", repo::describe_repo::NSID), get(describe_repo))
41
-
// .route(concat!("/", repo::get_record::NSID), get(get_record))
42
-
// .route(concat!("/", repo::import_repo::NSID), post(todo))
43
-
// .route(concat!("/", repo::list_missing_blobs::NSID), get(todo))
44
-
// .route(concat!("/", repo::list_records::NSID), get(list_records))
70
+
Router::new()
71
+
.route(
72
+
concat!("/", atrium_repo::apply_writes::NSID),
73
+
post(apply_writes::apply_writes),
74
+
)
75
+
.route(
76
+
concat!("/", atrium_repo::create_record::NSID),
77
+
post(create_record::create_record),
78
+
)
79
+
.route(
80
+
concat!("/", atrium_repo::put_record::NSID),
81
+
post(put_record::put_record),
82
+
)
83
+
.route(
84
+
concat!("/", atrium_repo::delete_record::NSID),
85
+
post(delete_record::delete_record),
86
+
)
87
+
.route(
88
+
concat!("/", atrium_repo::upload_blob::NSID),
89
+
post(upload_blob::upload_blob),
90
+
)
91
+
.route(
92
+
concat!("/", atrium_repo::describe_repo::NSID),
93
+
get(describe_repo::describe_repo),
94
+
)
95
+
.route(
96
+
concat!("/", atrium_repo::get_record::NSID),
97
+
get(get_record::get_record),
98
+
)
99
+
.route(
100
+
concat!("/", atrium_repo::import_repo::NSID),
101
+
post(import_repo::import_repo),
102
+
)
103
+
.route(
104
+
concat!("/", atrium_repo::list_missing_blobs::NSID),
105
+
get(list_missing_blobs::list_missing_blobs),
106
+
)
107
+
.route(
108
+
concat!("/", atrium_repo::list_records::NSID),
109
+
get(list_records::list_records),
110
+
)
45
111
}
+157
src/apis/com/atproto/repo/put_record.rs
+157
src/apis/com/atproto/repo/put_record.rs
···
1
+
//! Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS.
2
+
use anyhow::bail;
3
+
use rsky_lexicon::com::atproto::repo::{PutRecordInput, PutRecordOutput};
4
+
use rsky_repo::types::CommitDataWithOps;
5
+
6
+
use super::*;
7
+
8
+
#[tracing::instrument(skip_all)]
9
+
async fn inner_put_record(
10
+
body: PutRecordInput,
11
+
auth: AuthenticatedUser,
12
+
sequencer: Arc<RwLock<Sequencer>>,
13
+
actor_pools: HashMap<String, ActorStorage>,
14
+
account_manager: Arc<RwLock<AccountManager>>,
15
+
) -> Result<PutRecordOutput> {
16
+
let PutRecordInput {
17
+
repo,
18
+
collection,
19
+
rkey,
20
+
validate,
21
+
record,
22
+
swap_record,
23
+
swap_commit,
24
+
} = body;
25
+
let account = account_manager
26
+
.read()
27
+
.await
28
+
.get_account(
29
+
&repo,
30
+
Some(AvailabilityFlags {
31
+
include_deactivated: Some(true),
32
+
include_taken_down: None,
33
+
}),
34
+
)
35
+
.await?;
36
+
if let Some(account) = account {
37
+
if account.deactivated_at.is_some() {
38
+
bail!("Account is deactivated")
39
+
}
40
+
let did = account.did;
41
+
// if did != auth.access.credentials.unwrap().did.unwrap() {
42
+
if did != auth.did() {
43
+
bail!("AuthRequiredError")
44
+
}
45
+
let uri = AtUri::make(did.clone(), Some(collection.clone()), Some(rkey.clone()))?;
46
+
let swap_commit_cid = match swap_commit {
47
+
Some(swap_commit) => Some(Cid::from_str(&swap_commit)?),
48
+
None => None,
49
+
};
50
+
let swap_record_cid = match swap_record {
51
+
Some(swap_record) => Some(Cid::from_str(&swap_record)?),
52
+
None => None,
53
+
};
54
+
let (commit, write): (Option<CommitDataWithOps>, PreparedWrite) = {
55
+
let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await;
56
+
57
+
let current = actor_store
58
+
.record
59
+
.get_record(&uri, None, Some(true))
60
+
.await?;
61
+
tracing::debug!("@LOG: debug inner_put_record, current: {current:?}");
62
+
let write: PreparedWrite = if current.is_some() {
63
+
PreparedWrite::Update(
64
+
prepare_update(PrepareUpdateOpts {
65
+
did: did.clone(),
66
+
collection,
67
+
rkey,
68
+
swap_cid: swap_record_cid,
69
+
record: serde_json::from_value(record)?,
70
+
validate,
71
+
})
72
+
.await?,
73
+
)
74
+
} else {
75
+
PreparedWrite::Create(
76
+
prepare_create(PrepareCreateOpts {
77
+
did: did.clone(),
78
+
collection,
79
+
rkey: Some(rkey),
80
+
swap_cid: swap_record_cid,
81
+
record: serde_json::from_value(record)?,
82
+
validate,
83
+
})
84
+
.await?,
85
+
)
86
+
};
87
+
88
+
match current {
89
+
Some(current) if current.cid == write.cid().expect("write cid").to_string() => {
90
+
(None, write)
91
+
}
92
+
_ => {
93
+
let commit = actor_store
94
+
.process_writes(vec![write.clone()], swap_commit_cid)
95
+
.await?;
96
+
(Some(commit), write)
97
+
}
98
+
}
99
+
};
100
+
101
+
if let Some(commit) = commit {
102
+
_ = sequencer
103
+
.write()
104
+
.await
105
+
.sequence_commit(did.clone(), commit.clone())
106
+
.await?;
107
+
account_manager
108
+
.write()
109
+
.await
110
+
.update_repo_root(
111
+
did,
112
+
commit.commit_data.cid,
113
+
commit.commit_data.rev,
114
+
&actor_pools,
115
+
)
116
+
.await?;
117
+
}
118
+
Ok(PutRecordOutput {
119
+
uri: write.uri().to_string(),
120
+
cid: write.cid().expect("write cid").to_string(),
121
+
})
122
+
} else {
123
+
bail!("Could not find repo: `{repo}`")
124
+
}
125
+
}
126
+
127
+
/// Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS.
128
+
/// - POST /xrpc/com.atproto.repo.putRecord
129
+
/// ### Request Body
130
+
/// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account).
131
+
/// - `collection`: `nsid` // The NSID of the record collection.
132
+
/// - `rkey`: `string` // The record key. <= 512 characters.
133
+
/// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons.
134
+
/// - `record`
135
+
/// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. WARNING: nullable and optional field; may cause problems with golang implementation
136
+
/// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID.
137
+
/// ### Responses
138
+
/// - 200 OK: {"uri": "string","cid": "string","commit": {"cid": "string","rev": "string"},"validationStatus": "valid | unknown"}
139
+
/// - 400 Bad Request: {error:"`InvalidRequest` | `ExpiredToken` | `InvalidToken` | `InvalidSwap`"}
140
+
/// - 401 Unauthorized
141
+
#[tracing::instrument(skip_all)]
142
+
pub async fn put_record(
143
+
auth: AuthenticatedUser,
144
+
State(sequencer): State<Arc<RwLock<Sequencer>>>,
145
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
146
+
State(account_manager): State<Arc<RwLock<AccountManager>>>,
147
+
Json(body): Json<PutRecordInput>,
148
+
) -> Result<Json<PutRecordOutput>, ApiError> {
149
+
tracing::debug!("@LOG: debug put_record {body:#?}");
150
+
match inner_put_record(body, auth, sequencer, actor_pools, account_manager).await {
151
+
Ok(res) => Ok(Json(res)),
152
+
Err(error) => {
153
+
tracing::error!("@LOG: ERROR: {error}");
154
+
Err(ApiError::RuntimeError)
155
+
}
156
+
}
157
+
}
-514
src/apis/com/atproto/repo/repo.rs
-514
src/apis/com/atproto/repo/repo.rs
···
1
-
//! PDS repository endpoints /xrpc/com.atproto.repo.*)
2
-
mod apply_writes;
3
-
pub(crate) use apply_writes::apply_writes;
4
-
5
-
use std::{collections::HashSet, str::FromStr};
6
-
7
-
use anyhow::{Context as _, anyhow};
8
-
use atrium_api::com::atproto::repo::apply_writes::{
9
-
self as atrium_apply_writes, InputWritesItem, OutputResultsItem,
10
-
};
11
-
use atrium_api::{
12
-
com::atproto::repo::{self, defs::CommitMetaData},
13
-
types::{
14
-
LimitedU32, Object, TryFromUnknown as _, TryIntoUnknown as _, Unknown,
15
-
string::{AtIdentifier, Nsid, Tid},
16
-
},
17
-
};
18
-
use atrium_repo::{Cid, blockstore::CarStore};
19
-
use axum::{
20
-
Json, Router,
21
-
body::Body,
22
-
extract::{Query, Request, State},
23
-
http::{self, StatusCode},
24
-
routing::{get, post},
25
-
};
26
-
use constcat::concat;
27
-
use futures::TryStreamExt as _;
28
-
use metrics::counter;
29
-
use rsky_syntax::aturi::AtUri;
30
-
use serde::Deserialize;
31
-
use tokio::io::AsyncWriteExt as _;
32
-
33
-
use crate::repo::block_map::cid_for_cbor;
34
-
use crate::repo::types::PreparedCreateOrUpdate;
35
-
use crate::{
36
-
AppState, Db, Error, Result, SigningKey,
37
-
actor_store::{ActorStoreTransactor, ActorStoreWriter},
38
-
auth::AuthenticatedUser,
39
-
config::AppConfig,
40
-
error::ErrorMessage,
41
-
firehose::{self, FirehoseProducer, RepoOp},
42
-
metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE},
43
-
repo::types::{PreparedWrite, WriteOpAction},
44
-
storage,
45
-
};
46
-
47
-
#[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)]
48
-
#[serde(rename_all = "camelCase")]
49
-
/// Parameters for [`list_records`].
50
-
pub(super) struct ListRecordsParameters {
51
-
///The NSID of the record type.
52
-
pub collection: Nsid,
53
-
/// The cursor to start from.
54
-
#[serde(skip_serializing_if = "core::option::Option::is_none")]
55
-
pub cursor: Option<String>,
56
-
///The number of records to return.
57
-
#[serde(skip_serializing_if = "core::option::Option::is_none")]
58
-
pub limit: Option<String>,
59
-
///The handle or DID of the repo.
60
-
pub repo: AtIdentifier,
61
-
///Flag to reverse the order of the returned records.
62
-
#[serde(skip_serializing_if = "core::option::Option::is_none")]
63
-
pub reverse: Option<bool>,
64
-
///DEPRECATED: The highest sort-ordered rkey to stop at (exclusive)
65
-
#[serde(skip_serializing_if = "core::option::Option::is_none")]
66
-
pub rkey_end: Option<String>,
67
-
///DEPRECATED: The lowest sort-ordered rkey to start from (exclusive)
68
-
#[serde(skip_serializing_if = "core::option::Option::is_none")]
69
-
pub rkey_start: Option<String>,
70
-
}
71
-
72
-
/// Resolve DID to DID document. Does not bi-directionally verify handle.
73
-
/// - GET /xrpc/com.atproto.repo.resolveDid
74
-
/// ### Query Parameters
75
-
/// - `did`: DID to resolve.
76
-
/// ### Responses
77
-
/// - 200 OK: {`did_doc`: `did_doc`}
78
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `DidNotFound`, `DidDeactivated`]}
79
-
async fn resolve_did(
80
-
db: &Db,
81
-
identifier: &AtIdentifier,
82
-
) -> anyhow::Result<(
83
-
atrium_api::types::string::Did,
84
-
atrium_api::types::string::Handle,
85
-
)> {
86
-
let (handle, did) = match *identifier {
87
-
AtIdentifier::Handle(ref handle) => {
88
-
let handle_as_str = &handle.as_str();
89
-
(
90
-
&handle.to_owned(),
91
-
&atrium_api::types::string::Did::new(
92
-
sqlx::query_scalar!(
93
-
r#"SELECT did FROM handles WHERE handle = ?"#,
94
-
handle_as_str
95
-
)
96
-
.fetch_one(db)
97
-
.await
98
-
.context("failed to query did")?,
99
-
)
100
-
.expect("should be valid DID"),
101
-
)
102
-
}
103
-
AtIdentifier::Did(ref did) => {
104
-
let did_as_str = &did.as_str();
105
-
(
106
-
&atrium_api::types::string::Handle::new(
107
-
sqlx::query_scalar!(r#"SELECT handle FROM handles WHERE did = ?"#, did_as_str)
108
-
.fetch_one(db)
109
-
.await
110
-
.context("failed to query did")?,
111
-
)
112
-
.expect("should be valid handle"),
113
-
&did.to_owned(),
114
-
)
115
-
}
116
-
};
117
-
118
-
Ok((did.to_owned(), handle.to_owned()))
119
-
}
120
-
121
-
/// Create a single new repository record. Requires auth, implemented by PDS.
122
-
/// - POST /xrpc/com.atproto.repo.createRecord
123
-
/// ### Request Body
124
-
/// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account).
125
-
/// - `collection`: `nsid` // The NSID of the record collection.
126
-
/// - `rkey`: `string` // The record key. <= 512 characters.
127
-
/// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons.
128
-
/// - `record`
129
-
/// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID.
130
-
/// ### Responses
131
-
/// - 200 OK: {`cid`: `cid`, `uri`: `at-uri`, `commit`: {`cid`: `cid`, `rev`: `tid`}, `validation_status`: [`valid`, `unknown`]}
132
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]}
133
-
/// - 401 Unauthorized
134
-
async fn create_record(
135
-
user: AuthenticatedUser,
136
-
State(actor_store): State<ActorStore>,
137
-
State(skey): State<SigningKey>,
138
-
State(config): State<AppConfig>,
139
-
State(db): State<Db>,
140
-
State(fhp): State<FirehoseProducer>,
141
-
Json(input): Json<repo::create_record::Input>,
142
-
) -> Result<Json<repo::create_record::Output>> {
143
-
todo!();
144
-
// let write_result = apply_writes::apply_writes(
145
-
// user,
146
-
// State(actor_store),
147
-
// State(skey),
148
-
// State(config),
149
-
// State(db),
150
-
// State(fhp),
151
-
// Json(
152
-
// repo::apply_writes::InputData {
153
-
// repo: input.repo.clone(),
154
-
// validate: input.validate,
155
-
// swap_commit: input.swap_commit.clone(),
156
-
// writes: vec![repo::apply_writes::InputWritesItem::Create(Box::new(
157
-
// repo::apply_writes::CreateData {
158
-
// collection: input.collection.clone(),
159
-
// rkey: input.rkey.clone(),
160
-
// value: input.record.clone(),
161
-
// }
162
-
// .into(),
163
-
// ))],
164
-
// }
165
-
// .into(),
166
-
// ),
167
-
// )
168
-
// .await
169
-
// .context("failed to apply writes")?;
170
-
171
-
// let create_result = if let repo::apply_writes::OutputResultsItem::CreateResult(create_result) =
172
-
// write_result
173
-
// .results
174
-
// .clone()
175
-
// .and_then(|result| result.first().cloned())
176
-
// .context("unexpected output from apply_writes")?
177
-
// {
178
-
// Some(create_result)
179
-
// } else {
180
-
// None
181
-
// }
182
-
// .context("unexpected result from apply_writes")?;
183
-
184
-
// Ok(Json(
185
-
// repo::create_record::OutputData {
186
-
// cid: create_result.cid.clone(),
187
-
// commit: write_result.commit.clone(),
188
-
// uri: create_result.uri.clone(),
189
-
// validation_status: Some("unknown".to_owned()),
190
-
// }
191
-
// .into(),
192
-
// ))
193
-
}
194
-
195
-
/// Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS.
196
-
/// - POST /xrpc/com.atproto.repo.putRecord
197
-
/// ### Request Body
198
-
/// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account).
199
-
/// - `collection`: `nsid` // The NSID of the record collection.
200
-
/// - `rkey`: `string` // The record key. <= 512 characters.
201
-
/// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons.
202
-
/// - `record`
203
-
/// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. WARNING: nullable and optional field; may cause problems with golang implementation
204
-
/// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID.
205
-
/// ### Responses
206
-
/// - 200 OK: {"uri": "string","cid": "string","commit": {"cid": "string","rev": "string"},"validationStatus": "valid | unknown"}
207
-
/// - 400 Bad Request: {error:"`InvalidRequest` | `ExpiredToken` | `InvalidToken` | `InvalidSwap`"}
208
-
/// - 401 Unauthorized
209
-
async fn put_record(
210
-
user: AuthenticatedUser,
211
-
State(actor_store): State<ActorStore>,
212
-
State(skey): State<SigningKey>,
213
-
State(config): State<AppConfig>,
214
-
State(db): State<Db>,
215
-
State(fhp): State<FirehoseProducer>,
216
-
Json(input): Json<repo::put_record::Input>,
217
-
) -> Result<Json<repo::put_record::Output>> {
218
-
todo!();
219
-
// // TODO: `input.swap_record`
220
-
// // FIXME: "put" implies that we will create the record if it does not exist.
221
-
// // We currently only update existing records and/or throw an error if one doesn't exist.
222
-
// let input = (*input).clone();
223
-
// let input = repo::apply_writes::InputData {
224
-
// repo: input.repo,
225
-
// validate: input.validate,
226
-
// swap_commit: input.swap_commit,
227
-
// writes: vec![repo::apply_writes::InputWritesItem::Update(Box::new(
228
-
// repo::apply_writes::UpdateData {
229
-
// collection: input.collection,
230
-
// rkey: input.rkey,
231
-
// value: input.record,
232
-
// }
233
-
// .into(),
234
-
// ))],
235
-
// }
236
-
// .into();
237
-
238
-
// let write_result = apply_writes::apply_writes(
239
-
// user,
240
-
// State(actor_store),
241
-
// State(skey),
242
-
// State(config),
243
-
// State(db),
244
-
// State(fhp),
245
-
// Json(input),
246
-
// )
247
-
// .await
248
-
// .context("failed to apply writes")?;
249
-
250
-
// let update_result = write_result
251
-
// .results
252
-
// .clone()
253
-
// .and_then(|result| result.first().cloned())
254
-
// .context("unexpected output from apply_writes")?;
255
-
// let (cid, uri) = match update_result {
256
-
// repo::apply_writes::OutputResultsItem::CreateResult(create_result) => (
257
-
// Some(create_result.cid.clone()),
258
-
// Some(create_result.uri.clone()),
259
-
// ),
260
-
// repo::apply_writes::OutputResultsItem::UpdateResult(update_result) => (
261
-
// Some(update_result.cid.clone()),
262
-
// Some(update_result.uri.clone()),
263
-
// ),
264
-
// repo::apply_writes::OutputResultsItem::DeleteResult(_) => (None, None),
265
-
// };
266
-
// Ok(Json(
267
-
// repo::put_record::OutputData {
268
-
// cid: cid.context("missing cid")?,
269
-
// commit: write_result.commit.clone(),
270
-
// uri: uri.context("missing uri")?,
271
-
// validation_status: Some("unknown".to_owned()),
272
-
// }
273
-
// .into(),
274
-
// ))
275
-
}
276
-
277
-
/// Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS.
278
-
/// - POST /xrpc/com.atproto.repo.deleteRecord
279
-
/// ### Request Body
280
-
/// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account).
281
-
/// - `collection`: `nsid` // The NSID of the record collection.
282
-
/// - `rkey`: `string` // The record key. <= 512 characters.
283
-
/// - `swap_record`: `boolean` // Compare and swap with the previous record by CID.
284
-
/// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID.
285
-
/// ### Responses
286
-
/// - 200 OK: {"commit": {"cid": "string","rev": "string"}}
287
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]}
288
-
/// - 401 Unauthorized
289
-
async fn delete_record(
290
-
user: AuthenticatedUser,
291
-
State(actor_store): State<ActorStore>,
292
-
State(skey): State<SigningKey>,
293
-
State(config): State<AppConfig>,
294
-
State(db): State<Db>,
295
-
State(fhp): State<FirehoseProducer>,
296
-
Json(input): Json<repo::delete_record::Input>,
297
-
) -> Result<Json<repo::delete_record::Output>> {
298
-
todo!();
299
-
// // TODO: `input.swap_record`
300
-
301
-
// Ok(Json(
302
-
// repo::delete_record::OutputData {
303
-
// commit: apply_writes::apply_writes(
304
-
// user,
305
-
// State(actor_store),
306
-
// State(skey),
307
-
// State(config),
308
-
// State(db),
309
-
// State(fhp),
310
-
// Json(
311
-
// repo::apply_writes::InputData {
312
-
// repo: input.repo.clone(),
313
-
// swap_commit: input.swap_commit.clone(),
314
-
// validate: None,
315
-
// writes: vec![repo::apply_writes::InputWritesItem::Delete(Box::new(
316
-
// repo::apply_writes::DeleteData {
317
-
// collection: input.collection.clone(),
318
-
// rkey: input.rkey.clone(),
319
-
// }
320
-
// .into(),
321
-
// ))],
322
-
// }
323
-
// .into(),
324
-
// ),
325
-
// )
326
-
// .await
327
-
// .context("failed to apply writes")?
328
-
// .commit
329
-
// .clone(),
330
-
// }
331
-
// .into(),
332
-
// ))
333
-
}
334
-
335
-
/// Get information about an account and repository, including the list of collections. Does not require auth.
336
-
/// - GET /xrpc/com.atproto.repo.describeRepo
337
-
/// ### Query Parameters
338
-
/// - `repo`: `at-identifier` // The handle or DID of the repo.
339
-
/// ### Responses
340
-
/// - 200 OK: {"handle": "string","did": "string","didDoc": {},"collections": [string],"handleIsCorrect": true} \
341
-
/// handeIsCorrect - boolean - Indicates if handle is currently valid (resolves bi-directionally)
342
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]}
343
-
/// - 401 Unauthorized
344
-
async fn describe_repo(
345
-
State(actor_store): State<ActorStore>,
346
-
State(config): State<AppConfig>,
347
-
State(db): State<Db>,
348
-
Query(input): Query<repo::describe_repo::ParametersData>,
349
-
) -> Result<Json<repo::describe_repo::Output>> {
350
-
// Lookup the DID by the provided handle.
351
-
let (did, handle) = resolve_did(&db, &input.repo)
352
-
.await
353
-
.context("failed to resolve handle")?;
354
-
355
-
// Use Actor Store to get the collections
356
-
todo!();
357
-
}
358
-
359
-
/// Get a single record from a repository. Does not require auth.
360
-
/// - GET /xrpc/com.atproto.repo.getRecord
361
-
/// ### Query Parameters
362
-
/// - `repo`: `at-identifier` // The handle or DID of the repo.
363
-
/// - `collection`: `nsid` // The NSID of the record collection.
364
-
/// - `rkey`: `string` // The record key. <= 512 characters.
365
-
/// - `cid`: `cid` // The CID of the version of the record. If not specified, then return the most recent version.
366
-
/// ### Responses
367
-
/// - 200 OK: {"uri": "string","cid": "string","value": {}}
368
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`]}
369
-
/// - 401 Unauthorized
370
-
async fn get_record(
371
-
State(actor_store): State<ActorStore>,
372
-
State(config): State<AppConfig>,
373
-
State(db): State<Db>,
374
-
Query(input): Query<repo::get_record::ParametersData>,
375
-
) -> Result<Json<repo::get_record::Output>> {
376
-
if input.cid.is_some() {
377
-
return Err(Error::unimplemented(anyhow!(
378
-
"looking up old records is unsupported"
379
-
)));
380
-
}
381
-
382
-
// Lookup the DID by the provided handle.
383
-
let (did, _handle) = resolve_did(&db, &input.repo)
384
-
.await
385
-
.context("failed to resolve handle")?;
386
-
387
-
// Create a URI from the parameters
388
-
let uri = format!(
389
-
"at://{}/{}/{}",
390
-
did.as_str(),
391
-
input.collection.as_str(),
392
-
input.rkey.as_str()
393
-
);
394
-
395
-
// Use Actor Store to get the record
396
-
todo!();
397
-
}
398
-
399
-
/// List a range of records in a repository, matching a specific collection. Does not require auth.
400
-
/// - GET /xrpc/com.atproto.repo.listRecords
401
-
/// ### Query Parameters
402
-
/// - `repo`: `at-identifier` // The handle or DID of the repo.
403
-
/// - `collection`: `nsid` // The NSID of the record type.
404
-
/// - `limit`: `integer` // The maximum number of records to return. Default 50, >=1 and <=100.
405
-
/// - `cursor`: `string`
406
-
/// - `reverse`: `boolean` // Flag to reverse the order of the returned records.
407
-
/// ### Responses
408
-
/// - 200 OK: {"cursor": "string","records": [{"uri": "string","cid": "string","value": {}}]}
409
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]}
410
-
/// - 401 Unauthorized
411
-
async fn list_records(
412
-
State(actor_store): State<ActorStore>,
413
-
State(config): State<AppConfig>,
414
-
State(db): State<Db>,
415
-
Query(input): Query<Object<ListRecordsParameters>>,
416
-
) -> Result<Json<repo::list_records::Output>> {
417
-
// Lookup the DID by the provided handle.
418
-
let (did, _handle) = resolve_did(&db, &input.repo)
419
-
.await
420
-
.context("failed to resolve handle")?;
421
-
422
-
// Use Actor Store to list records for the collection
423
-
todo!();
424
-
}
425
-
426
-
/// Upload a new blob, to be referenced from a repository record. \
427
-
/// The blob will be deleted if it is not referenced within a time window (eg, minutes). \
428
-
/// Blob restrictions (mimetype, size, etc) are enforced when the reference is created. \
429
-
/// Requires auth, implemented by PDS.
430
-
/// - POST /xrpc/com.atproto.repo.uploadBlob
431
-
/// ### Request Body
432
-
/// ### Responses
433
-
/// - 200 OK: {"blob": "binary"}
434
-
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]}
435
-
/// - 401 Unauthorized
436
-
async fn upload_blob(
437
-
user: AuthenticatedUser,
438
-
State(actor_store): State<ActorStore>,
439
-
State(config): State<AppConfig>,
440
-
State(db): State<Db>,
441
-
request: Request<Body>,
442
-
) -> Result<Json<repo::upload_blob::Output>> {
443
-
let length = request
444
-
.headers()
445
-
.get(http::header::CONTENT_LENGTH)
446
-
.context("no content length provided")?
447
-
.to_str()
448
-
.map_err(anyhow::Error::from)
449
-
.and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from))
450
-
.context("invalid content-length header")?;
451
-
let mime = request
452
-
.headers()
453
-
.get(http::header::CONTENT_TYPE)
454
-
.context("no content-type provided")?
455
-
.to_str()
456
-
.context("invalid content-type provided")?
457
-
.to_owned();
458
-
459
-
if length > config.blob.limit {
460
-
return Err(Error::with_status(
461
-
StatusCode::PAYLOAD_TOO_LARGE,
462
-
anyhow!("size {} above limit {}", length, config.blob.limit),
463
-
));
464
-
}
465
-
466
-
// Read the blob data
467
-
let mut body_data = Vec::new();
468
-
let mut stream = request.into_body().into_data_stream();
469
-
while let Some(bytes) = stream.try_next().await.context("failed to receive file")? {
470
-
body_data.extend_from_slice(&bytes);
471
-
472
-
// Check size limit incrementally
473
-
if body_data.len() as u64 > config.blob.limit {
474
-
return Err(Error::with_status(
475
-
StatusCode::PAYLOAD_TOO_LARGE,
476
-
anyhow!("size above limit and content-length header was wrong"),
477
-
));
478
-
}
479
-
}
480
-
481
-
// Use Actor Store to upload the blob
482
-
todo!();
483
-
}
484
-
485
-
async fn todo() -> Result<()> {
486
-
Err(Error::unimplemented(anyhow!("not implemented")))
487
-
}
488
-
489
-
/// These endpoints are part of the atproto PDS repository management APIs. \
490
-
/// Requests usually require authentication (unlike the com.atproto.sync.* endpoints), and are made directly to the user's own PDS instance.
491
-
/// ### Routes
492
-
/// - AP /xrpc/com.atproto.repo.applyWrites -> [`apply_writes`]
493
-
/// - AP /xrpc/com.atproto.repo.createRecord -> [`create_record`]
494
-
/// - AP /xrpc/com.atproto.repo.putRecord -> [`put_record`]
495
-
/// - AP /xrpc/com.atproto.repo.deleteRecord -> [`delete_record`]
496
-
/// - AP /xrpc/com.atproto.repo.uploadBlob -> [`upload_blob`]
497
-
/// - UG /xrpc/com.atproto.repo.describeRepo -> [`describe_repo`]
498
-
/// - UG /xrpc/com.atproto.repo.getRecord -> [`get_record`]
499
-
/// - UG /xrpc/com.atproto.repo.listRecords -> [`list_records`]
500
-
/// - [ ] xx /xrpc/com.atproto.repo.importRepo
501
-
// - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs
502
-
pub(super) fn routes() -> Router<AppState> {
503
-
Router::new()
504
-
.route(concat!("/", repo::apply_writes::NSID), post(apply_writes))
505
-
// .route(concat!("/", repo::create_record::NSID), post(create_record))
506
-
// .route(concat!("/", repo::put_record::NSID), post(put_record))
507
-
// .route(concat!("/", repo::delete_record::NSID), post(delete_record))
508
-
// .route(concat!("/", repo::upload_blob::NSID), post(upload_blob))
509
-
// .route(concat!("/", repo::describe_repo::NSID), get(describe_repo))
510
-
// .route(concat!("/", repo::get_record::NSID), get(get_record))
511
-
.route(concat!("/", repo::import_repo::NSID), post(todo))
512
-
.route(concat!("/", repo::list_missing_blobs::NSID), get(todo))
513
-
// .route(concat!("/", repo::list_records::NSID), get(list_records))
514
-
}
+117
src/apis/com/atproto/repo/upload_blob.rs
+117
src/apis/com/atproto/repo/upload_blob.rs
···
1
+
//! Upload a new blob, to be referenced from a repository record.
2
+
use crate::config::AppConfig;
3
+
use anyhow::Context as _;
4
+
use axum::{
5
+
body::Bytes,
6
+
http::{self, HeaderMap},
7
+
};
8
+
use rsky_lexicon::com::atproto::repo::{Blob, BlobOutput};
9
+
use rsky_repo::types::{BlobConstraint, PreparedBlobRef};
10
+
// use rsky_common::BadContentTypeError;
11
+
12
+
use super::*;
13
+
14
+
async fn inner_upload_blob(
15
+
auth: AuthenticatedUser,
16
+
blob: Bytes,
17
+
content_type: String,
18
+
actor_pools: HashMap<String, ActorStorage>,
19
+
) -> Result<BlobOutput> {
20
+
// let requester = auth.access.credentials.unwrap().did.unwrap();
21
+
let requester = auth.did();
22
+
23
+
let actor_store = ActorStore::from_actor_pools(&requester, &actor_pools).await;
24
+
25
+
let metadata = actor_store
26
+
.blob
27
+
.upload_blob_and_get_metadata(content_type, blob)
28
+
.await?;
29
+
let blobref = actor_store.blob.track_untethered_blob(metadata).await?;
30
+
31
+
// make the blob permanent if an associated record is already indexed
32
+
let records_for_blob = actor_store
33
+
.blob
34
+
.get_records_for_blob(blobref.get_cid()?)
35
+
.await?;
36
+
37
+
if !records_for_blob.is_empty() {
38
+
actor_store
39
+
.blob
40
+
.verify_blob_and_make_permanent(PreparedBlobRef {
41
+
cid: blobref.get_cid()?,
42
+
mime_type: blobref.get_mime_type().to_string(),
43
+
constraints: BlobConstraint {
44
+
max_size: None,
45
+
accept: None,
46
+
},
47
+
})
48
+
.await?;
49
+
}
50
+
51
+
Ok(BlobOutput {
52
+
blob: Blob {
53
+
r#type: Some("blob".to_owned()),
54
+
r#ref: Some(blobref.get_cid()?),
55
+
cid: None,
56
+
mime_type: blobref.get_mime_type().to_string(),
57
+
size: blobref.get_size(),
58
+
original: None,
59
+
},
60
+
})
61
+
}
62
+
63
+
/// Upload a new blob, to be referenced from a repository record. \
64
+
/// The blob will be deleted if it is not referenced within a time window (eg, minutes). \
65
+
/// Blob restrictions (mimetype, size, etc) are enforced when the reference is created. \
66
+
/// Requires auth, implemented by PDS.
67
+
/// - POST /xrpc/com.atproto.repo.uploadBlob
68
+
/// ### Request Body
69
+
/// ### Responses
70
+
/// - 200 OK: {"blob": "binary"}
71
+
/// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]}
72
+
/// - 401 Unauthorized
73
+
#[tracing::instrument(skip_all)]
74
+
#[axum::debug_handler(state = AppState)]
75
+
pub async fn upload_blob(
76
+
auth: AuthenticatedUser,
77
+
headers: HeaderMap,
78
+
State(config): State<AppConfig>,
79
+
State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>,
80
+
blob: Bytes,
81
+
) -> Result<Json<BlobOutput>, ApiError> {
82
+
let content_length = headers
83
+
.get(http::header::CONTENT_LENGTH)
84
+
.context("no content length provided")?
85
+
.to_str()
86
+
.map_err(anyhow::Error::from)
87
+
.and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from))
88
+
.context("invalid content-length header")?;
89
+
let content_type = headers
90
+
.get(http::header::CONTENT_TYPE)
91
+
.context("no content-type provided")?
92
+
.to_str()
93
+
// .map_err(BadContentTypeError::MissingType)
94
+
.context("invalid content-type provided")?
95
+
.to_owned();
96
+
97
+
if content_length > config.blob.limit {
98
+
return Err(ApiError::InvalidRequest(format!(
99
+
"Content-Length is greater than maximum of {}",
100
+
config.blob.limit
101
+
)));
102
+
};
103
+
if blob.len() as u64 > config.blob.limit {
104
+
return Err(ApiError::InvalidRequest(format!(
105
+
"Blob size is greater than maximum of {} despite content-length header",
106
+
config.blob.limit
107
+
)));
108
+
};
109
+
110
+
match inner_upload_blob(auth, blob, content_type, actor_pools).await {
111
+
Ok(res) => Ok(Json(res)),
112
+
Err(error) => {
113
+
tracing::error!("{error:?}");
114
+
Err(ApiError::RuntimeError)
115
+
}
116
+
}
117
+
}
+1
-1
src/apis/mod.rs
+1
-1
src/apis/mod.rs
···
7
7
use axum::{Json, Router, routing::get};
8
8
use serde_json::json;
9
9
10
-
use crate::{AppState, Result};
10
+
use crate::serve::{AppState, Result};
11
11
12
12
/// Health check endpoint. Returns name and version of the service.
13
13
pub(crate) async fn health() -> Result<Json<serde_json::Value>> {
+4
-1
src/auth.rs
+4
-1
src/auth.rs
···
8
8
use diesel::prelude::*;
9
9
use sha2::{Digest as _, Sha256};
10
10
11
-
use crate::{AppState, Error, error::ErrorMessage};
11
+
use crate::{
12
+
error::{Error, ErrorMessage},
13
+
serve::AppState,
14
+
};
12
15
13
16
/// Request extractor for authenticated users.
14
17
/// If specified in an API endpoint, this guarantees the API can only be called
+1
-1
src/did.rs
+1
-1
src/did.rs
+18
-11
src/error.rs
+18
-11
src/error.rs
···
148
148
149
149
impl ApiError {
150
150
/// Get the appropriate HTTP status code for this error
151
-
fn status_code(&self) -> StatusCode {
151
+
const fn status_code(&self) -> StatusCode {
152
152
match self {
153
153
Self::RuntimeError => StatusCode::INTERNAL_SERVER_ERROR,
154
154
Self::InvalidLogin
···
190
190
Self::BadRequest(error, _) => error,
191
191
Self::AuthRequiredError(_) => "AuthRequiredError",
192
192
}
193
-
.to_string()
193
+
.to_owned()
194
194
}
195
195
196
196
/// Get the user-facing error message
···
218
218
Self::BadRequest(_, msg) => msg,
219
219
Self::AuthRequiredError(msg) => msg,
220
220
}
221
-
.to_string()
221
+
.to_owned()
222
222
}
223
223
}
224
224
225
225
impl From<Error> for ApiError {
226
226
fn from(_value: Error) -> Self {
227
-
ApiError::RuntimeError
227
+
Self::RuntimeError
228
+
}
229
+
}
230
+
231
+
impl From<anyhow::Error> for ApiError {
232
+
fn from(_value: anyhow::Error) -> Self {
233
+
Self::RuntimeError
228
234
}
229
235
}
230
236
231
237
impl From<handle::errors::Error> for ApiError {
232
238
fn from(value: handle::errors::Error) -> Self {
233
239
match value.kind {
234
-
ErrorKind::InvalidHandle => ApiError::InvalidHandle,
235
-
ErrorKind::HandleNotAvailable => ApiError::HandleNotAvailable,
236
-
ErrorKind::UnsupportedDomain => ApiError::UnsupportedDomain,
237
-
ErrorKind::InternalError => ApiError::RuntimeError,
240
+
ErrorKind::InvalidHandle => Self::InvalidHandle,
241
+
ErrorKind::HandleNotAvailable => Self::HandleNotAvailable,
242
+
ErrorKind::UnsupportedDomain => Self::UnsupportedDomain,
243
+
ErrorKind::InternalError => Self::RuntimeError,
238
244
}
239
245
}
240
246
}
···
245
251
let error_type = self.error_type();
246
252
let message = self.message();
247
253
248
-
// Log the error for debugging
249
-
error!("API Error: {}: {}", error_type, message);
254
+
if cfg!(debug_assertions) {
255
+
error!("API Error: {}: {}", error_type, message);
256
+
}
250
257
251
258
// Create the error message and serialize to JSON
252
259
let error_message = ErrorMessage::new(error_type, message);
253
260
let body = serde_json::to_string(&error_message).unwrap_or_else(|_| {
254
-
r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_string()
261
+
r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_owned()
255
262
});
256
263
257
264
// Build the response
-426
src/firehose.rs
-426
src/firehose.rs
···
1
-
//! The firehose module.
2
-
use std::{collections::VecDeque, time::Duration};
3
-
4
-
use anyhow::{Result, bail};
5
-
use atrium_api::{
6
-
com::atproto::sync::{self},
7
-
types::string::{Datetime, Did, Tid},
8
-
};
9
-
use atrium_repo::Cid;
10
-
use axum::extract::ws::{Message, WebSocket};
11
-
use metrics::{counter, gauge};
12
-
use rand::Rng as _;
13
-
use serde::{Serialize, ser::SerializeMap as _};
14
-
use tracing::{debug, error, info, warn};
15
-
16
-
use crate::{
17
-
Client,
18
-
config::AppConfig,
19
-
metrics::{FIREHOSE_HISTORY, FIREHOSE_LISTENERS, FIREHOSE_MESSAGES, FIREHOSE_SEQUENCE},
20
-
};
21
-
22
-
enum FirehoseMessage {
23
-
Broadcast(sync::subscribe_repos::Message),
24
-
Connect(Box<(WebSocket, Option<i64>)>),
25
-
}
26
-
27
-
enum FrameHeader {
28
-
Error,
29
-
Message(String),
30
-
}
31
-
32
-
impl Serialize for FrameHeader {
33
-
#[expect(clippy::question_mark_used, reason = "returns a Result")]
34
-
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
35
-
where
36
-
S: serde::Serializer,
37
-
{
38
-
let mut map = serializer.serialize_map(None)?;
39
-
40
-
match *self {
41
-
Self::Message(ref s) => {
42
-
map.serialize_key("op")?;
43
-
map.serialize_value(&1_i32)?;
44
-
map.serialize_key("t")?;
45
-
map.serialize_value(s.as_str())?;
46
-
}
47
-
Self::Error => {
48
-
map.serialize_key("op")?;
49
-
map.serialize_value(&-1_i32)?;
50
-
}
51
-
}
52
-
53
-
map.end()
54
-
}
55
-
}
56
-
57
-
/// A repository operation.
58
-
pub(crate) enum RepoOp {
59
-
/// Create a new record.
60
-
Create {
61
-
/// The CID of the record.
62
-
cid: Cid,
63
-
/// The path of the record.
64
-
path: String,
65
-
},
66
-
/// Delete an existing record.
67
-
Delete {
68
-
/// The path of the record.
69
-
path: String,
70
-
/// The previous CID of the record.
71
-
prev: Cid,
72
-
},
73
-
/// Update an existing record.
74
-
Update {
75
-
/// The CID of the record.
76
-
cid: Cid,
77
-
/// The path of the record.
78
-
path: String,
79
-
/// The previous CID of the record.
80
-
prev: Cid,
81
-
},
82
-
}
83
-
84
-
impl From<RepoOp> for sync::subscribe_repos::RepoOp {
85
-
fn from(val: RepoOp) -> Self {
86
-
let (action, cid, prev, path) = match val {
87
-
RepoOp::Create { cid, path } => ("create", Some(cid), None, path),
88
-
RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path),
89
-
RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path),
90
-
};
91
-
92
-
sync::subscribe_repos::RepoOpData {
93
-
action: action.to_owned(),
94
-
cid: cid.map(atrium_api::types::CidLink),
95
-
prev: prev.map(atrium_api::types::CidLink),
96
-
path,
97
-
}
98
-
.into()
99
-
}
100
-
}
101
-
102
-
/// A commit to the repository.
103
-
pub(crate) struct Commit {
104
-
/// Blobs that were created in this commit.
105
-
pub blobs: Vec<Cid>,
106
-
/// The car file containing the commit blocks.
107
-
pub car: Vec<u8>,
108
-
/// The CID of the commit.
109
-
pub cid: Cid,
110
-
/// The DID of the repository changed.
111
-
pub did: Did,
112
-
/// The operations performed in this commit.
113
-
pub ops: Vec<RepoOp>,
114
-
/// The previous commit's CID (if applicable).
115
-
pub pcid: Option<Cid>,
116
-
/// The revision of the commit.
117
-
pub rev: String,
118
-
}
119
-
120
-
impl From<Commit> for sync::subscribe_repos::Commit {
121
-
fn from(val: Commit) -> Self {
122
-
sync::subscribe_repos::CommitData {
123
-
blobs: val
124
-
.blobs
125
-
.into_iter()
126
-
.map(atrium_api::types::CidLink)
127
-
.collect::<Vec<_>>(),
128
-
blocks: val.car,
129
-
commit: atrium_api::types::CidLink(val.cid),
130
-
ops: val.ops.into_iter().map(Into::into).collect::<Vec<_>>(),
131
-
prev_data: val.pcid.map(atrium_api::types::CidLink),
132
-
rebase: false,
133
-
repo: val.did,
134
-
rev: Tid::new(val.rev).expect("should be valid revision"),
135
-
seq: 0,
136
-
since: None,
137
-
time: Datetime::now(),
138
-
too_big: false,
139
-
}
140
-
.into()
141
-
}
142
-
}
143
-
144
-
/// A firehose producer. This is used to transmit messages to the firehose for broadcast.
145
-
#[derive(Clone, Debug)]
146
-
pub(crate) struct FirehoseProducer {
147
-
/// The channel to send messages to the firehose.
148
-
tx: tokio::sync::mpsc::Sender<FirehoseMessage>,
149
-
}
150
-
151
-
impl FirehoseProducer {
152
-
/// Broadcast an `#account` event.
153
-
pub(crate) async fn account(&self, account: impl Into<sync::subscribe_repos::Account>) {
154
-
drop(
155
-
self.tx
156
-
.send(FirehoseMessage::Broadcast(
157
-
sync::subscribe_repos::Message::Account(Box::new(account.into())),
158
-
))
159
-
.await,
160
-
);
161
-
}
162
-
/// Handle client connection.
163
-
pub(crate) async fn client_connection(&self, ws: WebSocket, cursor: Option<i64>) {
164
-
drop(
165
-
self.tx
166
-
.send(FirehoseMessage::Connect(Box::new((ws, cursor))))
167
-
.await,
168
-
);
169
-
}
170
-
/// Broadcast a `#commit` event.
171
-
pub(crate) async fn commit(&self, commit: impl Into<sync::subscribe_repos::Commit>) {
172
-
drop(
173
-
self.tx
174
-
.send(FirehoseMessage::Broadcast(
175
-
sync::subscribe_repos::Message::Commit(Box::new(commit.into())),
176
-
))
177
-
.await,
178
-
);
179
-
}
180
-
/// Broadcast an `#identity` event.
181
-
pub(crate) async fn identity(&self, identity: impl Into<sync::subscribe_repos::Identity>) {
182
-
drop(
183
-
self.tx
184
-
.send(FirehoseMessage::Broadcast(
185
-
sync::subscribe_repos::Message::Identity(Box::new(identity.into())),
186
-
))
187
-
.await,
188
-
);
189
-
}
190
-
}
191
-
192
-
#[expect(
193
-
clippy::as_conversions,
194
-
clippy::cast_possible_truncation,
195
-
clippy::cast_sign_loss,
196
-
clippy::cast_precision_loss,
197
-
clippy::arithmetic_side_effects
198
-
)]
199
-
/// Convert a `usize` to a `f64`.
200
-
const fn convert_usize_f64(x: usize) -> Result<f64, &'static str> {
201
-
let result = x as f64;
202
-
if result as usize - x > 0 {
203
-
return Err("cannot convert");
204
-
}
205
-
Ok(result)
206
-
}
207
-
208
-
/// Serialize a message.
209
-
fn serialize_message(seq: u64, mut msg: sync::subscribe_repos::Message) -> (&'static str, Vec<u8>) {
210
-
let mut dummy_seq = 0_i64;
211
-
#[expect(clippy::pattern_type_mismatch)]
212
-
let (ty, nseq) = match &mut msg {
213
-
sync::subscribe_repos::Message::Account(m) => ("#account", &mut m.seq),
214
-
sync::subscribe_repos::Message::Commit(m) => ("#commit", &mut m.seq),
215
-
sync::subscribe_repos::Message::Identity(m) => ("#identity", &mut m.seq),
216
-
sync::subscribe_repos::Message::Sync(m) => ("#sync", &mut m.seq),
217
-
sync::subscribe_repos::Message::Info(_m) => ("#info", &mut dummy_seq),
218
-
};
219
-
// Set the sequence number.
220
-
*nseq = i64::try_from(seq).expect("should find seq");
221
-
222
-
let hdr = FrameHeader::Message(ty.to_owned());
223
-
224
-
let mut frame = Vec::new();
225
-
serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
226
-
serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message");
227
-
228
-
(ty, frame)
229
-
}
230
-
231
-
/// Broadcast a message out to all clients.
232
-
async fn broadcast_message(clients: &mut Vec<WebSocket>, msg: Message) -> Result<()> {
233
-
counter!(FIREHOSE_MESSAGES).increment(1);
234
-
235
-
for i in (0..clients.len()).rev() {
236
-
let client = clients.get_mut(i).expect("should find client");
237
-
if let Err(e) = client.send(msg.clone()).await {
238
-
debug!("Firehose client disconnected: {e}");
239
-
drop(clients.remove(i));
240
-
}
241
-
}
242
-
243
-
gauge!(FIREHOSE_LISTENERS)
244
-
.set(convert_usize_f64(clients.len()).expect("should find clients length"));
245
-
Ok(())
246
-
}
247
-
248
-
/// Handle a new connection from a websocket client created by subscribeRepos.
249
-
async fn handle_connect(
250
-
mut ws: WebSocket,
251
-
seq: u64,
252
-
history: &VecDeque<(u64, &str, sync::subscribe_repos::Message)>,
253
-
cursor: Option<i64>,
254
-
) -> Result<WebSocket> {
255
-
if let Some(cursor) = cursor {
256
-
let mut frame = Vec::new();
257
-
let cursor = u64::try_from(cursor);
258
-
if cursor.is_err() {
259
-
tracing::warn!("cursor is not a valid u64");
260
-
return Ok(ws);
261
-
}
262
-
let cursor = cursor.expect("should be valid u64");
263
-
// Cursor specified; attempt to backfill the consumer.
264
-
if cursor > seq {
265
-
let hdr = FrameHeader::Error;
266
-
let msg = sync::subscribe_repos::Error::FutureCursor(Some(format!(
267
-
"cursor {cursor} is greater than the current sequence number {seq}"
268
-
)));
269
-
serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
270
-
serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message");
271
-
// Drop the connection.
272
-
drop(ws.send(Message::binary(frame)).await);
273
-
bail!(
274
-
"connection dropped: cursor {cursor} is greater than the current sequence number {seq}"
275
-
);
276
-
}
277
-
278
-
for &(historical_seq, ty, ref msg) in history {
279
-
if cursor > historical_seq {
280
-
continue;
281
-
}
282
-
let hdr = FrameHeader::Message(ty.to_owned());
283
-
serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header");
284
-
serde_ipld_dagcbor::to_writer(&mut frame, msg).expect("should serialize message");
285
-
if let Err(e) = ws.send(Message::binary(frame.clone())).await {
286
-
debug!("Firehose client disconnected during backfill: {e}");
287
-
break;
288
-
}
289
-
// Clear out the frame to begin a new one.
290
-
frame.clear();
291
-
}
292
-
}
293
-
294
-
Ok(ws)
295
-
}
296
-
297
-
/// Reconnect to upstream relays.
298
-
pub(crate) async fn reconnect_relays(client: &Client, config: &AppConfig) {
299
-
// Avoid connecting to upstream relays in test mode.
300
-
if config.test {
301
-
return;
302
-
}
303
-
304
-
info!("attempting to reconnect to upstream relays");
305
-
for relay in &config.firehose.relays {
306
-
let Some(host) = relay.host_str() else {
307
-
warn!("relay {} has no host specified", relay);
308
-
continue;
309
-
};
310
-
311
-
let r = client
312
-
.post(format!("https://{host}/xrpc/com.atproto.sync.requestCrawl"))
313
-
.json(&serde_json::json!({
314
-
"hostname": format!("https://{}", config.host_name)
315
-
}))
316
-
.send()
317
-
.await;
318
-
319
-
let r = match r {
320
-
Ok(r) => r,
321
-
Err(e) => {
322
-
error!("failed to hit upstream relay {host}: {e}");
323
-
continue;
324
-
}
325
-
};
326
-
327
-
let s = r.status();
328
-
if let Err(e) = r.error_for_status_ref() {
329
-
error!("failed to hit upstream relay {host}: {e}");
330
-
}
331
-
332
-
let b = r.json::<serde_json::Value>().await;
333
-
if let Ok(b) = b {
334
-
info!("relay {host}: {} {}", s, b);
335
-
} else {
336
-
info!("relay {host}: {}", s);
337
-
}
338
-
}
339
-
}
340
-
341
-
/// The main entrypoint for the firehose.
342
-
///
343
-
/// This will broadcast all updates in this PDS out to anyone who is listening.
344
-
///
345
-
/// Reference: <https://atproto.com/specs/sync>
346
-
pub(crate) fn spawn(
347
-
client: Client,
348
-
config: AppConfig,
349
-
) -> (tokio::task::JoinHandle<()>, FirehoseProducer) {
350
-
let (tx, mut rx) = tokio::sync::mpsc::channel(1000);
351
-
let handle = tokio::spawn(async move {
352
-
fn time_since_inception() -> u64 {
353
-
chrono::Utc::now()
354
-
.timestamp_micros()
355
-
.checked_sub(1_743_442_000_000_000)
356
-
.expect("should not wrap")
357
-
.unsigned_abs()
358
-
}
359
-
let mut clients: Vec<WebSocket> = Vec::new();
360
-
let mut history = VecDeque::with_capacity(1000);
361
-
let mut seq = time_since_inception();
362
-
363
-
loop {
364
-
if let Ok(msg) = tokio::time::timeout(Duration::from_secs(30), rx.recv()).await {
365
-
match msg {
366
-
Some(FirehoseMessage::Broadcast(msg)) => {
367
-
let (ty, by) = serialize_message(seq, msg.clone());
368
-
369
-
history.push_back((seq, ty, msg));
370
-
gauge!(FIREHOSE_HISTORY).set(
371
-
convert_usize_f64(history.len()).expect("should find history length"),
372
-
);
373
-
374
-
info!(
375
-
"Broadcasting message {} {} to {} clients",
376
-
seq,
377
-
ty,
378
-
clients.len()
379
-
);
380
-
381
-
counter!(FIREHOSE_SEQUENCE).absolute(seq);
382
-
let now = time_since_inception();
383
-
if now > seq {
384
-
seq = now;
385
-
} else {
386
-
seq = seq.checked_add(1).expect("should not wrap");
387
-
}
388
-
389
-
drop(broadcast_message(&mut clients, Message::binary(by)).await);
390
-
}
391
-
Some(FirehoseMessage::Connect(ws_cursor)) => {
392
-
let (ws, cursor) = *ws_cursor;
393
-
match handle_connect(ws, seq, &history, cursor).await {
394
-
Ok(r) => {
395
-
gauge!(FIREHOSE_LISTENERS).increment(1_i32);
396
-
clients.push(r);
397
-
}
398
-
Err(e) => {
399
-
error!("failed to connect new client: {e}");
400
-
}
401
-
}
402
-
}
403
-
// All producers have been destroyed.
404
-
None => break,
405
-
}
406
-
} else {
407
-
if clients.is_empty() {
408
-
reconnect_relays(&client, &config).await;
409
-
}
410
-
411
-
let contents = rand::thread_rng()
412
-
.sample_iter(rand::distributions::Alphanumeric)
413
-
.take(15)
414
-
.map(char::from)
415
-
.collect::<String>();
416
-
417
-
// Send a websocket ping message.
418
-
// Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets
419
-
let message = Message::Ping(axum::body::Bytes::from_owner(contents));
420
-
drop(broadcast_message(&mut clients, message).await);
421
-
}
422
-
}
423
-
});
424
-
425
-
(handle, FirehoseProducer { tx })
426
-
}
+4
-438
src/lib.rs
+4
-438
src/lib.rs
···
8
8
mod db;
9
9
mod did;
10
10
pub mod error;
11
-
mod firehose;
12
11
mod metrics;
13
-
mod mmap;
14
12
mod models;
15
13
mod oauth;
16
-
mod plc;
14
+
mod pipethrough;
17
15
mod schema;
16
+
mod serve;
18
17
mod service_proxy;
19
-
#[cfg(test)]
20
-
mod tests;
21
18
22
-
use account_manager::{AccountManager, SharedAccountManager};
23
-
use anyhow::{Context as _, anyhow};
24
-
use atrium_api::types::string::Did;
25
-
use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
26
-
use auth::AuthenticatedUser;
27
-
use axum::{
28
-
Router,
29
-
body::Body,
30
-
extract::{FromRef, Request, State},
31
-
http::{self, HeaderMap, Response, StatusCode, Uri},
32
-
response::IntoResponse,
33
-
routing::get,
34
-
};
35
-
use azure_core::credentials::TokenCredential;
36
-
use clap::Parser;
37
-
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
38
-
use config::AppConfig;
39
-
use db::establish_pool;
40
-
use deadpool_diesel::sqlite::Pool;
41
-
use diesel::prelude::*;
42
-
use diesel_migrations::{EmbeddedMigrations, embed_migrations};
43
-
pub use error::Error;
44
-
use figment::{Figment, providers::Format as _};
45
-
use firehose::FirehoseProducer;
46
-
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
47
-
use rand::Rng as _;
48
-
use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer};
49
-
use serde::{Deserialize, Serialize};
50
-
use service_proxy::service_proxy;
51
-
use std::{
52
-
net::{IpAddr, Ipv4Addr, SocketAddr},
53
-
path::PathBuf,
54
-
str::FromStr as _,
55
-
sync::Arc,
56
-
};
57
-
use tokio::{net::TcpListener, sync::RwLock};
58
-
use tower_http::{cors::CorsLayer, trace::TraceLayer};
59
-
use tracing::{info, warn};
60
-
use uuid::Uuid;
61
-
62
-
/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`.
63
-
pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
64
-
65
-
/// Embedded migrations
66
-
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
67
-
68
-
/// The application-wide result type.
69
-
pub type Result<T> = std::result::Result<T, Error>;
70
-
/// The reqwest client type with middleware.
71
-
pub type Client = reqwest_middleware::ClientWithMiddleware;
72
-
73
-
/// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose.
74
-
pub struct SharedSequencer {
75
-
/// The sequencer instance.
76
-
pub sequencer: RwLock<Sequencer>,
77
-
}
78
-
79
-
#[expect(
80
-
clippy::arbitrary_source_item_ordering,
81
-
reason = "serialized data might be structured"
82
-
)]
83
-
#[derive(Serialize, Deserialize, Debug, Clone)]
84
-
/// The key data structure.
85
-
struct KeyData {
86
-
/// Primary signing key for all repo operations.
87
-
skey: Vec<u8>,
88
-
/// Primary signing (rotation) key for all PLC operations.
89
-
rkey: Vec<u8>,
90
-
}
91
-
92
-
// FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies,
93
-
// and the implementations of this algorithm are much more limited as compared to P256.
94
-
//
95
-
// Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
96
-
#[derive(Clone)]
97
-
/// The signing key for PLC/DID operations.
98
-
pub struct SigningKey(Arc<Secp256k1Keypair>);
99
-
#[derive(Clone)]
100
-
/// The rotation key for PLC operations.
101
-
pub struct RotationKey(Arc<Secp256k1Keypair>);
102
-
103
-
impl std::ops::Deref for SigningKey {
104
-
type Target = Secp256k1Keypair;
105
-
106
-
fn deref(&self) -> &Self::Target {
107
-
&self.0
108
-
}
109
-
}
110
-
111
-
impl SigningKey {
112
-
/// Import from a private key.
113
-
pub fn import(key: &[u8]) -> Result<Self> {
114
-
let key = Secp256k1Keypair::import(key).context("failed to import signing key")?;
115
-
Ok(Self(Arc::new(key)))
116
-
}
117
-
}
118
-
119
-
impl std::ops::Deref for RotationKey {
120
-
type Target = Secp256k1Keypair;
121
-
122
-
fn deref(&self) -> &Self::Target {
123
-
&self.0
124
-
}
125
-
}
126
-
127
-
#[derive(Parser, Debug, Clone)]
128
-
/// Command line arguments.
129
-
pub struct Args {
130
-
/// Path to the configuration file
131
-
#[arg(short, long, default_value = "default.toml")]
132
-
pub config: PathBuf,
133
-
/// The verbosity level.
134
-
#[command(flatten)]
135
-
pub verbosity: Verbosity<InfoLevel>,
136
-
}
137
-
138
-
/// The actor pools for the database connections.
139
-
pub struct ActorPools {
140
-
/// The database connection pool for the actor's repository.
141
-
pub repo: Pool,
142
-
/// The database connection pool for the actor's blobs.
143
-
pub blob: Pool,
144
-
}
145
-
146
-
impl Clone for ActorPools {
147
-
fn clone(&self) -> Self {
148
-
Self {
149
-
repo: self.repo.clone(),
150
-
blob: self.blob.clone(),
151
-
}
152
-
}
153
-
}
154
-
155
-
#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
156
-
#[derive(Clone, FromRef)]
157
-
pub struct AppState {
158
-
/// The application configuration.
159
-
pub config: AppConfig,
160
-
/// The main database connection pool. Used for common PDS data, like invite codes.
161
-
pub db: Pool,
162
-
/// Actor-specific database connection pools. Hashed by DID.
163
-
pub db_actors: std::collections::HashMap<String, ActorPools>,
164
-
165
-
/// The HTTP client with middleware.
166
-
pub client: Client,
167
-
/// The simple HTTP client.
168
-
pub simple_client: reqwest::Client,
169
-
/// The firehose producer.
170
-
pub sequencer: Arc<SharedSequencer>,
171
-
/// The account manager.
172
-
pub account_manager: Arc<SharedAccountManager>,
173
-
174
-
/// The signing key.
175
-
pub signing_key: SigningKey,
176
-
/// The rotation key.
177
-
pub rotation_key: RotationKey,
178
-
}
19
+
pub use serve::run;
179
20
180
21
/// The index (/) route.
181
-
async fn index() -> impl IntoResponse {
22
+
async fn index() -> impl axum::response::IntoResponse {
182
23
r"
183
24
__ __
184
25
/\ \__ /\ \__
···
199
40
Protocol: https://atproto.com
200
41
"
201
42
}
202
-
203
-
/// The main application entry point.
204
-
#[expect(
205
-
clippy::cognitive_complexity,
206
-
clippy::too_many_lines,
207
-
unused_qualifications,
208
-
reason = "main function has high complexity"
209
-
)]
210
-
pub async fn run() -> anyhow::Result<()> {
211
-
let args = Args::parse();
212
-
213
-
// Set up trace logging to console and account for the user-provided verbosity flag.
214
-
if args.verbosity.log_level_filter() != LevelFilter::Off {
215
-
let lvl = match args.verbosity.log_level_filter() {
216
-
LevelFilter::Error => tracing::Level::ERROR,
217
-
LevelFilter::Warn => tracing::Level::WARN,
218
-
LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO,
219
-
LevelFilter::Debug => tracing::Level::DEBUG,
220
-
LevelFilter::Trace => tracing::Level::TRACE,
221
-
};
222
-
tracing_subscriber::fmt().with_max_level(lvl).init();
223
-
}
224
-
225
-
if !args.config.exists() {
226
-
// Throw up a warning if the config file does not exist.
227
-
//
228
-
// This is not fatal because users can specify all configuration settings via
229
-
// the environment, but the most likely scenario here is that a user accidentally
230
-
// omitted the config file for some reason (e.g. forgot to mount it into Docker).
231
-
warn!(
232
-
"configuration file {} does not exist",
233
-
args.config.display()
234
-
);
235
-
}
236
-
237
-
// Read and parse the user-provided configuration.
238
-
let config: AppConfig = Figment::new()
239
-
.admerge(figment::providers::Toml::file(args.config))
240
-
.admerge(figment::providers::Env::prefixed("BLUEPDS_"))
241
-
.extract()
242
-
.context("failed to load configuration")?;
243
-
244
-
if config.test {
245
-
warn!("BluePDS starting up in TEST mode.");
246
-
warn!("This means the application will not federate with the rest of the network.");
247
-
warn!(
248
-
"If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`"
249
-
);
250
-
}
251
-
252
-
// Initialize metrics reporting.
253
-
metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?;
254
-
255
-
// Create a reqwest client that will be used for all outbound requests.
256
-
let simple_client = reqwest::Client::builder()
257
-
.user_agent(APP_USER_AGENT)
258
-
.build()
259
-
.context("failed to build requester client")?;
260
-
let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
261
-
.with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
262
-
mode: CacheMode::Default,
263
-
manager: MokaManager::default(),
264
-
options: HttpCacheOptions::default(),
265
-
}))
266
-
.build();
267
-
268
-
tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?)
269
-
.await
270
-
.context("failed to create key directory")?;
271
-
272
-
// Check if crypto keys exist. If not, create new ones.
273
-
let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) {
274
-
let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f))
275
-
.context("failed to deserialize crypto keys")?;
276
-
277
-
let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?;
278
-
let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?;
279
-
280
-
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
281
-
} else {
282
-
info!("signing keys not found, generating new ones");
283
-
284
-
let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
285
-
let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
286
-
287
-
let keys = KeyData {
288
-
skey: skey.export(),
289
-
rkey: rkey.export(),
290
-
};
291
-
292
-
let mut f = std::fs::File::create(&config.key).context("failed to create key file")?;
293
-
serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?;
294
-
295
-
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
296
-
};
297
-
298
-
tokio::fs::create_dir_all(&config.repo.path).await?;
299
-
tokio::fs::create_dir_all(&config.plc.path).await?;
300
-
tokio::fs::create_dir_all(&config.blob.path).await?;
301
-
302
-
// Create a database connection manager and pool for the main database.
303
-
let pool =
304
-
establish_pool(&config.db).context("failed to establish database connection pool")?;
305
-
// Create a dictionary of database connection pools for each actor.
306
-
let mut actor_pools = std::collections::HashMap::new();
307
-
// let mut actor_blob_pools = std::collections::HashMap::new();
308
-
// We'll determine actors by looking in the data/repo dir for .db files.
309
-
let mut actor_dbs = tokio::fs::read_dir(&config.repo.path)
310
-
.await
311
-
.context("failed to read repo directory")?;
312
-
while let Some(entry) = actor_dbs
313
-
.next_entry()
314
-
.await
315
-
.context("failed to read repo dir")?
316
-
{
317
-
let path = entry.path();
318
-
if path.extension().and_then(|s| s.to_str()) == Some("db") {
319
-
let did_path = path
320
-
.file_stem()
321
-
.and_then(|s| s.to_str())
322
-
.context("failed to get actor DID")?;
323
-
let did = Did::from_str(&format!("did:plc:{}", did_path))
324
-
.expect("should be able to parse actor DID");
325
-
326
-
// Create a new database connection manager and pool for the actor.
327
-
// The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db"
328
-
let path_repo = format!("sqlite://{}", did_path);
329
-
let actor_repo_pool =
330
-
establish_pool(&path_repo).context("failed to create database connection pool")?;
331
-
// Create a new database connection manager and pool for the actor blobs.
332
-
// The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db"
333
-
let path_blob = path_repo.replace("repo", "blob");
334
-
let actor_blob_pool =
335
-
establish_pool(&path_blob).context("failed to create database connection pool")?;
336
-
drop(actor_pools.insert(
337
-
did.to_string(),
338
-
ActorPools {
339
-
repo: actor_repo_pool,
340
-
blob: actor_blob_pool,
341
-
},
342
-
));
343
-
}
344
-
}
345
-
// Apply pending migrations
346
-
// let conn = pool.get().await?;
347
-
// conn.run_pending_migrations(MIGRATIONS)
348
-
// .expect("should be able to run migrations");
349
-
350
-
let hostname = config.host_name.clone();
351
-
let crawlers: Vec<String> = config
352
-
.firehose
353
-
.relays
354
-
.iter()
355
-
.map(|s| s.to_string())
356
-
.collect();
357
-
let sequencer = Arc::new(SharedSequencer {
358
-
sequencer: RwLock::new(Sequencer::new(
359
-
Crawlers::new(hostname, crawlers.clone()),
360
-
None,
361
-
)),
362
-
});
363
-
let account_manager = SharedAccountManager {
364
-
account_manager: RwLock::new(AccountManager::new(pool.clone())),
365
-
};
366
-
367
-
let addr = config
368
-
.listen_address
369
-
.unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000));
370
-
371
-
let app = Router::new()
372
-
.route("/", get(index))
373
-
.merge(oauth::routes())
374
-
.nest(
375
-
"/xrpc",
376
-
apis::routes()
377
-
.merge(actor_endpoints::routes())
378
-
.fallback(service_proxy),
379
-
)
380
-
// .layer(RateLimitLayer::new(30, Duration::from_secs(30)))
381
-
.layer(CorsLayer::permissive())
382
-
.layer(TraceLayer::new_for_http())
383
-
.with_state(AppState {
384
-
config: config.clone(),
385
-
db: pool.clone(),
386
-
db_actors: actor_pools.clone(),
387
-
client: client.clone(),
388
-
simple_client,
389
-
sequencer: sequencer.clone(),
390
-
account_manager: Arc::new(account_manager),
391
-
signing_key: skey,
392
-
rotation_key: rkey,
393
-
});
394
-
395
-
info!("listening on {addr}");
396
-
info!("connect to: http://127.0.0.1:{}", addr.port());
397
-
398
-
// Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
399
-
// If so, create an invite code and share it via the console.
400
-
let conn = pool.get().await.context("failed to get db connection")?;
401
-
402
-
#[derive(QueryableByName)]
403
-
struct TotalCount {
404
-
#[diesel(sql_type = diesel::sql_types::Integer)]
405
-
total_count: i32,
406
-
}
407
-
408
-
let result = conn.interact(move |conn| {
409
-
diesel::sql_query(
410
-
"SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count",
411
-
)
412
-
.get_result::<TotalCount>(conn)
413
-
})
414
-
.await
415
-
.expect("should be able to query database")?;
416
-
417
-
let c = result.total_count;
418
-
419
-
#[expect(clippy::print_stdout)]
420
-
if c == 0 {
421
-
let uuid = Uuid::new_v4().to_string();
422
-
423
-
use crate::models::pds as models;
424
-
use crate::schema::pds::invite_code::dsl as InviteCode;
425
-
let uuid_clone = uuid.clone();
426
-
drop(
427
-
conn.interact(move |conn| {
428
-
diesel::insert_into(InviteCode::invite_code)
429
-
.values(models::InviteCode {
430
-
code: uuid_clone,
431
-
available_uses: 1,
432
-
disabled: 0,
433
-
for_account: "None".to_owned(),
434
-
created_by: "None".to_owned(),
435
-
created_at: "None".to_owned(),
436
-
})
437
-
.execute(conn)
438
-
.context("failed to create new invite code")
439
-
})
440
-
.await
441
-
.expect("should be able to create invite code"),
442
-
);
443
-
444
-
// N.B: This is a sensitive message, so we're bypassing `tracing` here and
445
-
// logging it directly to console.
446
-
println!("=====================================");
447
-
println!(" FIRST STARTUP ");
448
-
println!("=====================================");
449
-
println!("Use this code to create an account:");
450
-
println!("{uuid}");
451
-
println!("=====================================");
452
-
}
453
-
454
-
let listener = TcpListener::bind(&addr)
455
-
.await
456
-
.context("failed to bind address")?;
457
-
458
-
// Serve the app, and request crawling from upstream relays.
459
-
let serve = tokio::spawn(async move {
460
-
axum::serve(listener, app.into_make_service())
461
-
.await
462
-
.context("failed to serve app")
463
-
});
464
-
465
-
// Now that the app is live, request a crawl from upstream relays.
466
-
let mut background_sequencer = sequencer.sequencer.write().await.clone();
467
-
drop(tokio::spawn(
468
-
async move { background_sequencer.start().await },
469
-
));
470
-
471
-
serve
472
-
.await
473
-
.map_err(Into::into)
474
-
.and_then(|r| r)
475
-
.context("failed to serve app")
476
-
}
+1
-3
src/main.rs
+1
-3
src/main.rs
···
1
1
//! BluePDS binary entry point.
2
2
3
3
use anyhow::Context as _;
4
-
use clap::Parser;
5
4
6
5
#[tokio::main(flavor = "multi_thread")]
7
6
async fn main() -> anyhow::Result<()> {
8
-
// Parse command line arguments and call into the library's run function
9
7
bluepds::run().await.context("failed to run application")
10
-
}
8
+
}
-274
src/mmap.rs
-274
src/mmap.rs
···
1
-
#![allow(clippy::arbitrary_source_item_ordering)]
2
-
use std::io::{ErrorKind, Read as _, Seek as _, Write as _};
3
-
4
-
#[cfg(unix)]
5
-
use std::os::fd::AsRawFd as _;
6
-
#[cfg(windows)]
7
-
use std::os::windows::io::AsRawHandle;
8
-
9
-
use memmap2::{MmapMut, MmapOptions};
10
-
11
-
pub(crate) struct MappedFile {
12
-
/// The underlying file handle.
13
-
file: std::fs::File,
14
-
/// The length of the file.
15
-
len: u64,
16
-
/// The mapped memory region.
17
-
map: MmapMut,
18
-
/// Our current offset into the file.
19
-
off: u64,
20
-
}
21
-
22
-
impl MappedFile {
23
-
pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> {
24
-
let len = f.seek(std::io::SeekFrom::End(0))?;
25
-
26
-
#[cfg(windows)]
27
-
let raw = f.as_raw_handle();
28
-
#[cfg(unix)]
29
-
let raw = f.as_raw_fd();
30
-
31
-
#[expect(unsafe_code)]
32
-
Ok(Self {
33
-
// SAFETY:
34
-
// All file-backed memory map constructors are marked \
35
-
// unsafe because of the potential for Undefined Behavior (UB) \
36
-
// using the map if the underlying file is subsequently modified, in or out of process.
37
-
map: unsafe { MmapOptions::new().map_mut(raw)? },
38
-
file: f,
39
-
len,
40
-
off: 0,
41
-
})
42
-
}
43
-
44
-
/// Resize the memory-mapped file. This will reallocate the memory mapping.
45
-
#[expect(unsafe_code)]
46
-
fn resize(&mut self, len: u64) -> std::io::Result<()> {
47
-
// Resize the file.
48
-
self.file.set_len(len)?;
49
-
50
-
#[cfg(windows)]
51
-
let raw = self.file.as_raw_handle();
52
-
#[cfg(unix)]
53
-
let raw = self.file.as_raw_fd();
54
-
55
-
// SAFETY:
56
-
// All file-backed memory map constructors are marked \
57
-
// unsafe because of the potential for Undefined Behavior (UB) \
58
-
// using the map if the underlying file is subsequently modified, in or out of process.
59
-
self.map = unsafe { MmapOptions::new().map_mut(raw)? };
60
-
self.len = len;
61
-
62
-
Ok(())
63
-
}
64
-
}
65
-
66
-
impl std::io::Read for MappedFile {
67
-
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
68
-
if self.off == self.len {
69
-
// If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations.
70
-
return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof"));
71
-
}
72
-
73
-
// Calculate the number of bytes we're going to read.
74
-
let remaining_bytes = self.len.saturating_sub(self.off);
75
-
let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX);
76
-
let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX);
77
-
78
-
let off = usize::try_from(self.off).map_err(|e| {
79
-
std::io::Error::new(
80
-
ErrorKind::InvalidInput,
81
-
format!("offset too large for this platform: {e}"),
82
-
)
83
-
})?;
84
-
85
-
if let (Some(dest), Some(src)) = (
86
-
buf.get_mut(..len),
87
-
self.map.get(off..off.saturating_add(len)),
88
-
) {
89
-
dest.copy_from_slice(src);
90
-
self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0));
91
-
Ok(len)
92
-
} else {
93
-
Err(std::io::Error::new(
94
-
ErrorKind::InvalidInput,
95
-
"invalid buffer range",
96
-
))
97
-
}
98
-
}
99
-
}
100
-
101
-
impl std::io::Write for MappedFile {
102
-
fn flush(&mut self) -> std::io::Result<()> {
103
-
// This is done by the system.
104
-
Ok(())
105
-
}
106
-
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
107
-
// Determine if we need to resize the file.
108
-
let buf_len = u64::try_from(buf.len()).map_err(|e| {
109
-
std::io::Error::new(
110
-
ErrorKind::InvalidInput,
111
-
format!("buffer length too large for this platform: {e}"),
112
-
)
113
-
})?;
114
-
115
-
if self.off.saturating_add(buf_len) >= self.len {
116
-
self.resize(self.off.saturating_add(buf_len))?;
117
-
}
118
-
119
-
let off = usize::try_from(self.off).map_err(|e| {
120
-
std::io::Error::new(
121
-
ErrorKind::InvalidInput,
122
-
format!("offset too large for this platform: {e}"),
123
-
)
124
-
})?;
125
-
let len = buf.len();
126
-
127
-
if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) {
128
-
dest.copy_from_slice(buf);
129
-
self.off = self.off.saturating_add(buf_len);
130
-
Ok(len)
131
-
} else {
132
-
Err(std::io::Error::new(
133
-
ErrorKind::InvalidInput,
134
-
"invalid buffer range",
135
-
))
136
-
}
137
-
}
138
-
}
139
-
140
-
impl std::io::Seek for MappedFile {
141
-
fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> {
142
-
let off = match pos {
143
-
std::io::SeekFrom::Start(i) => i,
144
-
std::io::SeekFrom::End(i) => {
145
-
if i <= 0 {
146
-
// If i is negative or zero, we're seeking backwards from the end
147
-
// or exactly at the end
148
-
self.len.saturating_sub(i.unsigned_abs())
149
-
} else {
150
-
// If i is positive, we're seeking beyond the end, which is allowed
151
-
// but requires extending the file
152
-
self.len.saturating_add(i.unsigned_abs())
153
-
}
154
-
}
155
-
std::io::SeekFrom::Current(i) => {
156
-
if i >= 0 {
157
-
self.off.saturating_add(i.unsigned_abs())
158
-
} else {
159
-
self.off.saturating_sub(i.unsigned_abs())
160
-
}
161
-
}
162
-
};
163
-
164
-
// If the offset is beyond EOF, extend the file to the new size.
165
-
if off > self.len {
166
-
self.resize(off)?;
167
-
}
168
-
169
-
self.off = off;
170
-
Ok(off)
171
-
}
172
-
}
173
-
174
-
impl tokio::io::AsyncRead for MappedFile {
175
-
fn poll_read(
176
-
mut self: std::pin::Pin<&mut Self>,
177
-
_cx: &mut std::task::Context<'_>,
178
-
buf: &mut tokio::io::ReadBuf<'_>,
179
-
) -> std::task::Poll<std::io::Result<()>> {
180
-
let wbuf = buf.initialize_unfilled();
181
-
let len = wbuf.len();
182
-
183
-
std::task::Poll::Ready(match self.read(wbuf) {
184
-
Ok(_) => {
185
-
buf.advance(len);
186
-
Ok(())
187
-
}
188
-
Err(e) => Err(e),
189
-
})
190
-
}
191
-
}
192
-
193
-
impl tokio::io::AsyncWrite for MappedFile {
194
-
fn poll_flush(
195
-
self: std::pin::Pin<&mut Self>,
196
-
_cx: &mut std::task::Context<'_>,
197
-
) -> std::task::Poll<Result<(), std::io::Error>> {
198
-
std::task::Poll::Ready(Ok(()))
199
-
}
200
-
201
-
fn poll_shutdown(
202
-
self: std::pin::Pin<&mut Self>,
203
-
_cx: &mut std::task::Context<'_>,
204
-
) -> std::task::Poll<Result<(), std::io::Error>> {
205
-
std::task::Poll::Ready(Ok(()))
206
-
}
207
-
208
-
fn poll_write(
209
-
mut self: std::pin::Pin<&mut Self>,
210
-
_cx: &mut std::task::Context<'_>,
211
-
buf: &[u8],
212
-
) -> std::task::Poll<Result<usize, std::io::Error>> {
213
-
std::task::Poll::Ready(self.write(buf))
214
-
}
215
-
}
216
-
217
-
impl tokio::io::AsyncSeek for MappedFile {
218
-
fn poll_complete(
219
-
self: std::pin::Pin<&mut Self>,
220
-
_cx: &mut std::task::Context<'_>,
221
-
) -> std::task::Poll<std::io::Result<u64>> {
222
-
std::task::Poll::Ready(Ok(self.off))
223
-
}
224
-
225
-
fn start_seek(
226
-
mut self: std::pin::Pin<&mut Self>,
227
-
position: std::io::SeekFrom,
228
-
) -> std::io::Result<()> {
229
-
self.seek(position).map(|_p| ())
230
-
}
231
-
}
232
-
233
-
#[cfg(test)]
234
-
mod test {
235
-
use rand::Rng as _;
236
-
use std::io::Write as _;
237
-
238
-
use super::*;
239
-
240
-
#[test]
241
-
fn basic_rw() {
242
-
let tmp = std::env::temp_dir().join(
243
-
rand::thread_rng()
244
-
.sample_iter(rand::distributions::Alphanumeric)
245
-
.take(10)
246
-
.map(char::from)
247
-
.collect::<String>(),
248
-
);
249
-
250
-
let mut m = MappedFile::new(
251
-
std::fs::File::options()
252
-
.create(true)
253
-
.truncate(true)
254
-
.read(true)
255
-
.write(true)
256
-
.open(&tmp)
257
-
.expect("Failed to open temporary file"),
258
-
)
259
-
.expect("Failed to create MappedFile");
260
-
261
-
m.write_all(b"abcd123").expect("Failed to write data");
262
-
let _: u64 = m
263
-
.seek(std::io::SeekFrom::Start(0))
264
-
.expect("Failed to seek to start");
265
-
266
-
let mut buf = [0_u8; 7];
267
-
m.read_exact(&mut buf).expect("Failed to read data");
268
-
269
-
assert_eq!(&buf, b"abcd123");
270
-
271
-
drop(m);
272
-
std::fs::remove_file(tmp).expect("Failed to remove temporary file");
273
-
}
274
-
}
+4
-10
src/oauth.rs
+4
-10
src/oauth.rs
···
1
1
//! OAuth endpoints
2
2
#![allow(unnameable_types, unused_qualifications)]
3
+
use crate::config::AppConfig;
4
+
use crate::error::Error;
3
5
use crate::metrics::AUTH_FAILED;
4
-
use crate::{AppConfig, AppState, Client, Error, Result, SigningKey};
6
+
use crate::serve::{AppState, Client, Result, SigningKey};
5
7
use anyhow::{Context as _, anyhow};
6
8
use argon2::{Argon2, PasswordHash, PasswordVerifier as _};
7
9
use atrium_crypto::keypair::Did as _;
···
365
367
let response_type = response_type.to_owned();
366
368
let code_challenge = code_challenge.to_owned();
367
369
let code_challenge_method = code_challenge_method.to_owned();
368
-
let state = state.map(|s| s.to_owned());
369
-
let login_hint = login_hint.map(|s| s.to_owned());
370
-
let scope = scope.map(|s| s.to_owned());
371
-
let redirect_uri = redirect_uri.map(|s| s.to_owned());
372
-
let response_mode = response_mode.map(|s| s.to_owned());
373
-
let display = display.map(|s| s.to_owned());
374
-
let created_at = created_at;
375
-
let expires_at = expires_at;
376
370
_ = db
377
371
.get()
378
372
.await
···
587
581
.interact(move |conn| {
588
582
AccountSchema::account
589
583
.filter(AccountSchema::email.eq(username_clone))
590
-
.first::<rsky_pds::models::Account>(conn)
584
+
.first::<crate::models::pds::Account>(conn)
591
585
.optional()
592
586
})
593
587
.await
+606
src/pipethrough.rs
+606
src/pipethrough.rs
···
1
+
//! Based on https://github.com/blacksky-algorithms/rsky/blob/main/rsky-pds/src/pipethrough.rs
2
+
//! blacksky-algorithms/rsky is licensed under the Apache License 2.0
3
+
//!
4
+
//! Modified for Axum instead of Rocket
5
+
6
+
use anyhow::{Result, bail};
7
+
use axum::extract::{FromRequestParts, State};
8
+
use rsky_identity::IdResolver;
9
+
use rsky_pds::apis::ApiError;
10
+
use rsky_pds::auth_verifier::{AccessOutput, AccessStandard};
11
+
use rsky_pds::config::{ServerConfig, ServiceConfig, env_to_cfg};
12
+
use rsky_pds::pipethrough::{OverrideOpts, ProxyHeader, UrlAndAud};
13
+
use rsky_pds::xrpc_server::types::{HandlerPipeThrough, InvalidRequestError, XRPCError};
14
+
use rsky_pds::{APP_USER_AGENT, SharedIdResolver, context};
15
+
// use lazy_static::lazy_static;
16
+
use reqwest::header::{CONTENT_TYPE, HeaderValue};
17
+
use reqwest::{Client, Method, RequestBuilder, Response};
18
+
// use rocket::data::ToByteUnit;
19
+
// use rocket::http::{Method, Status};
20
+
// use rocket::request::{FromRequest, Outcome, Request};
21
+
// use rocket::{Data, State};
22
+
use axum::{
23
+
body::Bytes,
24
+
http::{self, HeaderMap},
25
+
};
26
+
use rsky_common::{GetServiceEndpointOpts, get_service_endpoint};
27
+
use rsky_repo::types::Ids;
28
+
use serde::de::DeserializeOwned;
29
+
use serde_json::Value as JsonValue;
30
+
use std::collections::{BTreeMap, HashSet};
31
+
use std::str::FromStr;
32
+
use std::sync::Arc;
33
+
use std::time::Duration;
34
+
use ubyte::ToByteUnit as _;
35
+
use url::Url;
36
+
37
+
use crate::serve::AppState;
38
+
39
+
// pub struct OverrideOpts {
40
+
// pub aud: Option<String>,
41
+
// pub lxm: Option<String>,
42
+
// }
43
+
44
+
// pub struct UrlAndAud {
45
+
// pub url: Url,
46
+
// pub aud: String,
47
+
// pub lxm: String,
48
+
// }
49
+
50
+
// pub struct ProxyHeader {
51
+
// pub did: String,
52
+
// pub service_url: String,
53
+
// }
54
+
55
+
pub struct ProxyRequest {
56
+
pub headers: BTreeMap<String, String>,
57
+
pub query: Option<String>,
58
+
pub path: String,
59
+
pub method: Method,
60
+
pub id_resolver: Arc<tokio::sync::RwLock<rsky_identity::IdResolver>>,
61
+
pub cfg: ServerConfig,
62
+
}
63
+
impl FromRequestParts<AppState> for ProxyRequest {
64
+
// type Rejection = ApiError;
65
+
type Rejection = axum::response::Response;
66
+
67
+
async fn from_request_parts(
68
+
parts: &mut axum::http::request::Parts,
69
+
state: &AppState,
70
+
) -> Result<Self, Self::Rejection> {
71
+
let headers = parts
72
+
.headers
73
+
.iter()
74
+
.map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
75
+
.collect::<BTreeMap<String, String>>();
76
+
let query = parts.uri.query().map(|s| s.to_string());
77
+
let path = parts.uri.path().to_string();
78
+
let method = parts.method.clone();
79
+
let id_resolver = state.id_resolver.clone();
80
+
// let cfg = state.cfg.clone();
81
+
let cfg = env_to_cfg(); // TODO: use state.cfg.clone();
82
+
83
+
Ok(Self {
84
+
headers,
85
+
query,
86
+
path,
87
+
method,
88
+
id_resolver,
89
+
cfg,
90
+
})
91
+
}
92
+
}
93
+
94
+
// #[rocket::async_trait]
95
+
// impl<'r> FromRequest<'r> for HandlerPipeThrough {
96
+
// type Error = anyhow::Error;
97
+
98
+
// #[tracing::instrument(skip_all)]
99
+
// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
100
+
// match AccessStandard::from_request(req).await {
101
+
// Outcome::Success(output) => {
102
+
// let AccessOutput { credentials, .. } = output.access;
103
+
// let requester: Option<String> = match credentials {
104
+
// None => None,
105
+
// Some(credentials) => credentials.did,
106
+
// };
107
+
// let headers = req.headers().clone().into_iter().fold(
108
+
// BTreeMap::new(),
109
+
// |mut acc: BTreeMap<String, String>, cur| {
110
+
// let _ = acc.insert(cur.name().to_string(), cur.value().to_string());
111
+
// acc
112
+
// },
113
+
// );
114
+
// let proxy_req = ProxyRequest {
115
+
// headers,
116
+
// query: match req.uri().query() {
117
+
// None => None,
118
+
// Some(query) => Some(query.to_string()),
119
+
// },
120
+
// path: req.uri().path().to_string(),
121
+
// method: req.method(),
122
+
// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(),
123
+
// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
124
+
// };
125
+
// match pipethrough(
126
+
// &proxy_req,
127
+
// requester,
128
+
// OverrideOpts {
129
+
// aud: None,
130
+
// lxm: None,
131
+
// },
132
+
// )
133
+
// .await
134
+
// {
135
+
// Ok(res) => Outcome::Success(res),
136
+
// Err(error) => match error.downcast_ref() {
137
+
// Some(InvalidRequestError::XRPCError(xrpc)) => {
138
+
// if let XRPCError::FailedResponse {
139
+
// status,
140
+
// error,
141
+
// message,
142
+
// headers,
143
+
// } = xrpc
144
+
// {
145
+
// tracing::error!(
146
+
// "@LOG: XRPC ERROR Status:{status}; Message: {message:?}; Error: {error:?}; Headers: {headers:?}"
147
+
// );
148
+
// }
149
+
// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
150
+
// Outcome::Error((Status::BadRequest, error))
151
+
// }
152
+
// _ => {
153
+
// req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string())));
154
+
// Outcome::Error((Status::BadRequest, error))
155
+
// }
156
+
// },
157
+
// }
158
+
// }
159
+
// Outcome::Error(err) => {
160
+
// req.local_cache(|| Some(ApiError::RuntimeError));
161
+
// Outcome::Error((
162
+
// Status::BadRequest,
163
+
// anyhow::Error::new(InvalidRequestError::AuthError(err.1)),
164
+
// ))
165
+
// }
166
+
// _ => panic!("Unexpected outcome during Pipethrough"),
167
+
// }
168
+
// }
169
+
// }
170
+
171
+
// #[rocket::async_trait]
172
+
// impl<'r> FromRequest<'r> for ProxyRequest<'r> {
173
+
// type Error = anyhow::Error;
174
+
175
+
// async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
176
+
// let headers = req.headers().clone().into_iter().fold(
177
+
// BTreeMap::new(),
178
+
// |mut acc: BTreeMap<String, String>, cur| {
179
+
// let _ = acc.insert(cur.name().to_string(), cur.value().to_string());
180
+
// acc
181
+
// },
182
+
// );
183
+
// Outcome::Success(Self {
184
+
// headers,
185
+
// query: match req.uri().query() {
186
+
// None => None,
187
+
// Some(query) => Some(query.to_string()),
188
+
// },
189
+
// path: req.uri().path().to_string(),
190
+
// method: req.method(),
191
+
// id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(),
192
+
// cfg: req.guard::<&State<ServerConfig>>().await.unwrap(),
193
+
// })
194
+
// }
195
+
// }
196
+
197
+
pub async fn pipethrough(
198
+
req: &ProxyRequest,
199
+
requester: Option<String>,
200
+
override_opts: OverrideOpts,
201
+
) -> Result<HandlerPipeThrough> {
202
+
let UrlAndAud {
203
+
url,
204
+
aud,
205
+
lxm: nsid,
206
+
} = format_url_and_aud(req, override_opts.aud).await?;
207
+
let lxm = override_opts.lxm.unwrap_or(nsid);
208
+
let headers = format_headers(req, aud, lxm, requester).await?;
209
+
let req_init = format_req_init(req, url, headers, None)?;
210
+
let res = make_request(req_init).await?;
211
+
parse_proxy_res(res).await
212
+
}
213
+
214
+
pub async fn pipethrough_procedure<T: serde::Serialize>(
215
+
req: &ProxyRequest,
216
+
requester: Option<String>,
217
+
body: Option<T>,
218
+
) -> Result<HandlerPipeThrough> {
219
+
let UrlAndAud {
220
+
url,
221
+
aud,
222
+
lxm: nsid,
223
+
} = format_url_and_aud(req, None).await?;
224
+
let headers = format_headers(req, aud, nsid, requester).await?;
225
+
let encoded_body: Option<Vec<u8>> = match body {
226
+
None => None,
227
+
Some(body) => Some(serde_json::to_string(&body)?.into_bytes()),
228
+
};
229
+
let req_init = format_req_init(req, url, headers, encoded_body)?;
230
+
let res = make_request(req_init).await?;
231
+
parse_proxy_res(res).await
232
+
}
233
+
234
+
#[tracing::instrument(skip_all)]
235
+
pub async fn pipethrough_procedure_post(
236
+
req: &ProxyRequest,
237
+
requester: Option<String>,
238
+
body: Option<Bytes>,
239
+
) -> Result<HandlerPipeThrough, ApiError> {
240
+
let UrlAndAud {
241
+
url,
242
+
aud,
243
+
lxm: nsid,
244
+
} = format_url_and_aud(req, None).await?;
245
+
let headers = format_headers(req, aud, nsid, requester).await?;
246
+
let encoded_body: Option<JsonValue>;
247
+
match body {
248
+
None => encoded_body = None,
249
+
Some(body) => {
250
+
// let res = match body.open(50.megabytes()).into_string().await {
251
+
// Ok(res1) => {
252
+
// tracing::info!(res1.value);
253
+
// res1.value
254
+
// }
255
+
// Err(error) => {
256
+
// tracing::error!("{error}");
257
+
// return Err(ApiError::RuntimeError);
258
+
// }
259
+
// };
260
+
let res = String::from_utf8(body.to_vec()).expect("Invalid UTF-8");
261
+
262
+
match serde_json::from_str(res.as_str()) {
263
+
Ok(res) => {
264
+
encoded_body = Some(res);
265
+
}
266
+
Err(error) => {
267
+
tracing::error!("{error}");
268
+
return Err(ApiError::RuntimeError);
269
+
}
270
+
}
271
+
}
272
+
};
273
+
let req_init = format_req_init_with_value(req, url, headers, encoded_body)?;
274
+
let res = make_request(req_init).await?;
275
+
Ok(parse_proxy_res(res).await?)
276
+
}
277
+
278
+
// Request setup/formatting
279
+
// -------------------
280
+
281
+
const REQ_HEADERS_TO_FORWARD: [&str; 4] = [
282
+
"accept-language",
283
+
"content-type",
284
+
"atproto-accept-labelers",
285
+
"x-bsky-topics",
286
+
];
287
+
288
+
#[tracing::instrument(skip_all)]
289
+
pub async fn format_url_and_aud(
290
+
req: &ProxyRequest,
291
+
aud_override: Option<String>,
292
+
) -> Result<UrlAndAud> {
293
+
let proxy_to = parse_proxy_header(req).await?;
294
+
let nsid = parse_req_nsid(req);
295
+
let default_proxy = default_service(req, &nsid).await;
296
+
let service_url = match proxy_to {
297
+
Some(ref proxy_to) => {
298
+
tracing::info!(
299
+
"@LOG: format_url_and_aud() proxy_to: {:?}",
300
+
proxy_to.service_url
301
+
);
302
+
Some(proxy_to.service_url.clone())
303
+
}
304
+
None => match default_proxy {
305
+
Some(ref default_proxy) => Some(default_proxy.url.clone()),
306
+
None => None,
307
+
},
308
+
};
309
+
let aud = match aud_override {
310
+
Some(_) => aud_override,
311
+
None => match proxy_to {
312
+
Some(proxy_to) => Some(proxy_to.did),
313
+
None => match default_proxy {
314
+
Some(default_proxy) => Some(default_proxy.did),
315
+
None => None,
316
+
},
317
+
},
318
+
};
319
+
match (service_url, aud) {
320
+
(Some(service_url), Some(aud)) => {
321
+
let mut url = Url::parse(format!("{0}{1}", service_url, req.path).as_str())?;
322
+
if let Some(ref params) = req.query {
323
+
url.set_query(Some(params.as_str()));
324
+
}
325
+
if !req.cfg.service.dev_mode && !is_safe_url(url.clone()) {
326
+
bail!(InvalidRequestError::InvalidServiceUrl(url.to_string()));
327
+
}
328
+
Ok(UrlAndAud {
329
+
url,
330
+
aud,
331
+
lxm: nsid,
332
+
})
333
+
}
334
+
_ => bail!(InvalidRequestError::NoServiceConfigured(req.path.clone())),
335
+
}
336
+
}
337
+
338
+
pub async fn format_headers(
339
+
req: &ProxyRequest,
340
+
aud: String,
341
+
lxm: String,
342
+
requester: Option<String>,
343
+
) -> Result<HeaderMap> {
344
+
let mut headers: HeaderMap = match requester {
345
+
Some(requester) => context::service_auth_headers(&requester, &aud, &lxm).await?,
346
+
None => HeaderMap::new(),
347
+
};
348
+
// forward select headers to upstream services
349
+
for header in REQ_HEADERS_TO_FORWARD {
350
+
let val = req.headers.get(header);
351
+
if let Some(val) = val {
352
+
headers.insert(header, HeaderValue::from_str(val)?);
353
+
}
354
+
}
355
+
Ok(headers)
356
+
}
357
+
358
+
pub fn format_req_init(
359
+
req: &ProxyRequest,
360
+
url: Url,
361
+
headers: HeaderMap,
362
+
body: Option<Vec<u8>>,
363
+
) -> Result<RequestBuilder> {
364
+
match req.method {
365
+
Method::GET => {
366
+
let client = Client::builder()
367
+
.user_agent(APP_USER_AGENT)
368
+
.http2_keep_alive_while_idle(true)
369
+
.http2_keep_alive_timeout(Duration::from_secs(5))
370
+
.default_headers(headers)
371
+
.build()?;
372
+
Ok(client.get(url))
373
+
}
374
+
Method::HEAD => {
375
+
let client = Client::builder()
376
+
.user_agent(APP_USER_AGENT)
377
+
.http2_keep_alive_while_idle(true)
378
+
.http2_keep_alive_timeout(Duration::from_secs(5))
379
+
.default_headers(headers)
380
+
.build()?;
381
+
Ok(client.head(url))
382
+
}
383
+
Method::POST => {
384
+
let client = Client::builder()
385
+
.user_agent(APP_USER_AGENT)
386
+
.http2_keep_alive_while_idle(true)
387
+
.http2_keep_alive_timeout(Duration::from_secs(5))
388
+
.default_headers(headers)
389
+
.build()?;
390
+
Ok(client.post(url).body(body.unwrap()))
391
+
}
392
+
_ => bail!(InvalidRequestError::MethodNotFound),
393
+
}
394
+
}
395
+
396
+
pub fn format_req_init_with_value(
397
+
req: &ProxyRequest,
398
+
url: Url,
399
+
headers: HeaderMap,
400
+
body: Option<JsonValue>,
401
+
) -> Result<RequestBuilder> {
402
+
match req.method {
403
+
Method::GET => {
404
+
let client = Client::builder()
405
+
.user_agent(APP_USER_AGENT)
406
+
.http2_keep_alive_while_idle(true)
407
+
.http2_keep_alive_timeout(Duration::from_secs(5))
408
+
.default_headers(headers)
409
+
.build()?;
410
+
Ok(client.get(url))
411
+
}
412
+
Method::HEAD => {
413
+
let client = Client::builder()
414
+
.user_agent(APP_USER_AGENT)
415
+
.http2_keep_alive_while_idle(true)
416
+
.http2_keep_alive_timeout(Duration::from_secs(5))
417
+
.default_headers(headers)
418
+
.build()?;
419
+
Ok(client.head(url))
420
+
}
421
+
Method::POST => {
422
+
let client = Client::builder()
423
+
.user_agent(APP_USER_AGENT)
424
+
.http2_keep_alive_while_idle(true)
425
+
.http2_keep_alive_timeout(Duration::from_secs(5))
426
+
.default_headers(headers)
427
+
.build()?;
428
+
Ok(client.post(url).json(&body.unwrap()))
429
+
}
430
+
_ => bail!(InvalidRequestError::MethodNotFound),
431
+
}
432
+
}
433
+
434
+
pub async fn parse_proxy_header(req: &ProxyRequest) -> Result<Option<ProxyHeader>> {
435
+
let headers = &req.headers;
436
+
let proxy_to: Option<&String> = headers.get("atproto-proxy");
437
+
match proxy_to {
438
+
None => Ok(None),
439
+
Some(proxy_to) => {
440
+
let parts: Vec<&str> = proxy_to.split("#").collect::<Vec<&str>>();
441
+
match (parts.get(0), parts.get(1), parts.get(2)) {
442
+
(Some(did), Some(service_id), None) => {
443
+
let did = did.to_string();
444
+
let mut lock = req.id_resolver.write().await;
445
+
match lock.did.resolve(did.clone(), None).await? {
446
+
None => bail!(InvalidRequestError::CannotResolveProxyDid),
447
+
Some(did_doc) => {
448
+
match get_service_endpoint(
449
+
did_doc,
450
+
GetServiceEndpointOpts {
451
+
id: format!("#{service_id}"),
452
+
r#type: None,
453
+
},
454
+
) {
455
+
None => bail!(InvalidRequestError::CannotResolveServiceUrl),
456
+
Some(service_url) => Ok(Some(ProxyHeader { did, service_url })),
457
+
}
458
+
}
459
+
}
460
+
}
461
+
(_, None, _) => bail!(InvalidRequestError::NoServiceId),
462
+
_ => bail!("error parsing atproto-proxy header"),
463
+
}
464
+
}
465
+
}
466
+
}
467
+
468
+
pub fn parse_req_nsid(req: &ProxyRequest) -> String {
469
+
let nsid = req.path.as_str().replace("/xrpc/", "");
470
+
match nsid.ends_with("/") {
471
+
false => nsid,
472
+
true => nsid
473
+
.trim_end_matches(|c| c == nsid.chars().last().unwrap())
474
+
.to_string(),
475
+
}
476
+
}
477
+
478
+
// Sending request
479
+
// -------------------
480
+
#[tracing::instrument(skip_all)]
481
+
pub async fn make_request(req_init: RequestBuilder) -> Result<Response> {
482
+
let res = req_init.send().await;
483
+
match res {
484
+
Err(e) => {
485
+
tracing::error!("@LOG WARN: pipethrough network error {}", e.to_string());
486
+
bail!(InvalidRequestError::XRPCError(XRPCError::UpstreamFailure))
487
+
}
488
+
Ok(res) => match res.error_for_status_ref() {
489
+
Ok(_) => Ok(res),
490
+
Err(_) => {
491
+
let status = res.status().to_string();
492
+
let headers = res.headers().clone();
493
+
let error_body = res.json::<JsonValue>().await?;
494
+
bail!(InvalidRequestError::XRPCError(XRPCError::FailedResponse {
495
+
status,
496
+
headers,
497
+
error: match error_body["error"].as_str() {
498
+
None => None,
499
+
Some(error_body_error) => Some(error_body_error.to_string()),
500
+
},
501
+
message: match error_body["message"].as_str() {
502
+
None => None,
503
+
Some(error_body_message) => Some(error_body_message.to_string()),
504
+
}
505
+
}))
506
+
}
507
+
},
508
+
}
509
+
}
510
+
511
+
// Response parsing/forwarding
512
+
// -------------------
513
+
514
+
const RES_HEADERS_TO_FORWARD: [&str; 4] = [
515
+
"content-type",
516
+
"content-language",
517
+
"atproto-repo-rev",
518
+
"atproto-content-labelers",
519
+
];
520
+
521
+
pub async fn parse_proxy_res(res: Response) -> Result<HandlerPipeThrough> {
522
+
let encoding = match res.headers().get(CONTENT_TYPE) {
523
+
Some(content_type) => content_type.to_str()?,
524
+
None => "application/json",
525
+
};
526
+
// Release borrow
527
+
let encoding = encoding.to_string();
528
+
let res_headers = RES_HEADERS_TO_FORWARD.into_iter().fold(
529
+
BTreeMap::new(),
530
+
|mut acc: BTreeMap<String, String>, cur| {
531
+
let _ = match res.headers().get(cur) {
532
+
Some(res_header_val) => acc.insert(
533
+
cur.to_string(),
534
+
res_header_val.clone().to_str().unwrap().to_string(),
535
+
),
536
+
None => None,
537
+
};
538
+
acc
539
+
},
540
+
);
541
+
let buffer = read_array_buffer_res(res).await?;
542
+
Ok(HandlerPipeThrough {
543
+
encoding,
544
+
buffer,
545
+
headers: Some(res_headers),
546
+
})
547
+
}
548
+
549
+
// Utils
550
+
// -------------------
551
+
552
+
pub async fn default_service(req: &ProxyRequest, nsid: &str) -> Option<ServiceConfig> {
553
+
let cfg = req.cfg.clone();
554
+
match Ids::from_str(nsid) {
555
+
Ok(Ids::ToolsOzoneTeamAddMember) => cfg.mod_service,
556
+
Ok(Ids::ToolsOzoneTeamDeleteMember) => cfg.mod_service,
557
+
Ok(Ids::ToolsOzoneTeamUpdateMember) => cfg.mod_service,
558
+
Ok(Ids::ToolsOzoneTeamListMembers) => cfg.mod_service,
559
+
Ok(Ids::ToolsOzoneCommunicationCreateTemplate) => cfg.mod_service,
560
+
Ok(Ids::ToolsOzoneCommunicationDeleteTemplate) => cfg.mod_service,
561
+
Ok(Ids::ToolsOzoneCommunicationUpdateTemplate) => cfg.mod_service,
562
+
Ok(Ids::ToolsOzoneCommunicationListTemplates) => cfg.mod_service,
563
+
Ok(Ids::ToolsOzoneModerationEmitEvent) => cfg.mod_service,
564
+
Ok(Ids::ToolsOzoneModerationGetEvent) => cfg.mod_service,
565
+
Ok(Ids::ToolsOzoneModerationGetRecord) => cfg.mod_service,
566
+
Ok(Ids::ToolsOzoneModerationGetRepo) => cfg.mod_service,
567
+
Ok(Ids::ToolsOzoneModerationQueryEvents) => cfg.mod_service,
568
+
Ok(Ids::ToolsOzoneModerationQueryStatuses) => cfg.mod_service,
569
+
Ok(Ids::ToolsOzoneModerationSearchRepos) => cfg.mod_service,
570
+
Ok(Ids::ComAtprotoModerationCreateReport) => cfg.report_service,
571
+
_ => cfg.bsky_app_view,
572
+
}
573
+
}
574
+
575
+
pub fn parse_res<T: DeserializeOwned>(_nsid: String, res: HandlerPipeThrough) -> Result<T> {
576
+
let buffer = res.buffer;
577
+
let record = serde_json::from_slice::<T>(buffer.as_slice())?;
578
+
Ok(record)
579
+
}
580
+
581
+
#[tracing::instrument(skip_all)]
582
+
pub async fn read_array_buffer_res(res: Response) -> Result<Vec<u8>> {
583
+
match res.bytes().await {
584
+
Ok(bytes) => Ok(bytes.to_vec()),
585
+
Err(err) => {
586
+
tracing::error!("@LOG WARN: pipethrough network error {}", err.to_string());
587
+
bail!("UpstreamFailure")
588
+
}
589
+
}
590
+
}
591
+
592
+
pub fn is_safe_url(url: Url) -> bool {
593
+
if url.scheme() != "https" {
594
+
return false;
595
+
}
596
+
match url.host_str() {
597
+
None => false,
598
+
Some(hostname) if hostname == "localhost" => false,
599
+
Some(hostname) => {
600
+
if std::net::IpAddr::from_str(hostname).is_ok() {
601
+
return false;
602
+
}
603
+
true
604
+
}
605
+
}
606
+
}
-114
src/plc.rs
-114
src/plc.rs
···
1
-
//! PLC operations.
2
-
use std::collections::HashMap;
3
-
4
-
use anyhow::{Context as _, bail};
5
-
use base64::Engine as _;
6
-
use serde::{Deserialize, Serialize};
7
-
use tracing::debug;
8
-
9
-
use crate::{Client, RotationKey};
10
-
11
-
/// The URL of the public PLC directory.
12
-
const PLC_DIRECTORY: &str = "https://plc.directory/";
13
-
14
-
#[derive(Debug, Deserialize, Serialize, Clone)]
15
-
#[serde(rename_all = "camelCase", tag = "type")]
16
-
/// A PLC service.
17
-
pub(crate) enum PlcService {
18
-
#[serde(rename = "AtprotoPersonalDataServer")]
19
-
/// A personal data server.
20
-
Pds {
21
-
/// The URL of the PDS.
22
-
endpoint: String,
23
-
},
24
-
}
25
-
26
-
#[expect(
27
-
clippy::arbitrary_source_item_ordering,
28
-
reason = "serialized data might be structured"
29
-
)]
30
-
#[derive(Debug, Deserialize, Serialize, Clone)]
31
-
#[serde(rename_all = "camelCase")]
32
-
pub(crate) struct PlcOperation {
33
-
#[serde(rename = "type")]
34
-
pub typ: String,
35
-
pub rotation_keys: Vec<String>,
36
-
pub verification_methods: HashMap<String, String>,
37
-
pub also_known_as: Vec<String>,
38
-
pub services: HashMap<String, PlcService>,
39
-
pub prev: Option<String>,
40
-
}
41
-
42
-
impl PlcOperation {
43
-
/// Sign an operation with the provided signature.
44
-
pub(crate) fn sign(self, sig: Vec<u8>) -> SignedPlcOperation {
45
-
SignedPlcOperation {
46
-
typ: self.typ,
47
-
rotation_keys: self.rotation_keys,
48
-
verification_methods: self.verification_methods,
49
-
also_known_as: self.also_known_as,
50
-
services: self.services,
51
-
prev: self.prev,
52
-
sig: base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig),
53
-
}
54
-
}
55
-
}
56
-
57
-
#[expect(
58
-
clippy::arbitrary_source_item_ordering,
59
-
reason = "serialized data might be structured"
60
-
)]
61
-
#[derive(Debug, Deserialize, Serialize, Clone)]
62
-
#[serde(rename_all = "camelCase")]
63
-
/// A signed PLC operation.
64
-
pub(crate) struct SignedPlcOperation {
65
-
#[serde(rename = "type")]
66
-
pub typ: String,
67
-
pub rotation_keys: Vec<String>,
68
-
pub verification_methods: HashMap<String, String>,
69
-
pub also_known_as: Vec<String>,
70
-
pub services: HashMap<String, PlcService>,
71
-
pub prev: Option<String>,
72
-
pub sig: String,
73
-
}
74
-
75
-
pub(crate) fn sign_op(rkey: &RotationKey, op: PlcOperation) -> anyhow::Result<SignedPlcOperation> {
76
-
let bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode op")?;
77
-
let bytes = rkey.sign(&bytes).context("failed to sign op")?;
78
-
79
-
Ok(op.sign(bytes))
80
-
}
81
-
82
-
/// Submit a PLC operation to the public directory.
83
-
pub(crate) async fn submit(
84
-
client: &Client,
85
-
did: &str,
86
-
op: &SignedPlcOperation,
87
-
) -> anyhow::Result<()> {
88
-
debug!(
89
-
"submitting {} {}",
90
-
did,
91
-
serde_json::to_string(&op).context("should serialize")?
92
-
);
93
-
94
-
let res = client
95
-
.post(format!("{PLC_DIRECTORY}{did}"))
96
-
.json(&op)
97
-
.send()
98
-
.await
99
-
.context("failed to send directory request")?;
100
-
101
-
if res.status().is_success() {
102
-
Ok(())
103
-
} else {
104
-
let e = res
105
-
.json::<serde_json::Value>()
106
-
.await
107
-
.context("failed to read error response")?;
108
-
109
-
bail!(
110
-
"error from PLC directory: {}",
111
-
serde_json::to_string(&e).context("should serialize")?
112
-
);
113
-
}
114
-
}
+429
src/serve.rs
+429
src/serve.rs
···
1
+
use super::account_manager::AccountManager;
2
+
use super::config::AppConfig;
3
+
use super::db::establish_pool;
4
+
pub use super::error::Error;
5
+
use super::service_proxy::service_proxy;
6
+
use anyhow::Context as _;
7
+
use atrium_api::types::string::Did;
8
+
use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
9
+
use axum::{Router, extract::FromRef, routing::get};
10
+
use clap::Parser;
11
+
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
12
+
use deadpool_diesel::sqlite::Pool;
13
+
use diesel::prelude::*;
14
+
use diesel_migrations::{EmbeddedMigrations, embed_migrations};
15
+
use figment::{Figment, providers::Format as _};
16
+
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
17
+
use rsky_common::env::env_list;
18
+
use rsky_identity::IdResolver;
19
+
use rsky_identity::types::{DidCache, IdentityResolverOpts};
20
+
use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer};
21
+
use serde::{Deserialize, Serialize};
22
+
use std::env;
23
+
use std::{
24
+
net::{IpAddr, Ipv4Addr, SocketAddr},
25
+
path::PathBuf,
26
+
str::FromStr as _,
27
+
sync::Arc,
28
+
};
29
+
use tokio::{net::TcpListener, sync::RwLock};
30
+
use tower_http::{cors::CorsLayer, trace::TraceLayer};
31
+
use tracing::{info, warn};
32
+
use uuid::Uuid;
33
+
34
+
/// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`.
35
+
pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),);
36
+
37
+
/// Embedded migrations
38
+
pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations");
39
+
pub const MIGRATIONS_ACTOR: EmbeddedMigrations = embed_migrations!("./migrations_actor");
40
+
41
+
/// The application-wide result type.
42
+
pub type Result<T> = std::result::Result<T, Error>;
43
+
/// The reqwest client type with middleware.
44
+
pub type Client = reqwest_middleware::ClientWithMiddleware;
45
+
46
+
#[expect(
47
+
clippy::arbitrary_source_item_ordering,
48
+
reason = "serialized data might be structured"
49
+
)]
50
+
#[derive(Serialize, Deserialize, Debug, Clone)]
51
+
/// The key data structure.
52
+
struct KeyData {
53
+
/// Primary signing key for all repo operations.
54
+
skey: Vec<u8>,
55
+
/// Primary signing (rotation) key for all PLC operations.
56
+
rkey: Vec<u8>,
57
+
}
58
+
59
+
// FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies,
60
+
// and the implementations of this algorithm are much more limited as compared to P256.
61
+
//
62
+
// Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/
63
+
#[derive(Clone)]
64
+
/// The signing key for PLC/DID operations.
65
+
pub struct SigningKey(Arc<Secp256k1Keypair>);
66
+
#[derive(Clone)]
67
+
/// The rotation key for PLC operations.
68
+
pub struct RotationKey(Arc<Secp256k1Keypair>);
69
+
70
+
impl std::ops::Deref for SigningKey {
71
+
type Target = Secp256k1Keypair;
72
+
73
+
fn deref(&self) -> &Self::Target {
74
+
&self.0
75
+
}
76
+
}
77
+
78
+
impl SigningKey {
79
+
/// Import from a private key.
80
+
pub fn import(key: &[u8]) -> Result<Self> {
81
+
let key = Secp256k1Keypair::import(key).context("failed to import signing key")?;
82
+
Ok(Self(Arc::new(key)))
83
+
}
84
+
}
85
+
86
+
impl std::ops::Deref for RotationKey {
87
+
type Target = Secp256k1Keypair;
88
+
89
+
fn deref(&self) -> &Self::Target {
90
+
&self.0
91
+
}
92
+
}
93
+
94
+
#[derive(Parser, Debug, Clone)]
95
+
/// Command line arguments.
96
+
pub struct Args {
97
+
/// Path to the configuration file
98
+
#[arg(short, long, default_value = "default.toml")]
99
+
pub config: PathBuf,
100
+
/// The verbosity level.
101
+
#[command(flatten)]
102
+
pub verbosity: Verbosity<InfoLevel>,
103
+
}
104
+
105
+
/// The actor pools for the database connections.
106
+
pub struct ActorStorage {
107
+
/// The database connection pool for the actor's repository.
108
+
pub repo: Pool,
109
+
/// The file storage path for the actor's blobs.
110
+
pub blob: PathBuf,
111
+
}
112
+
113
+
impl Clone for ActorStorage {
114
+
fn clone(&self) -> Self {
115
+
Self {
116
+
repo: self.repo.clone(),
117
+
blob: self.blob.clone(),
118
+
}
119
+
}
120
+
}
121
+
122
+
#[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")]
123
+
#[derive(Clone, FromRef)]
124
+
/// The application state, shared across all routes.
125
+
pub struct AppState {
126
+
/// The application configuration.
127
+
pub(crate) config: AppConfig,
128
+
/// The main database connection pool. Used for common PDS data, like invite codes.
129
+
pub db: Pool,
130
+
/// Actor-specific database connection pools. Hashed by DID.
131
+
pub db_actors: std::collections::HashMap<String, ActorStorage>,
132
+
133
+
/// The HTTP client with middleware.
134
+
pub client: Client,
135
+
/// The simple HTTP client.
136
+
pub simple_client: reqwest::Client,
137
+
/// The firehose producer.
138
+
pub sequencer: Arc<RwLock<Sequencer>>,
139
+
/// The account manager.
140
+
pub account_manager: Arc<RwLock<AccountManager>>,
141
+
/// The ID resolver.
142
+
pub id_resolver: Arc<RwLock<IdResolver>>,
143
+
144
+
/// The signing key.
145
+
pub signing_key: SigningKey,
146
+
/// The rotation key.
147
+
pub rotation_key: RotationKey,
148
+
}
149
+
150
+
/// The main application entry point.
151
+
#[expect(
152
+
clippy::cognitive_complexity,
153
+
clippy::too_many_lines,
154
+
unused_qualifications,
155
+
reason = "main function has high complexity"
156
+
)]
157
+
pub async fn run() -> anyhow::Result<()> {
158
+
let args = Args::parse();
159
+
160
+
// Set up trace logging to console and account for the user-provided verbosity flag.
161
+
if args.verbosity.log_level_filter() != LevelFilter::Off {
162
+
let lvl = match args.verbosity.log_level_filter() {
163
+
LevelFilter::Error => tracing::Level::ERROR,
164
+
LevelFilter::Warn => tracing::Level::WARN,
165
+
LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO,
166
+
LevelFilter::Debug => tracing::Level::DEBUG,
167
+
LevelFilter::Trace => tracing::Level::TRACE,
168
+
};
169
+
tracing_subscriber::fmt().with_max_level(lvl).init();
170
+
}
171
+
172
+
if !args.config.exists() {
173
+
// Throw up a warning if the config file does not exist.
174
+
//
175
+
// This is not fatal because users can specify all configuration settings via
176
+
// the environment, but the most likely scenario here is that a user accidentally
177
+
// omitted the config file for some reason (e.g. forgot to mount it into Docker).
178
+
warn!(
179
+
"configuration file {} does not exist",
180
+
args.config.display()
181
+
);
182
+
}
183
+
184
+
// Read and parse the user-provided configuration.
185
+
let config: AppConfig = Figment::new()
186
+
.admerge(figment::providers::Toml::file(args.config))
187
+
.admerge(figment::providers::Env::prefixed("BLUEPDS_"))
188
+
.extract()
189
+
.context("failed to load configuration")?;
190
+
191
+
if config.test {
192
+
warn!("BluePDS starting up in TEST mode.");
193
+
warn!("This means the application will not federate with the rest of the network.");
194
+
warn!(
195
+
"If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`"
196
+
);
197
+
}
198
+
199
+
// Initialize metrics reporting.
200
+
super::metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?;
201
+
202
+
// Create a reqwest client that will be used for all outbound requests.
203
+
let simple_client = reqwest::Client::builder()
204
+
.user_agent(APP_USER_AGENT)
205
+
.build()
206
+
.context("failed to build requester client")?;
207
+
let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
208
+
.with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
209
+
mode: CacheMode::Default,
210
+
manager: MokaManager::default(),
211
+
options: HttpCacheOptions::default(),
212
+
}))
213
+
.build();
214
+
215
+
tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?)
216
+
.await
217
+
.context("failed to create key directory")?;
218
+
219
+
// Check if crypto keys exist. If not, create new ones.
220
+
let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) {
221
+
let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f))
222
+
.context("failed to deserialize crypto keys")?;
223
+
224
+
let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?;
225
+
let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?;
226
+
227
+
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
228
+
} else {
229
+
info!("signing keys not found, generating new ones");
230
+
231
+
let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
232
+
let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
233
+
234
+
let keys = KeyData {
235
+
skey: skey.export(),
236
+
rkey: rkey.export(),
237
+
};
238
+
239
+
let mut f = std::fs::File::create(&config.key).context("failed to create key file")?;
240
+
serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?;
241
+
242
+
(SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
243
+
};
244
+
245
+
tokio::fs::create_dir_all(&config.repo.path).await?;
246
+
tokio::fs::create_dir_all(&config.plc.path).await?;
247
+
tokio::fs::create_dir_all(&config.blob.path).await?;
248
+
249
+
// Create a database connection manager and pool for the main database.
250
+
let pool =
251
+
establish_pool(&config.db).context("failed to establish database connection pool")?;
252
+
253
+
// Create a dictionary of database connection pools for each actor.
254
+
let mut actor_pools = std::collections::HashMap::new();
255
+
// We'll determine actors by looking in the data/repo dir for .db files.
256
+
let mut actor_dbs = tokio::fs::read_dir(&config.repo.path)
257
+
.await
258
+
.context("failed to read repo directory")?;
259
+
while let Some(entry) = actor_dbs
260
+
.next_entry()
261
+
.await
262
+
.context("failed to read repo dir")?
263
+
{
264
+
let path = entry.path();
265
+
if path.extension().and_then(|s| s.to_str()) == Some("db") {
266
+
let actor_repo_pool = establish_pool(&format!("sqlite://{}", path.display()))
267
+
.context("failed to create database connection pool")?;
268
+
269
+
let did = Did::from_str(&format!(
270
+
"did:plc:{}",
271
+
path.file_stem()
272
+
.and_then(|s| s.to_str())
273
+
.context("failed to get actor DID")?
274
+
))
275
+
.expect("should be able to parse actor DID")
276
+
.to_string();
277
+
let blob_path = config.blob.path.to_path_buf();
278
+
let actor_storage = ActorStorage {
279
+
repo: actor_repo_pool,
280
+
blob: blob_path.clone(),
281
+
};
282
+
drop(actor_pools.insert(did, actor_storage));
283
+
}
284
+
}
285
+
// Apply pending migrations
286
+
// let conn = pool.get().await?;
287
+
// conn.run_pending_migrations(MIGRATIONS)
288
+
// .expect("should be able to run migrations");
289
+
290
+
let hostname = config.host_name.clone();
291
+
let crawlers: Vec<String> = config
292
+
.firehose
293
+
.relays
294
+
.iter()
295
+
.map(|s| s.to_string())
296
+
.collect();
297
+
let sequencer = Arc::new(RwLock::new(Sequencer::new(
298
+
Crawlers::new(hostname, crawlers.clone()),
299
+
None,
300
+
)));
301
+
let account_manager = Arc::new(RwLock::new(AccountManager::new(pool.clone())));
302
+
let plc_url = if cfg!(debug_assertions) {
303
+
"http://localhost:8000".to_owned() // dummy for debug
304
+
} else {
305
+
env::var("PDS_DID_PLC_URL").unwrap_or("https://plc.directory".to_owned()) // TODO: toml config
306
+
};
307
+
let id_resolver = Arc::new(RwLock::new(IdResolver::new(IdentityResolverOpts {
308
+
timeout: None,
309
+
plc_url: Some(plc_url),
310
+
did_cache: Some(DidCache::new(None, None)),
311
+
backup_nameservers: Some(env_list("PDS_HANDLE_BACKUP_NAMESERVERS")),
312
+
})));
313
+
314
+
let addr = config
315
+
.listen_address
316
+
.unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000));
317
+
318
+
let app = Router::new()
319
+
.route("/", get(super::index))
320
+
.merge(super::oauth::routes())
321
+
.nest(
322
+
"/xrpc",
323
+
super::apis::routes()
324
+
.merge(super::actor_endpoints::routes())
325
+
.fallback(service_proxy),
326
+
)
327
+
// .layer(RateLimitLayer::new(30, Duration::from_secs(30)))
328
+
.layer(CorsLayer::permissive())
329
+
.layer(TraceLayer::new_for_http())
330
+
.with_state(AppState {
331
+
config: config.clone(),
332
+
db: pool.clone(),
333
+
db_actors: actor_pools.clone(),
334
+
client: client.clone(),
335
+
simple_client,
336
+
sequencer: sequencer.clone(),
337
+
account_manager,
338
+
id_resolver,
339
+
signing_key: skey,
340
+
rotation_key: rkey,
341
+
});
342
+
343
+
info!("listening on {addr}");
344
+
info!("connect to: http://127.0.0.1:{}", addr.port());
345
+
346
+
// Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created).
347
+
// If so, create an invite code and share it via the console.
348
+
let conn = pool.get().await.context("failed to get db connection")?;
349
+
350
+
#[derive(QueryableByName)]
351
+
struct TotalCount {
352
+
#[diesel(sql_type = diesel::sql_types::Integer)]
353
+
total_count: i32,
354
+
}
355
+
356
+
let result = conn.interact(move |conn| {
357
+
diesel::sql_query(
358
+
"SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count",
359
+
)
360
+
.get_result::<TotalCount>(conn)
361
+
})
362
+
.await
363
+
.expect("should be able to query database")?;
364
+
365
+
let c = result.total_count;
366
+
367
+
#[expect(clippy::print_stdout)]
368
+
if c == 0 {
369
+
let uuid = Uuid::new_v4().to_string();
370
+
371
+
use crate::models::pds as models;
372
+
use crate::schema::pds::invite_code::dsl as InviteCode;
373
+
let uuid_clone = uuid.clone();
374
+
drop(
375
+
conn.interact(move |conn| {
376
+
diesel::insert_into(InviteCode::invite_code)
377
+
.values(models::InviteCode {
378
+
code: uuid_clone,
379
+
available_uses: 1,
380
+
disabled: 0,
381
+
for_account: "None".to_owned(),
382
+
created_by: "None".to_owned(),
383
+
created_at: "None".to_owned(),
384
+
})
385
+
.execute(conn)
386
+
.context("failed to create new invite code")
387
+
})
388
+
.await
389
+
.expect("should be able to create invite code"),
390
+
);
391
+
392
+
// N.B: This is a sensitive message, so we're bypassing `tracing` here and
393
+
// logging it directly to console.
394
+
println!("=====================================");
395
+
println!(" FIRST STARTUP ");
396
+
println!("=====================================");
397
+
println!("Use this code to create an account:");
398
+
println!("{uuid}");
399
+
println!("=====================================");
400
+
}
401
+
402
+
let listener = TcpListener::bind(&addr)
403
+
.await
404
+
.context("failed to bind address")?;
405
+
406
+
// Serve the app, and request crawling from upstream relays.
407
+
let serve = tokio::spawn(async move {
408
+
axum::serve(listener, app.into_make_service())
409
+
.await
410
+
.context("failed to serve app")
411
+
});
412
+
413
+
// Now that the app is live, request a crawl from upstream relays.
414
+
if cfg!(debug_assertions) {
415
+
info!("debug mode: not requesting crawl");
416
+
} else {
417
+
info!("requesting crawl from upstream relays");
418
+
let mut background_sequencer = sequencer.write().await.clone();
419
+
drop(tokio::spawn(
420
+
async move { background_sequencer.start().await },
421
+
));
422
+
}
423
+
424
+
serve
425
+
.await
426
+
.map_err(Into::into)
427
+
.and_then(|r| r)
428
+
.context("failed to serve app")
429
+
}
+6
-26
src/service_proxy.rs
+6
-26
src/service_proxy.rs
···
3
3
/// Reference: <https://atproto.com/specs/xrpc#service-proxying>
4
4
use anyhow::{Context as _, anyhow};
5
5
use atrium_api::types::string::Did;
6
-
use atrium_crypto::keypair::{Export as _, Secp256k1Keypair};
7
6
use axum::{
8
-
Router,
9
7
body::Body,
10
-
extract::{FromRef, Request, State},
8
+
extract::{Request, State},
11
9
http::{self, HeaderMap, Response, StatusCode, Uri},
12
-
response::IntoResponse,
13
-
routing::get,
14
10
};
15
-
use azure_core::credentials::TokenCredential;
16
-
use clap::Parser;
17
-
use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter};
18
-
use deadpool_diesel::sqlite::Pool;
19
-
use diesel::prelude::*;
20
-
use diesel_migrations::{EmbeddedMigrations, embed_migrations};
21
-
use figment::{Figment, providers::Format as _};
22
-
use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager};
23
11
use rand::Rng as _;
24
-
use serde::{Deserialize, Serialize};
25
-
use std::{
26
-
net::{IpAddr, Ipv4Addr, SocketAddr},
27
-
path::PathBuf,
28
-
str::FromStr as _,
29
-
sync::Arc,
30
-
};
31
-
use tokio::net::TcpListener;
32
-
use tower_http::{cors::CorsLayer, trace::TraceLayer};
33
-
use tracing::{info, warn};
34
-
use uuid::Uuid;
12
+
use std::str::FromStr as _;
35
13
36
-
use super::{Client, Error, Result};
37
-
use crate::{AuthenticatedUser, SigningKey};
14
+
use super::{
15
+
auth::AuthenticatedUser,
16
+
serve::{Client, Error, Result, SigningKey},
17
+
};
38
18
39
19
pub(super) async fn service_proxy(
40
20
uri: Uri,
-459
src/tests.rs
-459
src/tests.rs
···
1
-
//! Testing utilities for the PDS.
2
-
#![expect(clippy::arbitrary_source_item_ordering)]
3
-
use std::{
4
-
net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener},
5
-
path::PathBuf,
6
-
time::{Duration, Instant},
7
-
};
8
-
9
-
use anyhow::Result;
10
-
use atrium_api::{
11
-
com::atproto::server,
12
-
types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey},
13
-
};
14
-
use figment::{Figment, providers::Format as _};
15
-
use futures::future::join_all;
16
-
use serde::{Deserialize, Serialize};
17
-
use tokio::sync::OnceCell;
18
-
use uuid::Uuid;
19
-
20
-
use crate::config::AppConfig;
21
-
22
-
/// Global test state, created once for all tests.
23
-
pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new();
24
-
25
-
/// A temporary test directory that will be cleaned up when the struct is dropped.
26
-
struct TempDir {
27
-
/// The path to the directory.
28
-
path: PathBuf,
29
-
}
30
-
31
-
impl TempDir {
32
-
/// Create a new temporary directory.
33
-
fn new() -> Result<Self> {
34
-
let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4()));
35
-
std::fs::create_dir_all(&path)?;
36
-
Ok(Self { path })
37
-
}
38
-
39
-
/// Get the path to the directory.
40
-
fn path(&self) -> &PathBuf {
41
-
&self.path
42
-
}
43
-
}
44
-
45
-
impl Drop for TempDir {
46
-
fn drop(&mut self) {
47
-
drop(std::fs::remove_dir_all(&self.path));
48
-
}
49
-
}
50
-
51
-
/// Test state for the application.
52
-
pub(crate) struct TestState {
53
-
/// The address the test server is listening on.
54
-
address: SocketAddr,
55
-
/// The HTTP client.
56
-
client: reqwest::Client,
57
-
/// The application configuration.
58
-
config: AppConfig,
59
-
/// The temporary directory for test data.
60
-
#[expect(dead_code)]
61
-
temp_dir: TempDir,
62
-
}
63
-
64
-
impl TestState {
65
-
/// Get a base URL for the test server.
66
-
pub(crate) fn base_url(&self) -> String {
67
-
format!("http://{}", self.address)
68
-
}
69
-
70
-
/// Create a test account.
71
-
pub(crate) async fn create_test_account(&self) -> Result<TestAccount> {
72
-
// Create the account
73
-
let handle = "test.handle";
74
-
let response = self
75
-
.client
76
-
.post(format!(
77
-
"http://{}/xrpc/com.atproto.server.createAccount",
78
-
self.address
79
-
))
80
-
.json(&server::create_account::InputData {
81
-
did: None,
82
-
verification_code: None,
83
-
verification_phone: None,
84
-
email: Some(format!("{}@example.com", &handle)),
85
-
handle: Handle::new(handle.to_owned()).expect("should be able to create handle"),
86
-
password: Some("password123".to_owned()),
87
-
invite_code: None,
88
-
recovery_key: None,
89
-
plc_op: None,
90
-
})
91
-
.send()
92
-
.await?;
93
-
94
-
let account: server::create_account::Output = response.json().await?;
95
-
96
-
Ok(TestAccount {
97
-
handle: handle.to_owned(),
98
-
did: account.did.to_string(),
99
-
access_token: account.access_jwt.clone(),
100
-
refresh_token: account.refresh_jwt.clone(),
101
-
})
102
-
}
103
-
104
-
/// Create a new test state.
105
-
#[expect(clippy::unused_async)]
106
-
async fn new() -> Result<Self> {
107
-
// Configure the test app
108
-
#[derive(Serialize, Deserialize)]
109
-
struct TestConfigInput {
110
-
db: Option<String>,
111
-
host_name: Option<String>,
112
-
key: Option<PathBuf>,
113
-
listen_address: Option<SocketAddr>,
114
-
test: Option<bool>,
115
-
}
116
-
// Create a temporary directory for test data
117
-
let temp_dir = TempDir::new()?;
118
-
119
-
// Find a free port
120
-
let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?;
121
-
let address = listener.local_addr()?;
122
-
drop(listener);
123
-
124
-
let test_config = TestConfigInput {
125
-
db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())),
126
-
host_name: Some(format!("localhost:{}", address.port())),
127
-
key: Some(temp_dir.path().join("test.key")),
128
-
listen_address: Some(address),
129
-
test: Some(true),
130
-
};
131
-
132
-
let config: AppConfig = Figment::new()
133
-
.admerge(figment::providers::Toml::file("default.toml"))
134
-
.admerge(figment::providers::Env::prefixed("BLUEPDS_"))
135
-
.merge(figment::providers::Serialized::defaults(test_config))
136
-
.merge(
137
-
figment::providers::Toml::string(
138
-
r#"
139
-
[firehose]
140
-
relays = []
141
-
142
-
[repo]
143
-
path = "repo"
144
-
145
-
[plc]
146
-
path = "plc"
147
-
148
-
[blob]
149
-
path = "blob"
150
-
limit = 10485760 # 10 MB
151
-
"#,
152
-
)
153
-
.nested(),
154
-
)
155
-
.extract()?;
156
-
157
-
// Create directories
158
-
std::fs::create_dir_all(temp_dir.path().join("repo"))?;
159
-
std::fs::create_dir_all(temp_dir.path().join("plc"))?;
160
-
std::fs::create_dir_all(temp_dir.path().join("blob"))?;
161
-
162
-
// Create client
163
-
let client = reqwest::Client::builder()
164
-
.timeout(Duration::from_secs(30))
165
-
.build()?;
166
-
167
-
Ok(Self {
168
-
address,
169
-
client,
170
-
config,
171
-
temp_dir,
172
-
})
173
-
}
174
-
175
-
/// Start the application in a background task.
176
-
async fn start_app(&self) -> Result<()> {
177
-
// // Get a reference to the config that can be moved into the task
178
-
// let config = self.config.clone();
179
-
// let address = self.address;
180
-
181
-
// // Start the application in a background task
182
-
// let _handle = tokio::spawn(async move {
183
-
// // Set up the application
184
-
// use crate::*;
185
-
186
-
// // Initialize metrics (noop in test mode)
187
-
// drop(metrics::setup(None));
188
-
189
-
// // Create client
190
-
// let simple_client = reqwest::Client::builder()
191
-
// .user_agent(APP_USER_AGENT)
192
-
// .build()
193
-
// .context("failed to build requester client")?;
194
-
// let client = reqwest_middleware::ClientBuilder::new(simple_client.clone())
195
-
// .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache {
196
-
// mode: CacheMode::Default,
197
-
// manager: MokaManager::default(),
198
-
// options: HttpCacheOptions::default(),
199
-
// }))
200
-
// .build();
201
-
202
-
// // Create a test keypair
203
-
// std::fs::create_dir_all(config.key.parent().context("should have parent")?)?;
204
-
// let (skey, rkey) = {
205
-
// let skey = Secp256k1Keypair::create(&mut rand::thread_rng());
206
-
// let rkey = Secp256k1Keypair::create(&mut rand::thread_rng());
207
-
208
-
// let keys = KeyData {
209
-
// skey: skey.export(),
210
-
// rkey: rkey.export(),
211
-
// };
212
-
213
-
// let mut f =
214
-
// std::fs::File::create(&config.key).context("failed to create key file")?;
215
-
// serde_ipld_dagcbor::to_writer(&mut f, &keys)
216
-
// .context("failed to serialize crypto keys")?;
217
-
218
-
// (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey)))
219
-
// };
220
-
221
-
// // Set up database
222
-
// let opts = SqliteConnectOptions::from_str(&config.db)
223
-
// .context("failed to parse database options")?
224
-
// .create_if_missing(true);
225
-
// let db = SqliteDbConn::connect_with(opts).await?;
226
-
227
-
// sqlx::migrate!()
228
-
// .run(&db)
229
-
// .await
230
-
// .context("failed to apply migrations")?;
231
-
232
-
// // Create firehose
233
-
// let (_fh, fhp) = firehose::spawn(client.clone(), config.clone());
234
-
235
-
// // Create the application state
236
-
// let app_state = AppState {
237
-
// cred: azure_identity::DefaultAzureCredential::new()?,
238
-
// config: config.clone(),
239
-
// db: db.clone(),
240
-
// client: client.clone(),
241
-
// simple_client,
242
-
// firehose: fhp,
243
-
// signing_key: skey,
244
-
// rotation_key: rkey,
245
-
// };
246
-
247
-
// // Create the router
248
-
// let app = Router::new()
249
-
// .route("/", get(index))
250
-
// .merge(oauth::routes())
251
-
// .nest(
252
-
// "/xrpc",
253
-
// endpoints::routes()
254
-
// .merge(actor_endpoints::routes())
255
-
// .fallback(service_proxy),
256
-
// )
257
-
// .layer(CorsLayer::permissive())
258
-
// .layer(TraceLayer::new_for_http())
259
-
// .with_state(app_state);
260
-
261
-
// // Listen for connections
262
-
// let listener = TcpListener::bind(&address)
263
-
// .await
264
-
// .context("failed to bind address")?;
265
-
266
-
// axum::serve(listener, app.into_make_service())
267
-
// .await
268
-
// .context("failed to serve app")
269
-
// });
270
-
271
-
// // Give the server a moment to start
272
-
// tokio::time::sleep(Duration::from_millis(500)).await;
273
-
274
-
Ok(())
275
-
}
276
-
}
277
-
278
-
/// A test account that can be used for testing.
279
-
pub(crate) struct TestAccount {
280
-
/// The access token for the account.
281
-
pub(crate) access_token: String,
282
-
/// The account DID.
283
-
pub(crate) did: String,
284
-
/// The account handle.
285
-
pub(crate) handle: String,
286
-
/// The refresh token for the account.
287
-
#[expect(dead_code)]
288
-
pub(crate) refresh_token: String,
289
-
}
290
-
291
-
/// Initialize the test state.
292
-
pub(crate) async fn init_test_state() -> Result<&'static TestState> {
293
-
async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> {
294
-
let state = TestState::new().await?;
295
-
state.start_app().await?;
296
-
Ok(state)
297
-
}
298
-
TEST_STATE.get_or_try_init(init_test_state).await
299
-
}
300
-
301
-
/// Create a record benchmark that creates records and measures the time it takes.
302
-
#[expect(
303
-
clippy::arithmetic_side_effects,
304
-
clippy::integer_division,
305
-
clippy::integer_division_remainder_used,
306
-
clippy::use_debug,
307
-
clippy::print_stdout
308
-
)]
309
-
pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> {
310
-
// Initialize the test state
311
-
let state = init_test_state().await?;
312
-
313
-
// Create a test account
314
-
let account = state.create_test_account().await?;
315
-
316
-
// Create the client with authorization
317
-
let client = reqwest::Client::builder()
318
-
.timeout(Duration::from_secs(30))
319
-
.build()?;
320
-
321
-
let start = Instant::now();
322
-
323
-
// Split the work into batches
324
-
let mut handles = Vec::new();
325
-
for batch_idx in 0..concurrent {
326
-
let batch_size = count / concurrent;
327
-
let client = client.clone();
328
-
let base_url = state.base_url();
329
-
let account_did = account.did.clone();
330
-
let account_handle = account.handle.clone();
331
-
let access_token = account.access_token.clone();
332
-
333
-
let handle = tokio::spawn(async move {
334
-
let mut results = Vec::new();
335
-
336
-
for i in 0..batch_size {
337
-
let request_start = Instant::now();
338
-
let record_idx = batch_idx * batch_size + i;
339
-
340
-
let result = client
341
-
.post(format!("{base_url}/xrpc/com.atproto.repo.createRecord"))
342
-
.header("Authorization", format!("Bearer {access_token}"))
343
-
.json(&atrium_api::com::atproto::repo::create_record::InputData {
344
-
repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")),
345
-
collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"),
346
-
rkey: Some(
347
-
RecordKey::new(format!("test-{record_idx}")).expect("valid record key"),
348
-
),
349
-
validate: None,
350
-
record: serde_json::from_str(
351
-
&serde_json::json!({
352
-
"$type": "app.bsky.feed.post",
353
-
"text": format!("Test post {record_idx} from {account_handle}"),
354
-
"createdAt": chrono::Utc::now().to_rfc3339(),
355
-
})
356
-
.to_string(),
357
-
)
358
-
.expect("valid JSON record"),
359
-
swap_commit: None,
360
-
})
361
-
.send()
362
-
.await;
363
-
364
-
// Fetch the record we just created
365
-
let get_response = client
366
-
.get(format!(
367
-
"{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}"
368
-
))
369
-
.header("Authorization", format!("Bearer {access_token}"))
370
-
.send()
371
-
.await;
372
-
if get_response.is_err() {
373
-
println!("Failed to fetch record {record_idx}: {get_response:?}");
374
-
results.push(get_response);
375
-
continue;
376
-
}
377
-
378
-
let request_duration = request_start.elapsed();
379
-
if record_idx % 10 == 0 {
380
-
println!("Created record {record_idx} in {request_duration:?}");
381
-
}
382
-
results.push(result);
383
-
}
384
-
385
-
results
386
-
});
387
-
388
-
handles.push(handle);
389
-
}
390
-
391
-
// Wait for all batches to complete
392
-
let results = join_all(handles).await;
393
-
394
-
// Check for errors
395
-
for batch_result in results {
396
-
let batch_responses = batch_result?;
397
-
for response_result in batch_responses {
398
-
match response_result {
399
-
Ok(response) => {
400
-
if !response.status().is_success() {
401
-
return Err(anyhow::anyhow!(
402
-
"Failed to create record: {}",
403
-
response.status()
404
-
));
405
-
}
406
-
}
407
-
Err(err) => {
408
-
return Err(anyhow::anyhow!("Failed to create record: {}", err));
409
-
}
410
-
}
411
-
}
412
-
}
413
-
414
-
let duration = start.elapsed();
415
-
Ok(duration)
416
-
}
417
-
418
-
#[cfg(test)]
419
-
#[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)]
420
-
mod tests {
421
-
use super::*;
422
-
use anyhow::anyhow;
423
-
424
-
#[tokio::test]
425
-
async fn test_create_account() -> Result<()> {
426
-
return Ok(());
427
-
#[expect(unreachable_code, reason = "Disabled")]
428
-
let state = init_test_state().await?;
429
-
let account = state.create_test_account().await?;
430
-
431
-
println!("Created test account: {}", account.handle);
432
-
if account.handle.is_empty() {
433
-
return Err(anyhow::anyhow!("Account handle is empty"));
434
-
}
435
-
if account.did.is_empty() {
436
-
return Err(anyhow::anyhow!("Account DID is empty"));
437
-
}
438
-
if account.access_token.is_empty() {
439
-
return Err(anyhow::anyhow!("Account access token is empty"));
440
-
}
441
-
442
-
Ok(())
443
-
}
444
-
445
-
#[tokio::test]
446
-
async fn test_create_record_benchmark() -> Result<()> {
447
-
return Ok(());
448
-
#[expect(unreachable_code, reason = "Disabled")]
449
-
let duration = create_record_benchmark(100, 1).await?;
450
-
451
-
println!("Created 100 records in {duration:?}");
452
-
453
-
if duration.as_secs() >= 10 {
454
-
return Err(anyhow!("Benchmark took too long"));
455
-
}
456
-
457
-
Ok(())
458
-
}
459
-
}