Alternative ATProto PDS implementation

Compare changes

Choose any two refs to compare.

+2
Cargo.lock
··· 1310 1310 "reqwest 0.12.15", 1311 1311 "reqwest-middleware", 1312 1312 "rsky-common", 1313 + "rsky-identity", 1313 1314 "rsky-lexicon", 1314 1315 "rsky-pds", 1315 1316 "rsky-repo", ··· 1325 1326 "tower-http", 1326 1327 "tracing", 1327 1328 "tracing-subscriber", 1329 + "ubyte", 1328 1330 "url", 1329 1331 "urlencoding", 1330 1332 "uuid 1.16.0",
+3 -1
Cargo.toml
··· 79 79 # unstable-features = "allow" 80 80 # # Temporary Allows 81 81 dead_code = "allow" 82 - unused_imports = "allow" 82 + # unused_imports = "allow" 83 83 84 84 [lints.clippy] 85 85 # Groups ··· 158 158 rsky-pds = { git = "https://github.com/blacksky-algorithms/rsky.git" } 159 159 rsky-common = { git = "https://github.com/blacksky-algorithms/rsky.git" } 160 160 rsky-lexicon = { git = "https://github.com/blacksky-algorithms/rsky.git" } 161 + rsky-identity = { git = "https://github.com/blacksky-algorithms/rsky.git" } 161 162 162 163 # async in streams 163 164 # async-stream = "0.3" ··· 268 269 "sqlite", 269 270 "tracing", 270 271 ] } 272 + ubyte = "0.10.4"
+31 -118
README.md
··· 11 11 \/_/ 12 12 ``` 13 13 14 - This is an implementation of an ATProto PDS, built with [Axum](https://github.com/tokio-rs/axum) and [Atrium](https://github.com/sugyan/atrium). 15 - This PDS implementation uses a SQLite database to store private account information and file storage to store canonical user data. 14 + This is an implementation of an ATProto PDS, built with [Axum](https://github.com/tokio-rs/axum), [rsky](https://github.com/blacksky-algorithms/rsky/) and [Atrium](https://github.com/sugyan/atrium). 15 + This PDS implementation uses a SQLite database and [diesel.rs](https://diesel.rs/) ORM to store canonical user data, and file system storage to store user blobs. 16 16 17 17 Heavily inspired by David Buchanan's [millipds](https://github.com/DavidBuchanan314/millipds). 18 - This implementation forked from the [azure-rust-app](https://github.com/DrChat/azure-rust-app) starter template and the upstream [DrChat/bluepds](https://github.com/DrChat/bluepds). 19 - See TODO below for this fork's changes from upstream. 18 + This implementation forked from [DrChat/bluepds](https://github.com/DrChat/bluepds), and now makes heavy use of the [rsky-repo](https://github.com/blacksky-algorithms/rsky/tree/main/rsky-repo) repository implementation. 19 + The `actor_store` and `account_manager` modules have been reimplemented from [rsky-pds](https://github.com/blacksky-algorithms/rsky/tree/main/rsky-pds) to use a SQLite backend and file storage, which are themselves adapted from the [original Bluesky implementation](https://github.com/bluesky-social/atproto) using SQLite in Typescript. 20 + 20 21 21 22 If you want to see this fork in action, there is a live account hosted by this PDS at [@teq.shatteredsky.net](https://bsky.app/profile/teq.shatteredsky.net)! 22 23 23 24 > [!WARNING] 24 - > This PDS is undergoing heavy development. Do _NOT_ use this to host your primary account or any important data! 25 + > This PDS is undergoing heavy development, and this branch is not at an operable release. Do _NOT_ use this to host your primary account or any important data! 25 26 26 27 ## Quick Start 27 28 ``` ··· 43 44 - Size: 47 GB 44 45 - VPUs/GB: 10 45 46 46 - This is about half of the 3,000 OCPU hours and 18,000 GB hours available per month for free on the VM.Standard.A1.Flex shape. This is _without_ optimizing for costs. The PDS can likely be made much cheaper. 47 - 48 - ## Code map 49 - ``` 50 - * migrations/ - SQLite database migrations 51 - * src/ 52 - * endpoints/ - ATProto API endpoints 53 - * auth.rs - Authentication primitives 54 - * config.rs - Application configuration 55 - * did.rs - Decentralized Identifier helpers 56 - * error.rs - Axum error helpers 57 - * firehose.rs - ATProto firehose producer 58 - * main.rs - Main entrypoint 59 - * metrics.rs - Definitions for telemetry instruments 60 - * oauth.rs - OAuth routes 61 - * plc.rs - Functionality to access the Public Ledger of Credentials 62 - * storage.rs - Helpers to access user repository storage 63 - ``` 47 + This is about half of the 3,000 OCPU hours and 18,000 GB hours available per month for free on the VM.Standard.A1.Flex shape. This is _without_ optimizing for costs. The PDS can likely be made to run on much less resources. 64 48 65 49 ## To-do 66 - ### Teq's fork 67 - - [ ] OAuth 68 - - [X] `/.well-known/oauth-protected-resource` - Authorization Server Metadata 69 - - [X] `/.well-known/oauth-authorization-server` 70 - - [X] `/par` - Pushed Authorization Request 71 - - [X] `/client-metadata.json` - Client metadata discovery 72 - - [X] `/oauth/authorize` 73 - - [X] `/oauth/authorize/sign-in` 74 - - [X] `/oauth/token` 75 - - [ ] Authorization flow - Backend client 76 - - [X] Authorization flow - Serverless browser app 77 - - [ ] DPoP-Nonce 78 - - [ ] Verify JWT signature with JWK 79 - - [ ] Email verification 80 - - [ ] 2FA 81 - - [ ] Admin endpoints 82 - - [ ] App passwords 83 - - [X] `listRecords` fixes 84 - - [X] Fix collection prefixing (terminate with `/`) 85 - - [X] Fix cursor handling (return `cid` instead of `key`) 86 - - [X] Session management (JWT) 87 - - [X] Match token fields to reference implementation 88 - - [X] RefreshSession from Bluesky Client 89 - - [X] Respond with JSON error message `ExpiredToken` 90 - - [X] Cursor handling 91 - - [X] Implement time-based unix microsecond sequences 92 - - [X] Startup with present cursor 93 - - [X] Respond `RecordNotFound`, required for: 94 - - [X] app.bsky.feed.postgate 95 - - [X] app.bsky.feed.threadgate 96 - - [ ] app.bsky... (profile creation?) 97 - - [X] Linting 98 - - [X] Rustfmt 99 - - [X] warnings 100 - - [X] deprecated-safe 101 - - [X] future-incompatible 102 - - [X] keyword-idents 103 - - [X] let-underscore 104 - - [X] nonstandard-style 105 - - [X] refining-impl-trait 106 - - [X] rust-2018-idioms 107 - - [X] rust-2018/2021/2024-compatibility 108 - - [X] ungrouped 109 - - [X] Clippy 110 - - [X] nursery 111 - - [X] correctness 112 - - [X] suspicious 113 - - [X] complexity 114 - - [X] perf 115 - - [X] style 116 - - [X] pedantic 117 - - [X] cargo 118 - - [X] ungrouped 119 - 120 - ### High-level features 121 - - [ ] Storage backend abstractions 122 - - [ ] Azure blob storage backend 123 - - [ ] Backblaze b2(?) 124 - - [ ] Telemetry 125 - - [X] [Metrics](https://github.com/metrics-rs/metrics) (counters/gauges/etc) 126 - - [X] Exporters for common backends (Prometheus/etc) 127 - 128 50 ### APIs 129 - - [X] [Service proxying](https://atproto.com/specs/xrpc#service-proxying) 130 - - [X] UG /xrpc/_health (undocumented, but impl by reference PDS) 51 + - [ ] [Service proxying](https://atproto.com/specs/xrpc#service-proxying) 52 + - [ ] UG /xrpc/_health (undocumented, but impl by reference PDS) 131 53 <!-- - [ ] xx /xrpc/app.bsky.notification.registerPush 132 54 - app.bsky.actor 133 - - [X] AG /xrpc/app.bsky.actor.getPreferences 55 + - [ ] AG /xrpc/app.bsky.actor.getPreferences 134 56 - [ ] xx /xrpc/app.bsky.actor.getProfile 135 57 - [ ] xx /xrpc/app.bsky.actor.getProfiles 136 - - [X] AP /xrpc/app.bsky.actor.putPreferences 58 + - [ ] AP /xrpc/app.bsky.actor.putPreferences 137 59 - app.bsky.feed 138 60 - [ ] xx /xrpc/app.bsky.feed.getActorLikes 139 61 - [ ] xx /xrpc/app.bsky.feed.getAuthorFeed ··· 157 79 - com.atproto.identity 158 80 - [ ] xx /xrpc/com.atproto.identity.getRecommendedDidCredentials 159 81 - [ ] AP /xrpc/com.atproto.identity.requestPlcOperationSignature 160 - - [X] UG /xrpc/com.atproto.identity.resolveHandle 82 + - [ ] UG /xrpc/com.atproto.identity.resolveHandle 161 83 - [ ] AP /xrpc/com.atproto.identity.signPlcOperation 162 84 - [ ] xx /xrpc/com.atproto.identity.submitPlcOperation 163 - - [X] AP /xrpc/com.atproto.identity.updateHandle 85 + - [ ] AP /xrpc/com.atproto.identity.updateHandle 164 86 <!-- - com.atproto.moderation 165 87 - [ ] xx /xrpc/com.atproto.moderation.createReport --> 166 88 - com.atproto.repo ··· 169 91 - [X] AP /xrpc/com.atproto.repo.deleteRecord 170 92 - [X] UG /xrpc/com.atproto.repo.describeRepo 171 93 - [X] UG /xrpc/com.atproto.repo.getRecord 172 - - [ ] xx /xrpc/com.atproto.repo.importRepo 173 - - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs 94 + - [X] xx /xrpc/com.atproto.repo.importRepo 95 + - [X] xx /xrpc/com.atproto.repo.listMissingBlobs 174 96 - [X] UG /xrpc/com.atproto.repo.listRecords 175 97 - [X] AP /xrpc/com.atproto.repo.putRecord 176 98 - [X] AP /xrpc/com.atproto.repo.uploadBlob ··· 178 100 - [ ] xx /xrpc/com.atproto.server.activateAccount 179 101 - [ ] xx /xrpc/com.atproto.server.checkAccountStatus 180 102 - [ ] xx /xrpc/com.atproto.server.confirmEmail 181 - - [X] UP /xrpc/com.atproto.server.createAccount 103 + - [ ] UP /xrpc/com.atproto.server.createAccount 182 104 - [ ] xx /xrpc/com.atproto.server.createAppPassword 183 - - [X] AP /xrpc/com.atproto.server.createInviteCode 105 + - [ ] AP /xrpc/com.atproto.server.createInviteCode 184 106 - [ ] xx /xrpc/com.atproto.server.createInviteCodes 185 - - [X] UP /xrpc/com.atproto.server.createSession 107 + - [ ] UP /xrpc/com.atproto.server.createSession 186 108 - [ ] xx /xrpc/com.atproto.server.deactivateAccount 187 109 - [ ] xx /xrpc/com.atproto.server.deleteAccount 188 110 - [ ] xx /xrpc/com.atproto.server.deleteSession 189 - - [X] UG /xrpc/com.atproto.server.describeServer 111 + - [ ] UG /xrpc/com.atproto.server.describeServer 190 112 - [ ] xx /xrpc/com.atproto.server.getAccountInviteCodes 191 - - [X] AG /xrpc/com.atproto.server.getServiceAuth 192 - - [X] AG /xrpc/com.atproto.server.getSession 113 + - [ ] AG /xrpc/com.atproto.server.getServiceAuth 114 + - [ ] AG /xrpc/com.atproto.server.getSession 193 115 - [ ] xx /xrpc/com.atproto.server.listAppPasswords 194 116 - [ ] xx /xrpc/com.atproto.server.refreshSession 195 117 - [ ] xx /xrpc/com.atproto.server.requestAccountDelete ··· 201 123 - [ ] xx /xrpc/com.atproto.server.revokeAppPassword 202 124 - [ ] xx /xrpc/com.atproto.server.updateEmail 203 125 - com.atproto.sync 204 - - [X] UG /xrpc/com.atproto.sync.getBlob 205 - - [X] UG /xrpc/com.atproto.sync.getBlocks 206 - - [X] UG /xrpc/com.atproto.sync.getLatestCommit 207 - - [X] UG /xrpc/com.atproto.sync.getRecord 208 - - [X] UG /xrpc/com.atproto.sync.getRepo 209 - - [X] UG /xrpc/com.atproto.sync.getRepoStatus 210 - - [X] UG /xrpc/com.atproto.sync.listBlobs 211 - - [X] UG /xrpc/com.atproto.sync.listRepos 212 - - [X] UG /xrpc/com.atproto.sync.subscribeRepos 126 + - [ ] UG /xrpc/com.atproto.sync.getBlob 127 + - [ ] UG /xrpc/com.atproto.sync.getBlocks 128 + - [ ] UG /xrpc/com.atproto.sync.getLatestCommit 129 + - [ ] UG /xrpc/com.atproto.sync.getRecord 130 + - [ ] UG /xrpc/com.atproto.sync.getRepo 131 + - [ ] UG /xrpc/com.atproto.sync.getRepoStatus 132 + - [ ] UG /xrpc/com.atproto.sync.listBlobs 133 + - [ ] UG /xrpc/com.atproto.sync.listRepos 134 + - [ ] UG /xrpc/com.atproto.sync.subscribeRepos 213 135 214 - ## Quick Deployment (Azure CLI) 215 - ``` 216 - az group create --name "webapp" --location southcentralus 217 - az deployment group create --resource-group "webapp" --template-file .\deployment.bicep --parameters webAppName=testapp 218 - 219 - az acr login --name <insert name of ACR resource here> 220 - docker build -t <ACR>.azurecr.io/testapp:latest . 221 - docker push <ACR>.azurecr.io/testapp:latest 222 - ``` 223 - ## Quick Deployment (NixOS) 136 + ## Deployment (NixOS) 224 137 ```nix 225 138 { 226 139 inputs = {
-12
migrations/2025-05-15-182818_init_diff/down.sql
··· 1 - -- This file should undo anything in `up.sql` 2 - DROP TABLE IF EXISTS `oauth_refresh_tokens`; 3 1 DROP TABLE IF EXISTS `repo_seq`; 4 - DROP TABLE IF EXISTS `blob`; 5 - DROP TABLE IF EXISTS `oauth_used_jtis`; 6 2 DROP TABLE IF EXISTS `app_password`; 7 - DROP TABLE IF EXISTS `repo_block`; 8 3 DROP TABLE IF EXISTS `device_account`; 9 - DROP TABLE IF EXISTS `backlink`; 10 4 DROP TABLE IF EXISTS `actor`; 11 5 DROP TABLE IF EXISTS `device`; 12 6 DROP TABLE IF EXISTS `did_doc`; 13 7 DROP TABLE IF EXISTS `email_token`; 14 8 DROP TABLE IF EXISTS `invite_code`; 15 - DROP TABLE IF EXISTS `oauth_par_requests`; 16 - DROP TABLE IF EXISTS `record`; 17 - DROP TABLE IF EXISTS `repo_root`; 18 9 DROP TABLE IF EXISTS `used_refresh_token`; 19 10 DROP TABLE IF EXISTS `invite_code_use`; 20 - DROP TABLE IF EXISTS `oauth_authorization_codes`; 21 11 DROP TABLE IF EXISTS `authorization_request`; 22 12 DROP TABLE IF EXISTS `token`; 23 13 DROP TABLE IF EXISTS `refresh_token`; 24 - DROP TABLE IF EXISTS `account_pref`; 25 - DROP TABLE IF EXISTS `record_blob`; 26 14 DROP TABLE IF EXISTS `account`;
-108
migrations/2025-05-15-182818_init_diff/up.sql
··· 1 - CREATE TABLE `oauth_refresh_tokens`( 2 - `token` VARCHAR NOT NULL PRIMARY KEY, 3 - `client_id` VARCHAR NOT NULL, 4 - `subject` VARCHAR NOT NULL, 5 - `dpop_thumbprint` VARCHAR NOT NULL, 6 - `scope` VARCHAR, 7 - `created_at` INT8 NOT NULL, 8 - `expires_at` INT8 NOT NULL, 9 - `revoked` BOOL NOT NULL 10 - ); 11 - 12 1 CREATE TABLE `repo_seq`( 13 2 `seq` INT8 NOT NULL PRIMARY KEY, 14 3 `did` VARCHAR NOT NULL, ··· 18 7 `sequencedat` VARCHAR NOT NULL 19 8 ); 20 9 21 - CREATE TABLE `blob`( 22 - `cid` VARCHAR NOT NULL, 23 - `did` VARCHAR NOT NULL, 24 - `mimetype` VARCHAR NOT NULL, 25 - `size` INT4 NOT NULL, 26 - `tempkey` VARCHAR, 27 - `width` INT4, 28 - `height` INT4, 29 - `createdat` VARCHAR NOT NULL, 30 - `takedownref` VARCHAR, 31 - PRIMARY KEY(`cid`, `did`) 32 - ); 33 - 34 - CREATE TABLE `oauth_used_jtis`( 35 - `jti` VARCHAR NOT NULL PRIMARY KEY, 36 - `issuer` VARCHAR NOT NULL, 37 - `created_at` INT8 NOT NULL, 38 - `expires_at` INT8 NOT NULL 39 - ); 40 - 41 10 CREATE TABLE `app_password`( 42 11 `did` VARCHAR NOT NULL, 43 12 `name` VARCHAR NOT NULL, ··· 46 15 PRIMARY KEY(`did`, `name`) 47 16 ); 48 17 49 - CREATE TABLE `repo_block`( 50 - `cid` VARCHAR NOT NULL, 51 - `did` VARCHAR NOT NULL, 52 - `reporev` VARCHAR NOT NULL, 53 - `size` INT4 NOT NULL, 54 - `content` BYTEA NOT NULL, 55 - PRIMARY KEY(`cid`, `did`) 56 - ); 57 - 58 18 CREATE TABLE `device_account`( 59 19 `did` VARCHAR NOT NULL, 60 20 `deviceid` VARCHAR NOT NULL, ··· 62 22 `remember` BOOL NOT NULL, 63 23 `authorizedclients` VARCHAR NOT NULL, 64 24 PRIMARY KEY(`deviceId`, `did`) 65 - ); 66 - 67 - CREATE TABLE `backlink`( 68 - `uri` VARCHAR NOT NULL, 69 - `path` VARCHAR NOT NULL, 70 - `linkto` VARCHAR NOT NULL, 71 - PRIMARY KEY(`uri`, `path`) 72 25 ); 73 26 74 27 CREATE TABLE `actor`( ··· 111 64 `createdat` VARCHAR NOT NULL 112 65 ); 113 66 114 - CREATE TABLE `oauth_par_requests`( 115 - `request_uri` VARCHAR NOT NULL PRIMARY KEY, 116 - `client_id` VARCHAR NOT NULL, 117 - `response_type` VARCHAR NOT NULL, 118 - `code_challenge` VARCHAR NOT NULL, 119 - `code_challenge_method` VARCHAR NOT NULL, 120 - `state` VARCHAR, 121 - `login_hint` VARCHAR, 122 - `scope` VARCHAR, 123 - `redirect_uri` VARCHAR, 124 - `response_mode` VARCHAR, 125 - `display` VARCHAR, 126 - `created_at` INT8 NOT NULL, 127 - `expires_at` INT8 NOT NULL 128 - ); 129 - 130 - CREATE TABLE `record`( 131 - `uri` VARCHAR NOT NULL PRIMARY KEY, 132 - `cid` VARCHAR NOT NULL, 133 - `did` VARCHAR NOT NULL, 134 - `collection` VARCHAR NOT NULL, 135 - `rkey` VARCHAR NOT NULL, 136 - `reporev` VARCHAR, 137 - `indexedat` VARCHAR NOT NULL, 138 - `takedownref` VARCHAR 139 - ); 140 - 141 - CREATE TABLE `repo_root`( 142 - `did` VARCHAR NOT NULL PRIMARY KEY, 143 - `cid` VARCHAR NOT NULL, 144 - `rev` VARCHAR NOT NULL, 145 - `indexedat` VARCHAR NOT NULL 146 - ); 147 - 148 67 CREATE TABLE `used_refresh_token`( 149 68 `refreshtoken` VARCHAR NOT NULL PRIMARY KEY, 150 69 `tokenid` VARCHAR NOT NULL ··· 157 76 PRIMARY KEY(`code`, `usedBy`) 158 77 ); 159 78 160 - CREATE TABLE `oauth_authorization_codes`( 161 - `code` VARCHAR NOT NULL PRIMARY KEY, 162 - `client_id` VARCHAR NOT NULL, 163 - `subject` VARCHAR NOT NULL, 164 - `code_challenge` VARCHAR NOT NULL, 165 - `code_challenge_method` VARCHAR NOT NULL, 166 - `redirect_uri` VARCHAR NOT NULL, 167 - `scope` VARCHAR, 168 - `created_at` INT8 NOT NULL, 169 - `expires_at` INT8 NOT NULL, 170 - `used` BOOL NOT NULL 171 - ); 172 - 173 79 CREATE TABLE `authorization_request`( 174 80 `id` VARCHAR NOT NULL PRIMARY KEY, 175 81 `did` VARCHAR, ··· 203 109 `expiresat` VARCHAR NOT NULL, 204 110 `nextid` VARCHAR, 205 111 `apppasswordname` VARCHAR 206 - ); 207 - 208 - CREATE TABLE `account_pref`( 209 - `id` INT4 NOT NULL PRIMARY KEY, 210 - `did` VARCHAR NOT NULL, 211 - `name` VARCHAR NOT NULL, 212 - `valuejson` TEXT 213 - ); 214 - 215 - CREATE TABLE `record_blob`( 216 - `blobcid` VARCHAR NOT NULL, 217 - `recorduri` VARCHAR NOT NULL, 218 - `did` VARCHAR NOT NULL, 219 - PRIMARY KEY(`blobCid`, `recordUri`) 220 112 ); 221 113 222 114 CREATE TABLE `account`(
+4
migrations/2025-05-17-094600_oauth_temp/down.sql
··· 1 + DROP TABLE IF EXISTS `oauth_refresh_tokens`; 2 + DROP TABLE IF EXISTS `oauth_used_jtis`; 3 + DROP TABLE IF EXISTS `oauth_par_requests`; 4 + DROP TABLE IF EXISTS `oauth_authorization_codes`;
+46
migrations/2025-05-17-094600_oauth_temp/up.sql
··· 1 + CREATE TABLE `oauth_refresh_tokens`( 2 + `token` VARCHAR NOT NULL PRIMARY KEY, 3 + `client_id` VARCHAR NOT NULL, 4 + `subject` VARCHAR NOT NULL, 5 + `dpop_thumbprint` VARCHAR NOT NULL, 6 + `scope` VARCHAR, 7 + `created_at` INT8 NOT NULL, 8 + `expires_at` INT8 NOT NULL, 9 + `revoked` BOOL NOT NULL 10 + ); 11 + 12 + CREATE TABLE `oauth_used_jtis`( 13 + `jti` VARCHAR NOT NULL PRIMARY KEY, 14 + `issuer` VARCHAR NOT NULL, 15 + `created_at` INT8 NOT NULL, 16 + `expires_at` INT8 NOT NULL 17 + ); 18 + 19 + CREATE TABLE `oauth_par_requests`( 20 + `request_uri` VARCHAR NOT NULL PRIMARY KEY, 21 + `client_id` VARCHAR NOT NULL, 22 + `response_type` VARCHAR NOT NULL, 23 + `code_challenge` VARCHAR NOT NULL, 24 + `code_challenge_method` VARCHAR NOT NULL, 25 + `state` VARCHAR, 26 + `login_hint` VARCHAR, 27 + `scope` VARCHAR, 28 + `redirect_uri` VARCHAR, 29 + `response_mode` VARCHAR, 30 + `display` VARCHAR, 31 + `created_at` INT8 NOT NULL, 32 + `expires_at` INT8 NOT NULL 33 + ); 34 + 35 + CREATE TABLE `oauth_authorization_codes`( 36 + `code` VARCHAR NOT NULL PRIMARY KEY, 37 + `client_id` VARCHAR NOT NULL, 38 + `subject` VARCHAR NOT NULL, 39 + `code_challenge` VARCHAR NOT NULL, 40 + `code_challenge_method` VARCHAR NOT NULL, 41 + `redirect_uri` VARCHAR NOT NULL, 42 + `scope` VARCHAR, 43 + `created_at` INT8 NOT NULL, 44 + `expires_at` INT8 NOT NULL, 45 + `used` BOOL NOT NULL 46 + );
+17 -6
src/account_manager/helpers/account.rs
··· 23 23 use thiserror::Error; 24 24 25 25 use diesel::dsl::{LeftJoinOn, exists, not}; 26 - use diesel::helper_types::{Eq, IntoBoxed}; 26 + use diesel::helper_types::Eq; 27 27 28 28 #[derive(Error, Debug)] 29 29 pub enum AccountHelperError { ··· 277 277 }) 278 278 .await 279 279 .expect("Failed to delete actor")?; 280 - let did = did.to_owned(); 280 + let did_clone = did.to_owned(); 281 281 _ = db 282 282 .get() 283 283 .await? 284 284 .interact(move |conn| { 285 285 _ = delete(EmailTokenSchema::email_token) 286 - .filter(EmailTokenSchema::did.eq(&did)) 286 + .filter(EmailTokenSchema::did.eq(&did_clone)) 287 287 .execute(conn)?; 288 288 _ = delete(RefreshTokenSchema::refresh_token) 289 - .filter(RefreshTokenSchema::did.eq(&did)) 289 + .filter(RefreshTokenSchema::did.eq(&did_clone)) 290 290 .execute(conn)?; 291 291 _ = delete(AccountSchema::account) 292 - .filter(AccountSchema::did.eq(&did)) 292 + .filter(AccountSchema::did.eq(&did_clone)) 293 293 .execute(conn)?; 294 294 delete(ActorSchema::actor) 295 - .filter(ActorSchema::did.eq(&did)) 295 + .filter(ActorSchema::did.eq(&did_clone)) 296 296 .execute(conn) 297 297 }) 298 298 .await 299 299 .expect("Failed to delete account")?; 300 + 301 + let data_repo_file = format!("data/repo/{}.db", did.to_owned()); 302 + let data_blob_path = format!("data/blob/{}", did); 303 + let data_blob_path = std::path::Path::new(&data_blob_path); 304 + let data_repo_file = std::path::Path::new(&data_repo_file); 305 + if data_repo_file.exists() { 306 + std::fs::remove_file(data_repo_file)?; 307 + }; 308 + if data_blob_path.exists() { 309 + std::fs::remove_dir_all(data_blob_path)?; 310 + }; 300 311 Ok(()) 301 312 } 302 313
+18 -15
src/account_manager/mod.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 - use crate::ActorPools; 6 5 use crate::account_manager::helpers::account::{ 7 6 AccountStatus, ActorAccount, AvailabilityFlags, GetAccountAdminStatusOutput, 8 7 }; ··· 12 11 use crate::account_manager::helpers::invite::CodeDetail; 13 12 use crate::account_manager::helpers::password::UpdateUserPasswordOpts; 14 13 use crate::models::pds::EmailTokenPurpose; 14 + use crate::serve::ActorStorage; 15 15 use anyhow::Result; 16 - use axum::extract::FromRef; 17 16 use chrono::DateTime; 18 17 use chrono::offset::Utc as UtcOffset; 19 18 use cidv10::Cid; 20 - use deadpool_diesel::sqlite::Pool; 21 19 use diesel::*; 22 20 use futures::try_join; 23 21 use helpers::{account, auth, email_token, invite, password, repo}; ··· 136 134 pub async fn create_account( 137 135 &self, 138 136 opts: CreateAccountOpts, 139 - actor_pools: &mut std::collections::HashMap<String, ActorPools>, 137 + actor_pools: &mut std::collections::HashMap<String, ActorStorage>, 140 138 ) -> Result<(String, String)> { 141 139 let CreateAccountOpts { 142 140 did, ··· 182 180 let did_path = did 183 181 .strip_prefix("did:plc:") 184 182 .ok_or_else(|| anyhow::anyhow!("Invalid DID"))?; 185 - let path_repo = format!("sqlite://{}", did_path); 183 + let repo_path = format!("sqlite://data/repo/{}.db", did_path); 186 184 let actor_repo_pool = 187 - crate::establish_pool(path_repo.as_str()).expect("Failed to establish pool"); 188 - let path_blob = path_repo.replace("repo", "blob"); 189 - let actor_blob_pool = crate::establish_pool(&path_blob).expect("Failed to establish pool"); 190 - let actor_pool = ActorPools { 185 + crate::db::establish_pool(repo_path.as_str()).expect("Failed to establish pool"); 186 + let blob_path = std::path::Path::new("data/blob").to_path_buf(); 187 + let actor_pool = ActorStorage { 191 188 repo: actor_repo_pool, 192 - blob: actor_blob_pool, 189 + blob: blob_path.clone(), 193 190 }; 194 - actor_pools 195 - .insert(did.clone(), actor_pool) 196 - .expect("Failed to insert actor pools"); 191 + let blob_path = blob_path.join(did_path); 192 + tokio::fs::create_dir_all(&blob_path) 193 + .await 194 + .map_err(|_| anyhow::anyhow!("Failed to create blob path"))?; 195 + drop( 196 + actor_pools 197 + .insert(did.clone(), actor_pool) 198 + .expect("Failed to insert actor pools"), 199 + ); 197 200 let db = actor_pools 198 201 .get(&did) 199 202 .ok_or_else(|| anyhow::anyhow!("Actor not found"))? ··· 215 218 did: String, 216 219 cid: Cid, 217 220 rev: String, 218 - actor_pools: &std::collections::HashMap<String, ActorPools>, 221 + actor_pools: &std::collections::HashMap<String, ActorStorage>, 219 222 ) -> Result<()> { 220 223 let db = actor_pools 221 224 .get(&did) ··· 228 231 pub async fn delete_account( 229 232 &self, 230 233 did: &str, 231 - actor_pools: &std::collections::HashMap<String, ActorPools>, 234 + actor_pools: &std::collections::HashMap<String, ActorStorage>, 232 235 ) -> Result<()> { 233 236 let db = actor_pools 234 237 .get(did)
+20 -8
src/actor_endpoints.rs
··· 3 3 /// We shouldn't have to know about any bsky endpoints to store private user data. 4 4 /// This will _very likely_ be changed in the future. 5 5 use atrium_api::app::bsky::actor; 6 - use axum::{Json, routing::post}; 6 + use axum::{ 7 + Json, Router, 8 + extract::State, 9 + routing::{get, post}, 10 + }; 7 11 use constcat::concat; 8 - use diesel::prelude::*; 9 12 10 - use crate::actor_store::ActorStore; 13 + use crate::auth::AuthenticatedUser; 11 14 12 - use super::*; 15 + use super::serve::*; 13 16 14 17 async fn put_preferences( 15 18 user: AuthenticatedUser, 16 - State(actor_pools): State<std::collections::HashMap<String, ActorPools>>, 19 + State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>, 17 20 Json(input): Json<actor::put_preferences::Input>, 18 21 ) -> Result<()> { 19 22 let did = user.did(); 20 - let json_string = 21 - serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 23 + // let json_string = 24 + // serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 22 25 23 26 // let conn = &mut actor_pools 24 27 // .get(&did) ··· 35 38 // .context("failed to update user preferences") 36 39 // }); 37 40 todo!("Use actor_store's preferences writer instead"); 41 + // let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 42 + // let values = actor::defs::Preferences { 43 + // private_prefs: Some(json_string), 44 + // ..Default::default() 45 + // }; 46 + // let namespace = actor::defs::PreferencesNamespace::Private; 47 + // let scope = actor::defs::PreferencesScope::User; 48 + // actor_store.pref.put_preferences(values, namespace, scope); 49 + 38 50 Ok(()) 39 51 } 40 52 41 53 async fn get_preferences( 42 54 user: AuthenticatedUser, 43 - State(actor_pools): State<std::collections::HashMap<String, ActorPools>>, 55 + State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>, 44 56 ) -> Result<Json<actor::get_preferences::Output>> { 45 57 let did = user.did(); 46 58 // let conn = &mut actor_pools
+9 -7
src/actor_store/blob.rs
··· 6 6 7 7 use crate::models::actor_store as models; 8 8 use anyhow::{Result, bail}; 9 + use axum::body::Bytes; 9 10 use cidv10::Cid; 10 11 use diesel::dsl::{count_distinct, exists, not}; 11 12 use diesel::sql_types::{Integer, Nullable, Text}; ··· 28 29 use rsky_repo::types::{PreparedBlobRef, PreparedWrite}; 29 30 use std::str::FromStr as _; 30 31 31 - use super::sql_blob::{BlobStoreSql, ByteStream}; 32 + use super::blob_fs::{BlobStoreFs, ByteStream}; 32 33 33 34 pub struct GetBlobOutput { 34 35 pub size: i32, ··· 39 40 /// Handles blob operations for an actor store 40 41 pub struct BlobReader { 41 42 /// SQL-based blob storage 42 - pub blobstore: BlobStoreSql, 43 + pub blobstore: BlobStoreFs, 43 44 /// DID of the actor 44 45 pub did: String, 45 46 /// Database connection ··· 52 53 impl BlobReader { 53 54 /// Create a new blob reader 54 55 pub fn new( 55 - blobstore: BlobStoreSql, 56 + blobstore: BlobStoreFs, 56 57 db: deadpool_diesel::Pool< 57 58 deadpool_diesel::Manager<SqliteConnection>, 58 59 deadpool_diesel::sqlite::Object, ··· 138 139 pub async fn upload_blob_and_get_metadata( 139 140 &self, 140 141 user_suggested_mime: String, 141 - blob: Vec<u8>, 142 + blob: Bytes, 142 143 ) -> Result<BlobMetadata> { 143 144 let bytes = blob; 144 145 let size = bytes.len() as i64; 145 146 146 147 let (temp_key, sha256, img_info, sniffed_mime) = try_join!( 147 148 self.blobstore.put_temp(bytes.clone()), 148 - sha256_stream(bytes.clone()), 149 - image::maybe_get_info(bytes.clone()), 150 - image::mime_type_from_bytes(bytes.clone()) 149 + // TODO: reimpl funcs to use Bytes instead of Vec<u8> 150 + sha256_stream(bytes.to_vec()), 151 + image::maybe_get_info(bytes.to_vec()), 152 + image::mime_type_from_bytes(bytes.to_vec()) 151 153 )?; 152 154 153 155 let cid = sha256_raw_to_cid(sha256);
+287
src/actor_store/blob_fs.rs
··· 1 + //! File system implementation of blob storage 2 + //! Based on the S3 implementation but using local file system instead 3 + use anyhow::Result; 4 + use axum::body::Bytes; 5 + use cidv10::Cid; 6 + use rsky_common::get_random_str; 7 + use rsky_repo::error::BlobError; 8 + use std::path::PathBuf; 9 + use std::str::FromStr; 10 + use tokio::fs as async_fs; 11 + use tokio::io::AsyncWriteExt; 12 + use tracing::{debug, error, warn}; 13 + 14 + /// ByteStream implementation for blob data 15 + pub struct ByteStream { 16 + pub bytes: Bytes, 17 + } 18 + 19 + impl ByteStream { 20 + /// Create a new ByteStream with the given bytes 21 + pub const fn new(bytes: Bytes) -> Self { 22 + Self { bytes } 23 + } 24 + 25 + /// Collect the bytes from the stream 26 + pub async fn collect(self) -> Result<Bytes> { 27 + Ok(self.bytes) 28 + } 29 + } 30 + 31 + /// Path information for moving a blob 32 + struct MoveObject { 33 + from: PathBuf, 34 + to: PathBuf, 35 + } 36 + 37 + /// File system implementation of blob storage 38 + pub struct BlobStoreFs { 39 + /// Base directory for storing blobs 40 + pub base_dir: PathBuf, 41 + /// DID of the actor 42 + pub did: String, 43 + } 44 + 45 + impl BlobStoreFs { 46 + /// Create a new file system blob store for the given DID and base directory 47 + pub const fn new(did: String, base_dir: PathBuf) -> Self { 48 + Self { base_dir, did } 49 + } 50 + 51 + /// Create a factory function for blob stores 52 + pub fn creator(base_dir: PathBuf) -> Box<dyn Fn(String) -> Self> { 53 + let base_dir_clone = base_dir; 54 + Box::new(move |did: String| Self::new(did, base_dir_clone.clone())) 55 + } 56 + 57 + /// Generate a random key for temporary storage 58 + fn gen_key(&self) -> String { 59 + get_random_str() 60 + } 61 + 62 + /// Get path to the temporary blob storage 63 + fn get_tmp_path(&self, key: &str) -> PathBuf { 64 + self.base_dir.join("tmp").join(&self.did).join(key) 65 + } 66 + 67 + /// Get path to the stored blob with appropriate sharding 68 + fn get_stored_path(&self, cid: Cid) -> PathBuf { 69 + let cid_str = cid.to_string(); 70 + 71 + // Create two-level sharded structure based on CID 72 + // First 10 chars for level 1, next 10 chars for level 2 73 + let first_level = if cid_str.len() >= 10 { 74 + &cid_str[0..10] 75 + } else { 76 + "short" 77 + }; 78 + 79 + let second_level = if cid_str.len() >= 20 { 80 + &cid_str[10..20] 81 + } else { 82 + "short" 83 + }; 84 + 85 + self.base_dir 86 + .join("blocks") 87 + .join(&self.did) 88 + .join(first_level) 89 + .join(second_level) 90 + .join(&cid_str) 91 + } 92 + 93 + /// Get path to the quarantined blob 94 + fn get_quarantined_path(&self, cid: Cid) -> PathBuf { 95 + let cid_str = cid.to_string(); 96 + self.base_dir 97 + .join("quarantine") 98 + .join(&self.did) 99 + .join(&cid_str) 100 + } 101 + 102 + /// Store a blob temporarily 103 + pub async fn put_temp(&self, bytes: Bytes) -> Result<String> { 104 + let key = self.gen_key(); 105 + let temp_path = self.get_tmp_path(&key); 106 + 107 + // Ensure the directory exists 108 + if let Some(parent) = temp_path.parent() { 109 + async_fs::create_dir_all(parent).await?; 110 + } 111 + 112 + // Write the temporary blob 113 + let mut file = async_fs::File::create(&temp_path).await?; 114 + file.write_all(&bytes).await?; 115 + file.flush().await?; 116 + 117 + debug!("Stored temp blob at: {:?}", temp_path); 118 + Ok(key) 119 + } 120 + 121 + /// Make a temporary blob permanent by moving it to the blob store 122 + pub async fn make_permanent(&self, key: String, cid: Cid) -> Result<()> { 123 + let already_has = self.has_stored(cid).await?; 124 + 125 + if !already_has { 126 + // Move the temporary blob to permanent storage 127 + self.move_object(MoveObject { 128 + from: self.get_tmp_path(&key), 129 + to: self.get_stored_path(cid), 130 + }) 131 + .await?; 132 + debug!("Moved temp blob to permanent: {} -> {}", key, cid); 133 + } else { 134 + // Already saved, so just delete the temp 135 + let temp_path = self.get_tmp_path(&key); 136 + if temp_path.exists() { 137 + async_fs::remove_file(temp_path).await?; 138 + debug!("Deleted temp blob as permanent already exists: {}", key); 139 + } 140 + } 141 + 142 + Ok(()) 143 + } 144 + 145 + /// Store a blob directly as permanent 146 + pub async fn put_permanent(&self, cid: Cid, bytes: Bytes) -> Result<()> { 147 + let target_path = self.get_stored_path(cid); 148 + 149 + // Ensure the directory exists 150 + if let Some(parent) = target_path.parent() { 151 + async_fs::create_dir_all(parent).await?; 152 + } 153 + 154 + // Write the blob 155 + let mut file = async_fs::File::create(&target_path).await?; 156 + file.write_all(&bytes).await?; 157 + file.flush().await?; 158 + 159 + debug!("Stored permanent blob: {}", cid); 160 + Ok(()) 161 + } 162 + 163 + /// Quarantine a blob by moving it to the quarantine area 164 + pub async fn quarantine(&self, cid: Cid) -> Result<()> { 165 + self.move_object(MoveObject { 166 + from: self.get_stored_path(cid), 167 + to: self.get_quarantined_path(cid), 168 + }) 169 + .await?; 170 + 171 + debug!("Quarantined blob: {}", cid); 172 + Ok(()) 173 + } 174 + 175 + /// Unquarantine a blob by moving it back to regular storage 176 + pub async fn unquarantine(&self, cid: Cid) -> Result<()> { 177 + self.move_object(MoveObject { 178 + from: self.get_quarantined_path(cid), 179 + to: self.get_stored_path(cid), 180 + }) 181 + .await?; 182 + 183 + debug!("Unquarantined blob: {}", cid); 184 + Ok(()) 185 + } 186 + 187 + /// Get a blob as a stream 188 + async fn get_object(&self, cid: Cid) -> Result<ByteStream> { 189 + let blob_path = self.get_stored_path(cid); 190 + 191 + match async_fs::read(&blob_path).await { 192 + Ok(bytes) => Ok(ByteStream::new(Bytes::from(bytes))), 193 + Err(e) => { 194 + error!("Failed to read blob at path {:?}: {}", blob_path, e); 195 + Err(anyhow::Error::new(BlobError::BlobNotFoundError)) 196 + } 197 + } 198 + } 199 + 200 + /// Get blob bytes 201 + pub async fn get_bytes(&self, cid: Cid) -> Result<Bytes> { 202 + let stream = self.get_object(cid).await?; 203 + stream.collect().await 204 + } 205 + 206 + /// Get a blob as a stream 207 + pub async fn get_stream(&self, cid: Cid) -> Result<ByteStream> { 208 + self.get_object(cid).await 209 + } 210 + 211 + /// Delete a blob by CID string 212 + pub async fn delete(&self, cid_str: String) -> Result<()> { 213 + match Cid::from_str(&cid_str) { 214 + Ok(cid) => self.delete_path(self.get_stored_path(cid)).await, 215 + Err(e) => { 216 + warn!("Invalid CID: {} - {}", cid_str, e); 217 + Err(anyhow::anyhow!("Invalid CID: {}", e)) 218 + } 219 + } 220 + } 221 + 222 + /// Delete multiple blobs by CID 223 + pub async fn delete_many(&self, cids: Vec<Cid>) -> Result<()> { 224 + let mut futures = Vec::with_capacity(cids.len()); 225 + 226 + for cid in cids { 227 + futures.push(self.delete_path(self.get_stored_path(cid))); 228 + } 229 + 230 + // Execute all delete operations concurrently 231 + let results = futures::future::join_all(futures).await; 232 + 233 + // Count errors but don't fail the operation 234 + let error_count = results.iter().filter(|r| r.is_err()).count(); 235 + if error_count > 0 { 236 + warn!( 237 + "{} errors occurred while deleting {} blobs", 238 + error_count, 239 + results.len() 240 + ); 241 + } 242 + 243 + Ok(()) 244 + } 245 + 246 + /// Check if a blob is stored in the regular storage 247 + pub async fn has_stored(&self, cid: Cid) -> Result<bool> { 248 + let blob_path = self.get_stored_path(cid); 249 + Ok(blob_path.exists()) 250 + } 251 + 252 + /// Check if a temporary blob exists 253 + pub async fn has_temp(&self, key: String) -> Result<bool> { 254 + let temp_path = self.get_tmp_path(&key); 255 + Ok(temp_path.exists()) 256 + } 257 + 258 + /// Helper function to delete a file at the given path 259 + async fn delete_path(&self, path: PathBuf) -> Result<()> { 260 + if path.exists() { 261 + async_fs::remove_file(&path).await?; 262 + debug!("Deleted file at: {:?}", path); 263 + Ok(()) 264 + } else { 265 + Err(anyhow::Error::new(BlobError::BlobNotFoundError)) 266 + } 267 + } 268 + 269 + /// Move a blob from one path to another 270 + async fn move_object(&self, mov: MoveObject) -> Result<()> { 271 + // Ensure the source exists 272 + if !mov.from.exists() { 273 + return Err(anyhow::Error::new(BlobError::BlobNotFoundError)); 274 + } 275 + 276 + // Ensure the target directory exists 277 + if let Some(parent) = mov.to.parent() { 278 + async_fs::create_dir_all(parent).await?; 279 + } 280 + 281 + // Move the file 282 + async_fs::rename(&mov.from, &mov.to).await?; 283 + 284 + debug!("Moved blob: {:?} -> {:?}", mov.from, mov.to); 285 + Ok(()) 286 + } 287 + }
+7 -6
src/actor_store/mod.rs
··· 7 7 //! Modified for SQLite backend 8 8 9 9 mod blob; 10 + pub(crate) mod blob_fs; 10 11 mod preference; 11 12 mod record; 12 13 pub(crate) mod sql_blob; ··· 33 34 use tokio::sync::RwLock; 34 35 35 36 use blob::BlobReader; 37 + use blob_fs::BlobStoreFs; 36 38 use preference::PreferenceReader; 37 39 use record::RecordReader; 38 - use sql_blob::BlobStoreSql; 39 40 use sql_repo::SqlRepoReader; 40 41 41 - use crate::ActorPools; 42 + use crate::serve::ActorStorage; 42 43 43 44 #[derive(Debug)] 44 45 enum FormatCommitError { ··· 73 74 74 75 // Combination of RepoReader/Transactor, BlobReader/Transactor, SqlRepoReader/Transactor 75 76 impl ActorStore { 76 - /// Concrete reader of an individual repo (hence BlobStoreSql which takes `did` param) 77 + /// Concrete reader of an individual repo (hence BlobStoreFs which takes `did` param) 77 78 pub fn new( 78 79 did: String, 79 - blobstore: BlobStoreSql, 80 + blobstore: BlobStoreFs, 80 81 db: deadpool_diesel::Pool< 81 82 deadpool_diesel::Manager<SqliteConnection>, 82 83 deadpool_diesel::sqlite::Object, ··· 95 96 /// Create a new ActorStore taking ActorPools HashMap as input 96 97 pub async fn from_actor_pools( 97 98 did: &String, 98 - hashmap_actor_pools: &std::collections::HashMap<String, ActorPools>, 99 + hashmap_actor_pools: &std::collections::HashMap<String, ActorStorage>, 99 100 ) -> Self { 100 101 let actor_pool = hashmap_actor_pools 101 102 .get(did) 102 103 .expect("Actor pool not found") 103 104 .clone(); 104 - let blobstore = BlobStoreSql::new(did.clone(), actor_pool.blob); 105 + let blobstore = BlobStoreFs::new(did.clone(), actor_pool.blob); 105 106 let conn = actor_pool 106 107 .repo 107 108 .clone()
+13 -44
src/apis/com/atproto/repo/apply_writes.rs
··· 1 1 //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 - use crate::SharedSequencer; 3 - use crate::account_manager::helpers::account::AvailabilityFlags; 4 - use crate::account_manager::{AccountManager, AccountManagerCreator, SharedAccountManager}; 5 - use crate::{ 6 - ActorPools, AppState, SigningKey, 7 - actor_store::{ActorStore, sql_blob::BlobStoreSql}, 8 - auth::AuthenticatedUser, 9 - config::AppConfig, 10 - error::{ApiError, ErrorMessage}, 11 - }; 12 - use anyhow::{Result, bail}; 13 - use axum::{ 14 - Json, Router, 15 - body::Body, 16 - extract::{Query, Request, State}, 17 - http::{self, StatusCode}, 18 - routing::{get, post}, 19 - }; 20 - use cidv10::Cid; 21 - use deadpool_diesel::sqlite::Pool; 22 - use futures::stream::{self, StreamExt}; 23 - use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite}; 24 - use rsky_pds::auth_verifier::AccessStandardIncludeChecks; 25 - use rsky_pds::repo::prepare::{ 26 - PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete, 27 - prepare_update, 28 - }; 29 - use rsky_pds::sequencer::Sequencer; 30 - use rsky_repo::types::PreparedWrite; 31 - use std::str::FromStr; 32 - use std::sync::Arc; 33 - use tokio::sync::RwLock; 2 + 3 + use super::*; 34 4 35 5 async fn inner_apply_writes( 36 6 body: ApplyWritesInput, 37 - user: AuthenticatedUser, 38 - sequencer: &RwLock<Sequencer>, 39 - actor_pools: std::collections::HashMap<String, ActorPools>, 40 - account_manager: &RwLock<AccountManager>, 7 + auth: AuthenticatedUser, 8 + sequencer: Arc<RwLock<Sequencer>>, 9 + actor_pools: HashMap<String, ActorStorage>, 10 + account_manager: Arc<RwLock<AccountManager>>, 41 11 ) -> Result<()> { 42 12 let tx: ApplyWritesInput = body; 43 13 let ApplyWritesInput { ··· 63 33 bail!("Account is deactivated") 64 34 } 65 35 let did = account.did; 66 - if did != user.did() { 36 + if did != auth.did() { 67 37 bail!("AuthRequiredError") 68 38 } 69 39 let did: &String = &did; ··· 72 42 } 73 43 74 44 let writes: Vec<PreparedWrite> = stream::iter(tx.writes) 75 - .then(|write| async move { 45 + .then(async |write| { 76 46 Ok::<PreparedWrite, anyhow::Error>(match write { 77 47 ApplyWritesInputRefWrite::Create(write) => PreparedWrite::Create( 78 48 prepare_create(PrepareCreateOpts { ··· 155 125 /// - `swap_commit`: `cid` // If provided, the entire operation will fail if the current repo commit CID does not match this value. Used to prevent conflicting repo mutations. 156 126 #[axum::debug_handler(state = AppState)] 157 127 pub(crate) async fn apply_writes( 158 - user: AuthenticatedUser, 159 - State(state): State<AppState>, 128 + auth: AuthenticatedUser, 129 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 130 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 131 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 160 132 Json(body): Json<ApplyWritesInput>, 161 133 ) -> Result<(), ApiError> { 162 134 tracing::debug!("@LOG: debug apply_writes {body:#?}"); 163 - let db_actors = state.db_actors; 164 - let sequencer = &state.sequencer.sequencer; 165 - let account_manager = &state.account_manager.account_manager; 166 - match inner_apply_writes(body, user, sequencer, db_actors, account_manager).await { 135 + match inner_apply_writes(body, auth, sequencer, actor_pools, account_manager).await { 167 136 Ok(()) => Ok(()), 168 137 Err(error) => { 169 138 tracing::error!("@LOG: ERROR: {error}");
+140
src/apis/com/atproto/repo/create_record.rs
··· 1 + //! Create a single new repository record. Requires auth, implemented by PDS. 2 + 3 + use super::*; 4 + 5 + async fn inner_create_record( 6 + body: CreateRecordInput, 7 + user: AuthenticatedUser, 8 + sequencer: Arc<RwLock<Sequencer>>, 9 + actor_pools: HashMap<String, ActorStorage>, 10 + account_manager: Arc<RwLock<AccountManager>>, 11 + ) -> Result<CreateRecordOutput> { 12 + let CreateRecordInput { 13 + repo, 14 + collection, 15 + record, 16 + rkey, 17 + validate, 18 + swap_commit, 19 + } = body; 20 + let account = account_manager 21 + .read() 22 + .await 23 + .get_account( 24 + &repo, 25 + Some(AvailabilityFlags { 26 + include_deactivated: Some(true), 27 + include_taken_down: None, 28 + }), 29 + ) 30 + .await?; 31 + if let Some(account) = account { 32 + if account.deactivated_at.is_some() { 33 + bail!("Account is deactivated") 34 + } 35 + let did = account.did; 36 + // if did != auth.access.credentials.unwrap().did.unwrap() { 37 + if did != user.did() { 38 + bail!("AuthRequiredError") 39 + } 40 + let swap_commit_cid = match swap_commit { 41 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 42 + None => None, 43 + }; 44 + let write = prepare_create(PrepareCreateOpts { 45 + did: did.clone(), 46 + collection: collection.clone(), 47 + record: serde_json::from_value(record)?, 48 + rkey, 49 + validate, 50 + swap_cid: None, 51 + }) 52 + .await?; 53 + 54 + let did: &String = &did; 55 + let mut actor_store = ActorStore::from_actor_pools(did, &actor_pools).await; 56 + let backlink_conflicts: Vec<AtUri> = match validate { 57 + Some(true) => { 58 + let write_at_uri: AtUri = write.uri.clone().try_into()?; 59 + actor_store 60 + .record 61 + .get_backlink_conflicts(&write_at_uri, &write.record) 62 + .await? 63 + } 64 + _ => Vec::new(), 65 + }; 66 + 67 + let backlink_deletions: Vec<PreparedDelete> = backlink_conflicts 68 + .iter() 69 + .map(|at_uri| { 70 + prepare_delete(PrepareDeleteOpts { 71 + did: at_uri.get_hostname().to_string(), 72 + collection: at_uri.get_collection(), 73 + rkey: at_uri.get_rkey(), 74 + swap_cid: None, 75 + }) 76 + }) 77 + .collect::<Result<Vec<PreparedDelete>>>()?; 78 + let mut writes: Vec<PreparedWrite> = vec![PreparedWrite::Create(write.clone())]; 79 + for delete in backlink_deletions { 80 + writes.push(PreparedWrite::Delete(delete)); 81 + } 82 + let commit = actor_store 83 + .process_writes(writes.clone(), swap_commit_cid) 84 + .await?; 85 + 86 + _ = sequencer 87 + .write() 88 + .await 89 + .sequence_commit(did.clone(), commit.clone()) 90 + .await?; 91 + account_manager 92 + .write() 93 + .await 94 + .update_repo_root( 95 + did.to_string(), 96 + commit.commit_data.cid, 97 + commit.commit_data.rev, 98 + &actor_pools, 99 + ) 100 + .await?; 101 + 102 + Ok(CreateRecordOutput { 103 + uri: write.uri.clone(), 104 + cid: write.cid.to_string(), 105 + }) 106 + } else { 107 + bail!("Could not find repo: `{repo}`") 108 + } 109 + } 110 + 111 + /// Create a single new repository record. Requires auth, implemented by PDS. 112 + /// - POST /xrpc/com.atproto.repo.createRecord 113 + /// ### Request Body 114 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 115 + /// - `collection`: `nsid` // The NSID of the record collection. 116 + /// - `rkey`: `string` // The record key. <= 512 characters. 117 + /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 118 + /// - `record` 119 + /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 120 + /// ### Responses 121 + /// - 200 OK: {`cid`: `cid`, `uri`: `at-uri`, `commit`: {`cid`: `cid`, `rev`: `tid`}, `validation_status`: [`valid`, `unknown`]} 122 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 123 + /// - 401 Unauthorized 124 + #[axum::debug_handler(state = AppState)] 125 + pub async fn create_record( 126 + user: AuthenticatedUser, 127 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 128 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 129 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 130 + Json(body): Json<CreateRecordInput>, 131 + ) -> Result<Json<CreateRecordOutput>, ApiError> { 132 + tracing::debug!("@LOG: debug create_record {body:#?}"); 133 + match inner_create_record(body, user, sequencer, db_actors, account_manager).await { 134 + Ok(res) => Ok(Json(res)), 135 + Err(error) => { 136 + tracing::error!("@LOG: ERROR: {error}"); 137 + Err(ApiError::RuntimeError) 138 + } 139 + } 140 + }
+117
src/apis/com/atproto/repo/delete_record.rs
··· 1 + //! Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS. 2 + use super::*; 3 + 4 + async fn inner_delete_record( 5 + body: DeleteRecordInput, 6 + user: AuthenticatedUser, 7 + sequencer: Arc<RwLock<Sequencer>>, 8 + actor_pools: HashMap<String, ActorStorage>, 9 + account_manager: Arc<RwLock<AccountManager>>, 10 + ) -> Result<()> { 11 + let DeleteRecordInput { 12 + repo, 13 + collection, 14 + rkey, 15 + swap_record, 16 + swap_commit, 17 + } = body; 18 + let account = account_manager 19 + .read() 20 + .await 21 + .get_account( 22 + &repo, 23 + Some(AvailabilityFlags { 24 + include_deactivated: Some(true), 25 + include_taken_down: None, 26 + }), 27 + ) 28 + .await?; 29 + match account { 30 + None => bail!("Could not find repo: `{repo}`"), 31 + Some(account) if account.deactivated_at.is_some() => bail!("Account is deactivated"), 32 + Some(account) => { 33 + let did = account.did; 34 + // if did != auth.access.credentials.unwrap().did.unwrap() { 35 + if did != user.did() { 36 + bail!("AuthRequiredError") 37 + } 38 + 39 + let swap_commit_cid = match swap_commit { 40 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 41 + None => None, 42 + }; 43 + let swap_record_cid = match swap_record { 44 + Some(swap_record) => Some(Cid::from_str(&swap_record)?), 45 + None => None, 46 + }; 47 + 48 + let write = prepare_delete(PrepareDeleteOpts { 49 + did: did.clone(), 50 + collection, 51 + rkey, 52 + swap_cid: swap_record_cid, 53 + })?; 54 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 55 + let write_at_uri: AtUri = write.uri.clone().try_into()?; 56 + let record = actor_store 57 + .record 58 + .get_record(&write_at_uri, None, Some(true)) 59 + .await?; 60 + let commit = match record { 61 + None => return Ok(()), // No-op if record already doesn't exist 62 + Some(_) => { 63 + actor_store 64 + .process_writes(vec![PreparedWrite::Delete(write.clone())], swap_commit_cid) 65 + .await? 66 + } 67 + }; 68 + 69 + _ = sequencer 70 + .write() 71 + .await 72 + .sequence_commit(did.clone(), commit.clone()) 73 + .await?; 74 + account_manager 75 + .write() 76 + .await 77 + .update_repo_root( 78 + did, 79 + commit.commit_data.cid, 80 + commit.commit_data.rev, 81 + &actor_pools, 82 + ) 83 + .await?; 84 + 85 + Ok(()) 86 + } 87 + } 88 + } 89 + 90 + /// Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS. 91 + /// - POST /xrpc/com.atproto.repo.deleteRecord 92 + /// ### Request Body 93 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 94 + /// - `collection`: `nsid` // The NSID of the record collection. 95 + /// - `rkey`: `string` // The record key. <= 512 characters. 96 + /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. 97 + /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 98 + /// ### Responses 99 + /// - 200 OK: {"commit": {"cid": "string","rev": "string"}} 100 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 101 + /// - 401 Unauthorized 102 + #[axum::debug_handler(state = AppState)] 103 + pub async fn delete_record( 104 + user: AuthenticatedUser, 105 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 106 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 107 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 108 + Json(body): Json<DeleteRecordInput>, 109 + ) -> Result<(), ApiError> { 110 + match inner_delete_record(body, user, sequencer, db_actors, account_manager).await { 111 + Ok(()) => Ok(()), 112 + Err(error) => { 113 + tracing::error!("@LOG: ERROR: {error}"); 114 + Err(ApiError::RuntimeError) 115 + } 116 + } 117 + }
+70
src/apis/com/atproto/repo/describe_repo.rs
··· 1 + //! Get information about an account and repository, including the list of collections. Does not require auth. 2 + use super::*; 3 + 4 + async fn inner_describe_repo( 5 + repo: String, 6 + id_resolver: Arc<RwLock<IdResolver>>, 7 + actor_pools: HashMap<String, ActorStorage>, 8 + account_manager: Arc<RwLock<AccountManager>>, 9 + ) -> Result<DescribeRepoOutput> { 10 + let account = account_manager 11 + .read() 12 + .await 13 + .get_account(&repo, None) 14 + .await?; 15 + match account { 16 + None => bail!("Cound not find user: `{repo}`"), 17 + Some(account) => { 18 + let did_doc: DidDocument = match id_resolver 19 + .write() 20 + .await 21 + .did 22 + .ensure_resolve(&account.did, None) 23 + .await 24 + { 25 + Err(err) => bail!("Could not resolve DID: `{err}`"), 26 + Ok(res) => res, 27 + }; 28 + let handle = rsky_common::get_handle(&did_doc); 29 + let handle_is_correct = handle == account.handle; 30 + 31 + let actor_store = 32 + ActorStore::from_actor_pools(&account.did.clone(), &actor_pools).await; 33 + let collections = actor_store.record.list_collections().await?; 34 + 35 + Ok(DescribeRepoOutput { 36 + handle: account.handle.unwrap_or_else(|| INVALID_HANDLE.to_owned()), 37 + did: account.did, 38 + did_doc: serde_json::to_value(did_doc)?, 39 + collections, 40 + handle_is_correct, 41 + }) 42 + } 43 + } 44 + } 45 + 46 + /// Get information about an account and repository, including the list of collections. Does not require auth. 47 + /// - GET /xrpc/com.atproto.repo.describeRepo 48 + /// ### Query Parameters 49 + /// - `repo`: `at-identifier` // The handle or DID of the repo. 50 + /// ### Responses 51 + /// - 200 OK: {"handle": "string","did": "string","didDoc": {},"collections": [string],"handleIsCorrect": true} \ 52 + /// handeIsCorrect - boolean - Indicates if handle is currently valid (resolves bi-directionally) 53 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 54 + /// - 401 Unauthorized 55 + #[tracing::instrument(skip_all)] 56 + #[axum::debug_handler(state = AppState)] 57 + pub async fn describe_repo( 58 + Query(input): Query<atrium_repo::describe_repo::ParametersData>, 59 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 60 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 61 + State(id_resolver): State<Arc<RwLock<IdResolver>>>, 62 + ) -> Result<Json<DescribeRepoOutput>, ApiError> { 63 + match inner_describe_repo(input.repo.into(), id_resolver, db_actors, account_manager).await { 64 + Ok(res) => Ok(Json(res)), 65 + Err(error) => { 66 + tracing::error!("{error:?}"); 67 + Err(ApiError::RuntimeError) 68 + } 69 + } 70 + }
+37
src/apis/com/atproto/repo/ex.rs
··· 1 + //! 2 + use crate::account_manager::AccountManager; 3 + use crate::serve::ActorStorage; 4 + use crate::{actor_store::ActorStore, error::ApiError, serve::AppState}; 5 + use anyhow::{Result, bail}; 6 + use axum::extract::Query; 7 + use axum::{Json, extract::State}; 8 + use rsky_identity::IdResolver; 9 + use rsky_pds::sequencer::Sequencer; 10 + use std::collections::HashMap; 11 + use std::hash::RandomState; 12 + use std::sync::Arc; 13 + use tokio::sync::RwLock; 14 + 15 + async fn fun( 16 + actor_pools: HashMap<String, ActorStorage>, 17 + account_manager: Arc<RwLock<AccountManager>>, 18 + id_resolver: Arc<RwLock<IdResolver>>, 19 + sequencer: Arc<RwLock<Sequencer>>, 20 + ) -> Result<_> { 21 + todo!(); 22 + } 23 + 24 + /// 25 + #[tracing::instrument(skip_all)] 26 + #[axum::debug_handler(state = AppState)] 27 + pub async fn fun( 28 + auth: AuthenticatedUser, 29 + Query(input): Query<atrium_api::com::atproto::repo::describe_repo::ParametersData>, 30 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 31 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 32 + State(id_resolver): State<Arc<RwLock<IdResolver>>>, 33 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 34 + Json(body): Json<ApplyWritesInput>, 35 + ) -> Result<Json<_>, ApiError> { 36 + todo!(); 37 + }
+102
src/apis/com/atproto/repo/get_record.rs
··· 1 + //! Get a single record from a repository. Does not require auth. 2 + 3 + use crate::pipethrough::{ProxyRequest, pipethrough}; 4 + 5 + use super::*; 6 + 7 + use rsky_pds::pipethrough::OverrideOpts; 8 + 9 + async fn inner_get_record( 10 + repo: String, 11 + collection: String, 12 + rkey: String, 13 + cid: Option<String>, 14 + req: ProxyRequest, 15 + actor_pools: HashMap<String, ActorStorage>, 16 + account_manager: Arc<RwLock<AccountManager>>, 17 + ) -> Result<GetRecordOutput> { 18 + let did = account_manager 19 + .read() 20 + .await 21 + .get_did_for_actor(&repo, None) 22 + .await?; 23 + 24 + // fetch from pds if available, if not then fetch from appview 25 + if let Some(did) = did { 26 + let uri = AtUri::make(did.clone(), Some(collection), Some(rkey))?; 27 + 28 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 29 + 30 + match actor_store.record.get_record(&uri, cid, None).await { 31 + Ok(Some(record)) if record.takedown_ref.is_none() => Ok(GetRecordOutput { 32 + uri: uri.to_string(), 33 + cid: Some(record.cid), 34 + value: serde_json::to_value(record.value)?, 35 + }), 36 + _ => bail!("Could not locate record: `{uri}`"), 37 + } 38 + } else { 39 + match req.cfg.bsky_app_view { 40 + None => bail!("Could not locate record"), 41 + Some(_) => match pipethrough( 42 + &req, 43 + None, 44 + OverrideOpts { 45 + aud: None, 46 + lxm: None, 47 + }, 48 + ) 49 + .await 50 + { 51 + Err(error) => { 52 + tracing::error!("@LOG: ERROR: {error}"); 53 + bail!("Could not locate record") 54 + } 55 + Ok(res) => { 56 + let output: GetRecordOutput = serde_json::from_slice(res.buffer.as_slice())?; 57 + Ok(output) 58 + } 59 + }, 60 + } 61 + } 62 + } 63 + 64 + /// Get a single record from a repository. Does not require auth. 65 + /// - GET /xrpc/com.atproto.repo.getRecord 66 + /// ### Query Parameters 67 + /// - `repo`: `at-identifier` // The handle or DID of the repo. 68 + /// - `collection`: `nsid` // The NSID of the record collection. 69 + /// - `rkey`: `string` // The record key. <= 512 characters. 70 + /// - `cid`: `cid` // The CID of the version of the record. If not specified, then return the most recent version. 71 + /// ### Responses 72 + /// - 200 OK: {"uri": "string","cid": "string","value": {}} 73 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`]} 74 + /// - 401 Unauthorized 75 + #[tracing::instrument(skip_all)] 76 + #[axum::debug_handler(state = AppState)] 77 + pub async fn get_record( 78 + Query(input): Query<ParametersData>, 79 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 80 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 81 + req: ProxyRequest, 82 + ) -> Result<Json<GetRecordOutput>, ApiError> { 83 + let repo = input.repo; 84 + let collection = input.collection; 85 + let rkey = input.rkey; 86 + let cid = input.cid; 87 + match inner_get_record(repo, collection, rkey, cid, req, db_actors, account_manager).await { 88 + Ok(res) => Ok(Json(res)), 89 + Err(error) => { 90 + tracing::error!("@LOG: ERROR: {error}"); 91 + Err(ApiError::RecordNotFound) 92 + } 93 + } 94 + } 95 + 96 + #[derive(serde::Deserialize, Debug)] 97 + pub struct ParametersData { 98 + pub cid: Option<String>, 99 + pub collection: String, 100 + pub repo: String, 101 + pub rkey: String, 102 + }
+183
src/apis/com/atproto/repo/import_repo.rs
··· 1 + use axum::{body::Bytes, http::HeaderMap}; 2 + use reqwest::header; 3 + use rsky_common::env::env_int; 4 + use rsky_repo::block_map::BlockMap; 5 + use rsky_repo::car::{CarWithRoot, read_stream_car_with_root}; 6 + use rsky_repo::parse::get_and_parse_record; 7 + use rsky_repo::repo::Repo; 8 + use rsky_repo::sync::consumer::{VerifyRepoInput, verify_diff}; 9 + use rsky_repo::types::{RecordWriteDescript, VerifiedDiff}; 10 + use ubyte::ToByteUnit; 11 + 12 + use super::*; 13 + 14 + async fn from_data(bytes: Bytes) -> Result<CarWithRoot, ApiError> { 15 + let max_import_size = env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes(); 16 + if bytes.len() > max_import_size { 17 + return Err(ApiError::InvalidRequest(format!( 18 + "Content-Length is greater than maximum of {max_import_size}" 19 + ))); 20 + } 21 + 22 + let mut cursor = std::io::Cursor::new(bytes); 23 + match read_stream_car_with_root(&mut cursor).await { 24 + Ok(car_with_root) => Ok(car_with_root), 25 + Err(error) => { 26 + tracing::error!("Error reading stream car with root\n{error}"); 27 + Err(ApiError::InvalidRequest("Invalid CAR file".to_owned())) 28 + } 29 + } 30 + } 31 + 32 + #[tracing::instrument(skip_all)] 33 + #[axum::debug_handler(state = AppState)] 34 + /// Import a repo in the form of a CAR file. Requires Content-Length HTTP header to be set. 35 + /// Request 36 + /// mime application/vnd.ipld.car 37 + /// Body - required 38 + pub async fn import_repo( 39 + // auth: AccessFullImport, 40 + auth: AuthenticatedUser, 41 + headers: HeaderMap, 42 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 43 + body: Bytes, 44 + ) -> Result<(), ApiError> { 45 + // let requester = auth.access.credentials.unwrap().did.unwrap(); 46 + let requester = auth.did(); 47 + let mut actor_store = ActorStore::from_actor_pools(&requester, &actor_pools).await; 48 + 49 + // Check headers 50 + let content_length = headers 51 + .get(header::CONTENT_LENGTH) 52 + .expect("no content length provided") 53 + .to_str() 54 + .map_err(anyhow::Error::from) 55 + .and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from)) 56 + .expect("invalid content-length header"); 57 + if content_length > env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes() { 58 + return Err(ApiError::InvalidRequest(format!( 59 + "Content-Length is greater than maximum of {}", 60 + env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes() 61 + ))); 62 + }; 63 + 64 + // Get current repo if it exists 65 + let curr_root: Option<Cid> = actor_store.get_repo_root().await; 66 + let curr_repo: Option<Repo> = match curr_root { 67 + None => None, 68 + Some(_root) => Some(Repo::load(actor_store.storage.clone(), curr_root).await?), 69 + }; 70 + 71 + // Process imported car 72 + // let car_with_root = import_repo_input.car_with_root; 73 + let car_with_root: CarWithRoot = match from_data(body).await { 74 + Ok(car) => car, 75 + Err(error) => { 76 + tracing::error!("Error importing repo\n{error:?}"); 77 + return Err(ApiError::InvalidRequest("Invalid CAR file".to_owned())); 78 + } 79 + }; 80 + 81 + // Get verified difference from current repo and imported repo 82 + let mut imported_blocks: BlockMap = car_with_root.blocks; 83 + let imported_root: Cid = car_with_root.root; 84 + let opts = VerifyRepoInput { 85 + ensure_leaves: Some(false), 86 + }; 87 + 88 + let diff: VerifiedDiff = match verify_diff( 89 + curr_repo, 90 + &mut imported_blocks, 91 + imported_root, 92 + None, 93 + None, 94 + Some(opts), 95 + ) 96 + .await 97 + { 98 + Ok(res) => res, 99 + Err(error) => { 100 + tracing::error!("{:?}", error); 101 + return Err(ApiError::RuntimeError); 102 + } 103 + }; 104 + 105 + let commit_data = diff.commit; 106 + let prepared_writes: Vec<PreparedWrite> = 107 + prepare_import_repo_writes(requester, diff.writes, &imported_blocks).await?; 108 + match actor_store 109 + .process_import_repo(commit_data, prepared_writes) 110 + .await 111 + { 112 + Ok(_res) => {} 113 + Err(error) => { 114 + tracing::error!("Error importing repo\n{error}"); 115 + return Err(ApiError::RuntimeError); 116 + } 117 + } 118 + 119 + Ok(()) 120 + } 121 + 122 + /// Converts list of RecordWriteDescripts into a list of PreparedWrites 123 + async fn prepare_import_repo_writes( 124 + did: String, 125 + writes: Vec<RecordWriteDescript>, 126 + blocks: &BlockMap, 127 + ) -> Result<Vec<PreparedWrite>, ApiError> { 128 + match stream::iter(writes) 129 + .then(|write| { 130 + let did = did.clone(); 131 + async move { 132 + Ok::<PreparedWrite, anyhow::Error>(match write { 133 + RecordWriteDescript::Create(write) => { 134 + let parsed_record = get_and_parse_record(blocks, write.cid)?; 135 + PreparedWrite::Create( 136 + prepare_create(PrepareCreateOpts { 137 + did: did.clone(), 138 + collection: write.collection, 139 + rkey: Some(write.rkey), 140 + swap_cid: None, 141 + record: parsed_record.record, 142 + validate: Some(true), 143 + }) 144 + .await?, 145 + ) 146 + } 147 + RecordWriteDescript::Update(write) => { 148 + let parsed_record = get_and_parse_record(blocks, write.cid)?; 149 + PreparedWrite::Update( 150 + prepare_update(PrepareUpdateOpts { 151 + did: did.clone(), 152 + collection: write.collection, 153 + rkey: write.rkey, 154 + swap_cid: None, 155 + record: parsed_record.record, 156 + validate: Some(true), 157 + }) 158 + .await?, 159 + ) 160 + } 161 + RecordWriteDescript::Delete(write) => { 162 + PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts { 163 + did: did.clone(), 164 + collection: write.collection, 165 + rkey: write.rkey, 166 + swap_cid: None, 167 + })?) 168 + } 169 + }) 170 + } 171 + }) 172 + .collect::<Vec<_>>() 173 + .await 174 + .into_iter() 175 + .collect::<Result<Vec<PreparedWrite>, _>>() 176 + { 177 + Ok(res) => Ok(res), 178 + Err(error) => { 179 + tracing::error!("Error preparing import repo writes\n{error}"); 180 + Err(ApiError::RuntimeError) 181 + } 182 + } 183 + }
+48
src/apis/com/atproto/repo/list_missing_blobs.rs
··· 1 + //! Returns a list of missing blobs for the requesting account. Intended to be used in the account migration flow. 2 + use rsky_lexicon::com::atproto::repo::ListMissingBlobsOutput; 3 + use rsky_pds::actor_store::blob::ListMissingBlobsOpts; 4 + 5 + use super::*; 6 + 7 + /// Returns a list of missing blobs for the requesting account. Intended to be used in the account migration flow. 8 + /// Request 9 + /// Query Parameters 10 + /// limit integer 11 + /// Possible values: >= 1 and <= 1000 12 + /// Default value: 500 13 + /// cursor string 14 + /// Responses 15 + /// cursor string 16 + /// blobs object[] 17 + #[tracing::instrument(skip_all)] 18 + #[axum::debug_handler(state = AppState)] 19 + pub async fn list_missing_blobs( 20 + user: AuthenticatedUser, 21 + Query(input): Query<atrium_repo::list_missing_blobs::ParametersData>, 22 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 23 + ) -> Result<Json<ListMissingBlobsOutput>, ApiError> { 24 + let cursor = input.cursor; 25 + let limit = input.limit; 26 + let default_limit: atrium_api::types::LimitedNonZeroU16<1000> = 27 + atrium_api::types::LimitedNonZeroU16::try_from(500).expect("default limit"); 28 + let limit: u16 = limit.unwrap_or(default_limit).into(); 29 + // let did = auth.access.credentials.unwrap().did.unwrap(); 30 + let did = user.did(); 31 + 32 + let actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 33 + 34 + match actor_store 35 + .blob 36 + .list_missing_blobs(ListMissingBlobsOpts { cursor, limit }) 37 + .await 38 + { 39 + Ok(blobs) => { 40 + let cursor = blobs.last().map(|last_blob| last_blob.cid.clone()); 41 + Ok(Json(ListMissingBlobsOutput { cursor, blobs })) 42 + } 43 + Err(error) => { 44 + tracing::error!("{error:?}"); 45 + Err(ApiError::RuntimeError) 46 + } 47 + } 48 + }
+146
src/apis/com/atproto/repo/list_records.rs
··· 1 + //! List a range of records in a repository, matching a specific collection. Does not require auth. 2 + use super::*; 3 + 4 + // #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 5 + // #[serde(rename_all = "camelCase")] 6 + // /// Parameters for [`list_records`]. 7 + // pub(super) struct ListRecordsParameters { 8 + // ///The NSID of the record type. 9 + // pub collection: Nsid, 10 + // /// The cursor to start from. 11 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 12 + // pub cursor: Option<String>, 13 + // ///The number of records to return. 14 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 15 + // pub limit: Option<String>, 16 + // ///The handle or DID of the repo. 17 + // pub repo: AtIdentifier, 18 + // ///Flag to reverse the order of the returned records. 19 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 20 + // pub reverse: Option<bool>, 21 + // ///DEPRECATED: The highest sort-ordered rkey to stop at (exclusive) 22 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 23 + // pub rkey_end: Option<String>, 24 + // ///DEPRECATED: The lowest sort-ordered rkey to start from (exclusive) 25 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 26 + // pub rkey_start: Option<String>, 27 + // } 28 + 29 + #[expect(non_snake_case, clippy::too_many_arguments)] 30 + async fn inner_list_records( 31 + // The handle or DID of the repo. 32 + repo: String, 33 + // The NSID of the record type. 34 + collection: String, 35 + // The number of records to return. 36 + limit: u16, 37 + cursor: Option<String>, 38 + // DEPRECATED: The lowest sort-ordered rkey to start from (exclusive) 39 + rkeyStart: Option<String>, 40 + // DEPRECATED: The highest sort-ordered rkey to stop at (exclusive) 41 + rkeyEnd: Option<String>, 42 + // Flag to reverse the order of the returned records. 43 + reverse: bool, 44 + // The actor pools 45 + actor_pools: HashMap<String, ActorStorage>, 46 + account_manager: Arc<RwLock<AccountManager>>, 47 + ) -> Result<ListRecordsOutput> { 48 + if limit > 100 { 49 + bail!("Error: limit can not be greater than 100") 50 + } 51 + let did = account_manager 52 + .read() 53 + .await 54 + .get_did_for_actor(&repo, None) 55 + .await?; 56 + if let Some(did) = did { 57 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 58 + 59 + let records: Vec<Record> = actor_store 60 + .record 61 + .list_records_for_collection( 62 + collection, 63 + limit as i64, 64 + reverse, 65 + cursor, 66 + rkeyStart, 67 + rkeyEnd, 68 + None, 69 + ) 70 + .await? 71 + .into_iter() 72 + .map(|record| { 73 + Ok(Record { 74 + uri: record.uri.clone(), 75 + cid: record.cid.clone(), 76 + value: serde_json::to_value(record)?, 77 + }) 78 + }) 79 + .collect::<Result<Vec<Record>>>()?; 80 + 81 + let last_record = records.last(); 82 + let cursor: Option<String>; 83 + if let Some(last_record) = last_record { 84 + let last_at_uri: AtUri = last_record.uri.clone().try_into()?; 85 + cursor = Some(last_at_uri.get_rkey()); 86 + } else { 87 + cursor = None; 88 + } 89 + Ok(ListRecordsOutput { records, cursor }) 90 + } else { 91 + bail!("Could not find repo: {repo}") 92 + } 93 + } 94 + 95 + /// List a range of records in a repository, matching a specific collection. Does not require auth. 96 + /// - GET /xrpc/com.atproto.repo.listRecords 97 + /// ### Query Parameters 98 + /// - `repo`: `at-identifier` // The handle or DID of the repo. 99 + /// - `collection`: `nsid` // The NSID of the record type. 100 + /// - `limit`: `integer` // The maximum number of records to return. Default 50, >=1 and <=100. 101 + /// - `cursor`: `string` 102 + /// - `reverse`: `boolean` // Flag to reverse the order of the returned records. 103 + /// ### Responses 104 + /// - 200 OK: {"cursor": "string","records": [{"uri": "string","cid": "string","value": {}}]} 105 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 106 + /// - 401 Unauthorized 107 + #[tracing::instrument(skip_all)] 108 + #[allow(non_snake_case)] 109 + #[axum::debug_handler(state = AppState)] 110 + pub async fn list_records( 111 + Query(input): Query<atrium_repo::list_records::ParametersData>, 112 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 113 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 114 + ) -> Result<Json<ListRecordsOutput>, ApiError> { 115 + let repo = input.repo; 116 + let collection = input.collection; 117 + let limit: Option<u8> = input.limit.map(u8::from); 118 + let limit: Option<u16> = limit.map(|x| x.into()); 119 + let cursor = input.cursor; 120 + let reverse = input.reverse; 121 + let rkeyStart = None; 122 + let rkeyEnd = None; 123 + 124 + let limit = limit.unwrap_or(50); 125 + let reverse = reverse.unwrap_or(false); 126 + 127 + match inner_list_records( 128 + repo.into(), 129 + collection.into(), 130 + limit, 131 + cursor, 132 + rkeyStart, 133 + rkeyEnd, 134 + reverse, 135 + actor_pools, 136 + account_manager, 137 + ) 138 + .await 139 + { 140 + Ok(res) => Ok(Json(res)), 141 + Err(error) => { 142 + tracing::error!("@LOG: ERROR: {error}"); 143 + Err(ApiError::RuntimeError) 144 + } 145 + } 146 + }
+92 -26
src/apis/com/atproto/repo/mod.rs
··· 1 - use atrium_api::com::atproto::repo; 2 - use axum::{Router, routing::post}; 1 + use atrium_api::com::atproto::repo as atrium_repo; 2 + use axum::{ 3 + Router, 4 + routing::{get, post}, 5 + }; 3 6 use constcat::concat; 4 7 5 - use crate::AppState; 6 - 7 8 pub mod apply_writes; 8 - // pub mod create_record; 9 - // pub mod delete_record; 10 - // pub mod describe_repo; 11 - // pub mod get_record; 12 - // pub mod import_repo; 13 - // pub mod list_missing_blobs; 14 - // pub mod list_records; 15 - // pub mod put_record; 16 - // pub mod upload_blob; 9 + pub mod create_record; 10 + pub mod delete_record; 11 + pub mod describe_repo; 12 + pub mod get_record; 13 + pub mod import_repo; 14 + pub mod list_missing_blobs; 15 + pub mod list_records; 16 + pub mod put_record; 17 + pub mod upload_blob; 18 + 19 + use crate::account_manager::AccountManager; 20 + use crate::account_manager::helpers::account::AvailabilityFlags; 21 + use crate::{ 22 + actor_store::ActorStore, 23 + auth::AuthenticatedUser, 24 + error::ApiError, 25 + serve::{ActorStorage, AppState}, 26 + }; 27 + use anyhow::{Result, bail}; 28 + use axum::extract::Query; 29 + use axum::{Json, extract::State}; 30 + use cidv10::Cid; 31 + use futures::stream::{self, StreamExt}; 32 + use rsky_identity::IdResolver; 33 + use rsky_identity::types::DidDocument; 34 + use rsky_lexicon::com::atproto::repo::DeleteRecordInput; 35 + use rsky_lexicon::com::atproto::repo::DescribeRepoOutput; 36 + use rsky_lexicon::com::atproto::repo::GetRecordOutput; 37 + use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite}; 38 + use rsky_lexicon::com::atproto::repo::{CreateRecordInput, CreateRecordOutput}; 39 + use rsky_lexicon::com::atproto::repo::{ListRecordsOutput, Record}; 40 + // use rsky_pds::pipethrough::{OverrideOpts, ProxyRequest, pipethrough}; 41 + use rsky_pds::repo::prepare::{ 42 + PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete, 43 + prepare_update, 44 + }; 45 + use rsky_pds::sequencer::Sequencer; 46 + use rsky_repo::types::PreparedDelete; 47 + use rsky_repo::types::PreparedWrite; 48 + use rsky_syntax::aturi::AtUri; 49 + use rsky_syntax::handle::INVALID_HANDLE; 50 + use std::collections::HashMap; 51 + use std::hash::RandomState; 52 + use std::str::FromStr; 53 + use std::sync::Arc; 54 + use tokio::sync::RwLock; 17 55 18 56 /// These endpoints are part of the atproto PDS repository management APIs. \ 19 57 /// Requests usually require authentication (unlike the com.atproto.sync.* endpoints), and are made directly to the user's own PDS instance. ··· 29 67 /// - [ ] xx /xrpc/com.atproto.repo.importRepo 30 68 // - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs 31 69 pub(crate) fn routes() -> Router<AppState> { 32 - Router::new().route( 33 - concat!("/", repo::apply_writes::NSID), 34 - post(apply_writes::apply_writes), 35 - ) 36 - // .route(concat!("/", repo::create_record::NSID), post(create_record)) 37 - // .route(concat!("/", repo::put_record::NSID), post(put_record)) 38 - // .route(concat!("/", repo::delete_record::NSID), post(delete_record)) 39 - // .route(concat!("/", repo::upload_blob::NSID), post(upload_blob)) 40 - // .route(concat!("/", repo::describe_repo::NSID), get(describe_repo)) 41 - // .route(concat!("/", repo::get_record::NSID), get(get_record)) 42 - // .route(concat!("/", repo::import_repo::NSID), post(todo)) 43 - // .route(concat!("/", repo::list_missing_blobs::NSID), get(todo)) 44 - // .route(concat!("/", repo::list_records::NSID), get(list_records)) 70 + Router::new() 71 + .route( 72 + concat!("/", atrium_repo::apply_writes::NSID), 73 + post(apply_writes::apply_writes), 74 + ) 75 + .route( 76 + concat!("/", atrium_repo::create_record::NSID), 77 + post(create_record::create_record), 78 + ) 79 + .route( 80 + concat!("/", atrium_repo::put_record::NSID), 81 + post(put_record::put_record), 82 + ) 83 + .route( 84 + concat!("/", atrium_repo::delete_record::NSID), 85 + post(delete_record::delete_record), 86 + ) 87 + .route( 88 + concat!("/", atrium_repo::upload_blob::NSID), 89 + post(upload_blob::upload_blob), 90 + ) 91 + .route( 92 + concat!("/", atrium_repo::describe_repo::NSID), 93 + get(describe_repo::describe_repo), 94 + ) 95 + .route( 96 + concat!("/", atrium_repo::get_record::NSID), 97 + get(get_record::get_record), 98 + ) 99 + .route( 100 + concat!("/", atrium_repo::import_repo::NSID), 101 + post(import_repo::import_repo), 102 + ) 103 + .route( 104 + concat!("/", atrium_repo::list_missing_blobs::NSID), 105 + get(list_missing_blobs::list_missing_blobs), 106 + ) 107 + .route( 108 + concat!("/", atrium_repo::list_records::NSID), 109 + get(list_records::list_records), 110 + ) 45 111 }
+157
src/apis/com/atproto/repo/put_record.rs
··· 1 + //! Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS. 2 + use anyhow::bail; 3 + use rsky_lexicon::com::atproto::repo::{PutRecordInput, PutRecordOutput}; 4 + use rsky_repo::types::CommitDataWithOps; 5 + 6 + use super::*; 7 + 8 + #[tracing::instrument(skip_all)] 9 + async fn inner_put_record( 10 + body: PutRecordInput, 11 + auth: AuthenticatedUser, 12 + sequencer: Arc<RwLock<Sequencer>>, 13 + actor_pools: HashMap<String, ActorStorage>, 14 + account_manager: Arc<RwLock<AccountManager>>, 15 + ) -> Result<PutRecordOutput> { 16 + let PutRecordInput { 17 + repo, 18 + collection, 19 + rkey, 20 + validate, 21 + record, 22 + swap_record, 23 + swap_commit, 24 + } = body; 25 + let account = account_manager 26 + .read() 27 + .await 28 + .get_account( 29 + &repo, 30 + Some(AvailabilityFlags { 31 + include_deactivated: Some(true), 32 + include_taken_down: None, 33 + }), 34 + ) 35 + .await?; 36 + if let Some(account) = account { 37 + if account.deactivated_at.is_some() { 38 + bail!("Account is deactivated") 39 + } 40 + let did = account.did; 41 + // if did != auth.access.credentials.unwrap().did.unwrap() { 42 + if did != auth.did() { 43 + bail!("AuthRequiredError") 44 + } 45 + let uri = AtUri::make(did.clone(), Some(collection.clone()), Some(rkey.clone()))?; 46 + let swap_commit_cid = match swap_commit { 47 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 48 + None => None, 49 + }; 50 + let swap_record_cid = match swap_record { 51 + Some(swap_record) => Some(Cid::from_str(&swap_record)?), 52 + None => None, 53 + }; 54 + let (commit, write): (Option<CommitDataWithOps>, PreparedWrite) = { 55 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 56 + 57 + let current = actor_store 58 + .record 59 + .get_record(&uri, None, Some(true)) 60 + .await?; 61 + tracing::debug!("@LOG: debug inner_put_record, current: {current:?}"); 62 + let write: PreparedWrite = if current.is_some() { 63 + PreparedWrite::Update( 64 + prepare_update(PrepareUpdateOpts { 65 + did: did.clone(), 66 + collection, 67 + rkey, 68 + swap_cid: swap_record_cid, 69 + record: serde_json::from_value(record)?, 70 + validate, 71 + }) 72 + .await?, 73 + ) 74 + } else { 75 + PreparedWrite::Create( 76 + prepare_create(PrepareCreateOpts { 77 + did: did.clone(), 78 + collection, 79 + rkey: Some(rkey), 80 + swap_cid: swap_record_cid, 81 + record: serde_json::from_value(record)?, 82 + validate, 83 + }) 84 + .await?, 85 + ) 86 + }; 87 + 88 + match current { 89 + Some(current) if current.cid == write.cid().expect("write cid").to_string() => { 90 + (None, write) 91 + } 92 + _ => { 93 + let commit = actor_store 94 + .process_writes(vec![write.clone()], swap_commit_cid) 95 + .await?; 96 + (Some(commit), write) 97 + } 98 + } 99 + }; 100 + 101 + if let Some(commit) = commit { 102 + _ = sequencer 103 + .write() 104 + .await 105 + .sequence_commit(did.clone(), commit.clone()) 106 + .await?; 107 + account_manager 108 + .write() 109 + .await 110 + .update_repo_root( 111 + did, 112 + commit.commit_data.cid, 113 + commit.commit_data.rev, 114 + &actor_pools, 115 + ) 116 + .await?; 117 + } 118 + Ok(PutRecordOutput { 119 + uri: write.uri().to_string(), 120 + cid: write.cid().expect("write cid").to_string(), 121 + }) 122 + } else { 123 + bail!("Could not find repo: `{repo}`") 124 + } 125 + } 126 + 127 + /// Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS. 128 + /// - POST /xrpc/com.atproto.repo.putRecord 129 + /// ### Request Body 130 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 131 + /// - `collection`: `nsid` // The NSID of the record collection. 132 + /// - `rkey`: `string` // The record key. <= 512 characters. 133 + /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 134 + /// - `record` 135 + /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. WARNING: nullable and optional field; may cause problems with golang implementation 136 + /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 137 + /// ### Responses 138 + /// - 200 OK: {"uri": "string","cid": "string","commit": {"cid": "string","rev": "string"},"validationStatus": "valid | unknown"} 139 + /// - 400 Bad Request: {error:"`InvalidRequest` | `ExpiredToken` | `InvalidToken` | `InvalidSwap`"} 140 + /// - 401 Unauthorized 141 + #[tracing::instrument(skip_all)] 142 + pub async fn put_record( 143 + auth: AuthenticatedUser, 144 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 145 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 146 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 147 + Json(body): Json<PutRecordInput>, 148 + ) -> Result<Json<PutRecordOutput>, ApiError> { 149 + tracing::debug!("@LOG: debug put_record {body:#?}"); 150 + match inner_put_record(body, auth, sequencer, actor_pools, account_manager).await { 151 + Ok(res) => Ok(Json(res)), 152 + Err(error) => { 153 + tracing::error!("@LOG: ERROR: {error}"); 154 + Err(ApiError::RuntimeError) 155 + } 156 + } 157 + }
-514
src/apis/com/atproto/repo/repo.rs
··· 1 - //! PDS repository endpoints /xrpc/com.atproto.repo.*) 2 - mod apply_writes; 3 - pub(crate) use apply_writes::apply_writes; 4 - 5 - use std::{collections::HashSet, str::FromStr}; 6 - 7 - use anyhow::{Context as _, anyhow}; 8 - use atrium_api::com::atproto::repo::apply_writes::{ 9 - self as atrium_apply_writes, InputWritesItem, OutputResultsItem, 10 - }; 11 - use atrium_api::{ 12 - com::atproto::repo::{self, defs::CommitMetaData}, 13 - types::{ 14 - LimitedU32, Object, TryFromUnknown as _, TryIntoUnknown as _, Unknown, 15 - string::{AtIdentifier, Nsid, Tid}, 16 - }, 17 - }; 18 - use atrium_repo::{Cid, blockstore::CarStore}; 19 - use axum::{ 20 - Json, Router, 21 - body::Body, 22 - extract::{Query, Request, State}, 23 - http::{self, StatusCode}, 24 - routing::{get, post}, 25 - }; 26 - use constcat::concat; 27 - use futures::TryStreamExt as _; 28 - use metrics::counter; 29 - use rsky_syntax::aturi::AtUri; 30 - use serde::Deserialize; 31 - use tokio::io::AsyncWriteExt as _; 32 - 33 - use crate::repo::block_map::cid_for_cbor; 34 - use crate::repo::types::PreparedCreateOrUpdate; 35 - use crate::{ 36 - AppState, Db, Error, Result, SigningKey, 37 - actor_store::{ActorStoreTransactor, ActorStoreWriter}, 38 - auth::AuthenticatedUser, 39 - config::AppConfig, 40 - error::ErrorMessage, 41 - firehose::{self, FirehoseProducer, RepoOp}, 42 - metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE}, 43 - repo::types::{PreparedWrite, WriteOpAction}, 44 - storage, 45 - }; 46 - 47 - #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 48 - #[serde(rename_all = "camelCase")] 49 - /// Parameters for [`list_records`]. 50 - pub(super) struct ListRecordsParameters { 51 - ///The NSID of the record type. 52 - pub collection: Nsid, 53 - /// The cursor to start from. 54 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 55 - pub cursor: Option<String>, 56 - ///The number of records to return. 57 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 58 - pub limit: Option<String>, 59 - ///The handle or DID of the repo. 60 - pub repo: AtIdentifier, 61 - ///Flag to reverse the order of the returned records. 62 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 63 - pub reverse: Option<bool>, 64 - ///DEPRECATED: The highest sort-ordered rkey to stop at (exclusive) 65 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 66 - pub rkey_end: Option<String>, 67 - ///DEPRECATED: The lowest sort-ordered rkey to start from (exclusive) 68 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 69 - pub rkey_start: Option<String>, 70 - } 71 - 72 - /// Resolve DID to DID document. Does not bi-directionally verify handle. 73 - /// - GET /xrpc/com.atproto.repo.resolveDid 74 - /// ### Query Parameters 75 - /// - `did`: DID to resolve. 76 - /// ### Responses 77 - /// - 200 OK: {`did_doc`: `did_doc`} 78 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `DidNotFound`, `DidDeactivated`]} 79 - async fn resolve_did( 80 - db: &Db, 81 - identifier: &AtIdentifier, 82 - ) -> anyhow::Result<( 83 - atrium_api::types::string::Did, 84 - atrium_api::types::string::Handle, 85 - )> { 86 - let (handle, did) = match *identifier { 87 - AtIdentifier::Handle(ref handle) => { 88 - let handle_as_str = &handle.as_str(); 89 - ( 90 - &handle.to_owned(), 91 - &atrium_api::types::string::Did::new( 92 - sqlx::query_scalar!( 93 - r#"SELECT did FROM handles WHERE handle = ?"#, 94 - handle_as_str 95 - ) 96 - .fetch_one(db) 97 - .await 98 - .context("failed to query did")?, 99 - ) 100 - .expect("should be valid DID"), 101 - ) 102 - } 103 - AtIdentifier::Did(ref did) => { 104 - let did_as_str = &did.as_str(); 105 - ( 106 - &atrium_api::types::string::Handle::new( 107 - sqlx::query_scalar!(r#"SELECT handle FROM handles WHERE did = ?"#, did_as_str) 108 - .fetch_one(db) 109 - .await 110 - .context("failed to query did")?, 111 - ) 112 - .expect("should be valid handle"), 113 - &did.to_owned(), 114 - ) 115 - } 116 - }; 117 - 118 - Ok((did.to_owned(), handle.to_owned())) 119 - } 120 - 121 - /// Create a single new repository record. Requires auth, implemented by PDS. 122 - /// - POST /xrpc/com.atproto.repo.createRecord 123 - /// ### Request Body 124 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 125 - /// - `collection`: `nsid` // The NSID of the record collection. 126 - /// - `rkey`: `string` // The record key. <= 512 characters. 127 - /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 128 - /// - `record` 129 - /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 130 - /// ### Responses 131 - /// - 200 OK: {`cid`: `cid`, `uri`: `at-uri`, `commit`: {`cid`: `cid`, `rev`: `tid`}, `validation_status`: [`valid`, `unknown`]} 132 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 133 - /// - 401 Unauthorized 134 - async fn create_record( 135 - user: AuthenticatedUser, 136 - State(actor_store): State<ActorStore>, 137 - State(skey): State<SigningKey>, 138 - State(config): State<AppConfig>, 139 - State(db): State<Db>, 140 - State(fhp): State<FirehoseProducer>, 141 - Json(input): Json<repo::create_record::Input>, 142 - ) -> Result<Json<repo::create_record::Output>> { 143 - todo!(); 144 - // let write_result = apply_writes::apply_writes( 145 - // user, 146 - // State(actor_store), 147 - // State(skey), 148 - // State(config), 149 - // State(db), 150 - // State(fhp), 151 - // Json( 152 - // repo::apply_writes::InputData { 153 - // repo: input.repo.clone(), 154 - // validate: input.validate, 155 - // swap_commit: input.swap_commit.clone(), 156 - // writes: vec![repo::apply_writes::InputWritesItem::Create(Box::new( 157 - // repo::apply_writes::CreateData { 158 - // collection: input.collection.clone(), 159 - // rkey: input.rkey.clone(), 160 - // value: input.record.clone(), 161 - // } 162 - // .into(), 163 - // ))], 164 - // } 165 - // .into(), 166 - // ), 167 - // ) 168 - // .await 169 - // .context("failed to apply writes")?; 170 - 171 - // let create_result = if let repo::apply_writes::OutputResultsItem::CreateResult(create_result) = 172 - // write_result 173 - // .results 174 - // .clone() 175 - // .and_then(|result| result.first().cloned()) 176 - // .context("unexpected output from apply_writes")? 177 - // { 178 - // Some(create_result) 179 - // } else { 180 - // None 181 - // } 182 - // .context("unexpected result from apply_writes")?; 183 - 184 - // Ok(Json( 185 - // repo::create_record::OutputData { 186 - // cid: create_result.cid.clone(), 187 - // commit: write_result.commit.clone(), 188 - // uri: create_result.uri.clone(), 189 - // validation_status: Some("unknown".to_owned()), 190 - // } 191 - // .into(), 192 - // )) 193 - } 194 - 195 - /// Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS. 196 - /// - POST /xrpc/com.atproto.repo.putRecord 197 - /// ### Request Body 198 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 199 - /// - `collection`: `nsid` // The NSID of the record collection. 200 - /// - `rkey`: `string` // The record key. <= 512 characters. 201 - /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 202 - /// - `record` 203 - /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. WARNING: nullable and optional field; may cause problems with golang implementation 204 - /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 205 - /// ### Responses 206 - /// - 200 OK: {"uri": "string","cid": "string","commit": {"cid": "string","rev": "string"},"validationStatus": "valid | unknown"} 207 - /// - 400 Bad Request: {error:"`InvalidRequest` | `ExpiredToken` | `InvalidToken` | `InvalidSwap`"} 208 - /// - 401 Unauthorized 209 - async fn put_record( 210 - user: AuthenticatedUser, 211 - State(actor_store): State<ActorStore>, 212 - State(skey): State<SigningKey>, 213 - State(config): State<AppConfig>, 214 - State(db): State<Db>, 215 - State(fhp): State<FirehoseProducer>, 216 - Json(input): Json<repo::put_record::Input>, 217 - ) -> Result<Json<repo::put_record::Output>> { 218 - todo!(); 219 - // // TODO: `input.swap_record` 220 - // // FIXME: "put" implies that we will create the record if it does not exist. 221 - // // We currently only update existing records and/or throw an error if one doesn't exist. 222 - // let input = (*input).clone(); 223 - // let input = repo::apply_writes::InputData { 224 - // repo: input.repo, 225 - // validate: input.validate, 226 - // swap_commit: input.swap_commit, 227 - // writes: vec![repo::apply_writes::InputWritesItem::Update(Box::new( 228 - // repo::apply_writes::UpdateData { 229 - // collection: input.collection, 230 - // rkey: input.rkey, 231 - // value: input.record, 232 - // } 233 - // .into(), 234 - // ))], 235 - // } 236 - // .into(); 237 - 238 - // let write_result = apply_writes::apply_writes( 239 - // user, 240 - // State(actor_store), 241 - // State(skey), 242 - // State(config), 243 - // State(db), 244 - // State(fhp), 245 - // Json(input), 246 - // ) 247 - // .await 248 - // .context("failed to apply writes")?; 249 - 250 - // let update_result = write_result 251 - // .results 252 - // .clone() 253 - // .and_then(|result| result.first().cloned()) 254 - // .context("unexpected output from apply_writes")?; 255 - // let (cid, uri) = match update_result { 256 - // repo::apply_writes::OutputResultsItem::CreateResult(create_result) => ( 257 - // Some(create_result.cid.clone()), 258 - // Some(create_result.uri.clone()), 259 - // ), 260 - // repo::apply_writes::OutputResultsItem::UpdateResult(update_result) => ( 261 - // Some(update_result.cid.clone()), 262 - // Some(update_result.uri.clone()), 263 - // ), 264 - // repo::apply_writes::OutputResultsItem::DeleteResult(_) => (None, None), 265 - // }; 266 - // Ok(Json( 267 - // repo::put_record::OutputData { 268 - // cid: cid.context("missing cid")?, 269 - // commit: write_result.commit.clone(), 270 - // uri: uri.context("missing uri")?, 271 - // validation_status: Some("unknown".to_owned()), 272 - // } 273 - // .into(), 274 - // )) 275 - } 276 - 277 - /// Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS. 278 - /// - POST /xrpc/com.atproto.repo.deleteRecord 279 - /// ### Request Body 280 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 281 - /// - `collection`: `nsid` // The NSID of the record collection. 282 - /// - `rkey`: `string` // The record key. <= 512 characters. 283 - /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. 284 - /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 285 - /// ### Responses 286 - /// - 200 OK: {"commit": {"cid": "string","rev": "string"}} 287 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 288 - /// - 401 Unauthorized 289 - async fn delete_record( 290 - user: AuthenticatedUser, 291 - State(actor_store): State<ActorStore>, 292 - State(skey): State<SigningKey>, 293 - State(config): State<AppConfig>, 294 - State(db): State<Db>, 295 - State(fhp): State<FirehoseProducer>, 296 - Json(input): Json<repo::delete_record::Input>, 297 - ) -> Result<Json<repo::delete_record::Output>> { 298 - todo!(); 299 - // // TODO: `input.swap_record` 300 - 301 - // Ok(Json( 302 - // repo::delete_record::OutputData { 303 - // commit: apply_writes::apply_writes( 304 - // user, 305 - // State(actor_store), 306 - // State(skey), 307 - // State(config), 308 - // State(db), 309 - // State(fhp), 310 - // Json( 311 - // repo::apply_writes::InputData { 312 - // repo: input.repo.clone(), 313 - // swap_commit: input.swap_commit.clone(), 314 - // validate: None, 315 - // writes: vec![repo::apply_writes::InputWritesItem::Delete(Box::new( 316 - // repo::apply_writes::DeleteData { 317 - // collection: input.collection.clone(), 318 - // rkey: input.rkey.clone(), 319 - // } 320 - // .into(), 321 - // ))], 322 - // } 323 - // .into(), 324 - // ), 325 - // ) 326 - // .await 327 - // .context("failed to apply writes")? 328 - // .commit 329 - // .clone(), 330 - // } 331 - // .into(), 332 - // )) 333 - } 334 - 335 - /// Get information about an account and repository, including the list of collections. Does not require auth. 336 - /// - GET /xrpc/com.atproto.repo.describeRepo 337 - /// ### Query Parameters 338 - /// - `repo`: `at-identifier` // The handle or DID of the repo. 339 - /// ### Responses 340 - /// - 200 OK: {"handle": "string","did": "string","didDoc": {},"collections": [string],"handleIsCorrect": true} \ 341 - /// handeIsCorrect - boolean - Indicates if handle is currently valid (resolves bi-directionally) 342 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 343 - /// - 401 Unauthorized 344 - async fn describe_repo( 345 - State(actor_store): State<ActorStore>, 346 - State(config): State<AppConfig>, 347 - State(db): State<Db>, 348 - Query(input): Query<repo::describe_repo::ParametersData>, 349 - ) -> Result<Json<repo::describe_repo::Output>> { 350 - // Lookup the DID by the provided handle. 351 - let (did, handle) = resolve_did(&db, &input.repo) 352 - .await 353 - .context("failed to resolve handle")?; 354 - 355 - // Use Actor Store to get the collections 356 - todo!(); 357 - } 358 - 359 - /// Get a single record from a repository. Does not require auth. 360 - /// - GET /xrpc/com.atproto.repo.getRecord 361 - /// ### Query Parameters 362 - /// - `repo`: `at-identifier` // The handle or DID of the repo. 363 - /// - `collection`: `nsid` // The NSID of the record collection. 364 - /// - `rkey`: `string` // The record key. <= 512 characters. 365 - /// - `cid`: `cid` // The CID of the version of the record. If not specified, then return the most recent version. 366 - /// ### Responses 367 - /// - 200 OK: {"uri": "string","cid": "string","value": {}} 368 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`]} 369 - /// - 401 Unauthorized 370 - async fn get_record( 371 - State(actor_store): State<ActorStore>, 372 - State(config): State<AppConfig>, 373 - State(db): State<Db>, 374 - Query(input): Query<repo::get_record::ParametersData>, 375 - ) -> Result<Json<repo::get_record::Output>> { 376 - if input.cid.is_some() { 377 - return Err(Error::unimplemented(anyhow!( 378 - "looking up old records is unsupported" 379 - ))); 380 - } 381 - 382 - // Lookup the DID by the provided handle. 383 - let (did, _handle) = resolve_did(&db, &input.repo) 384 - .await 385 - .context("failed to resolve handle")?; 386 - 387 - // Create a URI from the parameters 388 - let uri = format!( 389 - "at://{}/{}/{}", 390 - did.as_str(), 391 - input.collection.as_str(), 392 - input.rkey.as_str() 393 - ); 394 - 395 - // Use Actor Store to get the record 396 - todo!(); 397 - } 398 - 399 - /// List a range of records in a repository, matching a specific collection. Does not require auth. 400 - /// - GET /xrpc/com.atproto.repo.listRecords 401 - /// ### Query Parameters 402 - /// - `repo`: `at-identifier` // The handle or DID of the repo. 403 - /// - `collection`: `nsid` // The NSID of the record type. 404 - /// - `limit`: `integer` // The maximum number of records to return. Default 50, >=1 and <=100. 405 - /// - `cursor`: `string` 406 - /// - `reverse`: `boolean` // Flag to reverse the order of the returned records. 407 - /// ### Responses 408 - /// - 200 OK: {"cursor": "string","records": [{"uri": "string","cid": "string","value": {}}]} 409 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 410 - /// - 401 Unauthorized 411 - async fn list_records( 412 - State(actor_store): State<ActorStore>, 413 - State(config): State<AppConfig>, 414 - State(db): State<Db>, 415 - Query(input): Query<Object<ListRecordsParameters>>, 416 - ) -> Result<Json<repo::list_records::Output>> { 417 - // Lookup the DID by the provided handle. 418 - let (did, _handle) = resolve_did(&db, &input.repo) 419 - .await 420 - .context("failed to resolve handle")?; 421 - 422 - // Use Actor Store to list records for the collection 423 - todo!(); 424 - } 425 - 426 - /// Upload a new blob, to be referenced from a repository record. \ 427 - /// The blob will be deleted if it is not referenced within a time window (eg, minutes). \ 428 - /// Blob restrictions (mimetype, size, etc) are enforced when the reference is created. \ 429 - /// Requires auth, implemented by PDS. 430 - /// - POST /xrpc/com.atproto.repo.uploadBlob 431 - /// ### Request Body 432 - /// ### Responses 433 - /// - 200 OK: {"blob": "binary"} 434 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 435 - /// - 401 Unauthorized 436 - async fn upload_blob( 437 - user: AuthenticatedUser, 438 - State(actor_store): State<ActorStore>, 439 - State(config): State<AppConfig>, 440 - State(db): State<Db>, 441 - request: Request<Body>, 442 - ) -> Result<Json<repo::upload_blob::Output>> { 443 - let length = request 444 - .headers() 445 - .get(http::header::CONTENT_LENGTH) 446 - .context("no content length provided")? 447 - .to_str() 448 - .map_err(anyhow::Error::from) 449 - .and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from)) 450 - .context("invalid content-length header")?; 451 - let mime = request 452 - .headers() 453 - .get(http::header::CONTENT_TYPE) 454 - .context("no content-type provided")? 455 - .to_str() 456 - .context("invalid content-type provided")? 457 - .to_owned(); 458 - 459 - if length > config.blob.limit { 460 - return Err(Error::with_status( 461 - StatusCode::PAYLOAD_TOO_LARGE, 462 - anyhow!("size {} above limit {}", length, config.blob.limit), 463 - )); 464 - } 465 - 466 - // Read the blob data 467 - let mut body_data = Vec::new(); 468 - let mut stream = request.into_body().into_data_stream(); 469 - while let Some(bytes) = stream.try_next().await.context("failed to receive file")? { 470 - body_data.extend_from_slice(&bytes); 471 - 472 - // Check size limit incrementally 473 - if body_data.len() as u64 > config.blob.limit { 474 - return Err(Error::with_status( 475 - StatusCode::PAYLOAD_TOO_LARGE, 476 - anyhow!("size above limit and content-length header was wrong"), 477 - )); 478 - } 479 - } 480 - 481 - // Use Actor Store to upload the blob 482 - todo!(); 483 - } 484 - 485 - async fn todo() -> Result<()> { 486 - Err(Error::unimplemented(anyhow!("not implemented"))) 487 - } 488 - 489 - /// These endpoints are part of the atproto PDS repository management APIs. \ 490 - /// Requests usually require authentication (unlike the com.atproto.sync.* endpoints), and are made directly to the user's own PDS instance. 491 - /// ### Routes 492 - /// - AP /xrpc/com.atproto.repo.applyWrites -> [`apply_writes`] 493 - /// - AP /xrpc/com.atproto.repo.createRecord -> [`create_record`] 494 - /// - AP /xrpc/com.atproto.repo.putRecord -> [`put_record`] 495 - /// - AP /xrpc/com.atproto.repo.deleteRecord -> [`delete_record`] 496 - /// - AP /xrpc/com.atproto.repo.uploadBlob -> [`upload_blob`] 497 - /// - UG /xrpc/com.atproto.repo.describeRepo -> [`describe_repo`] 498 - /// - UG /xrpc/com.atproto.repo.getRecord -> [`get_record`] 499 - /// - UG /xrpc/com.atproto.repo.listRecords -> [`list_records`] 500 - /// - [ ] xx /xrpc/com.atproto.repo.importRepo 501 - // - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs 502 - pub(super) fn routes() -> Router<AppState> { 503 - Router::new() 504 - .route(concat!("/", repo::apply_writes::NSID), post(apply_writes)) 505 - // .route(concat!("/", repo::create_record::NSID), post(create_record)) 506 - // .route(concat!("/", repo::put_record::NSID), post(put_record)) 507 - // .route(concat!("/", repo::delete_record::NSID), post(delete_record)) 508 - // .route(concat!("/", repo::upload_blob::NSID), post(upload_blob)) 509 - // .route(concat!("/", repo::describe_repo::NSID), get(describe_repo)) 510 - // .route(concat!("/", repo::get_record::NSID), get(get_record)) 511 - .route(concat!("/", repo::import_repo::NSID), post(todo)) 512 - .route(concat!("/", repo::list_missing_blobs::NSID), get(todo)) 513 - // .route(concat!("/", repo::list_records::NSID), get(list_records)) 514 - }
+117
src/apis/com/atproto/repo/upload_blob.rs
··· 1 + //! Upload a new blob, to be referenced from a repository record. 2 + use crate::config::AppConfig; 3 + use anyhow::Context as _; 4 + use axum::{ 5 + body::Bytes, 6 + http::{self, HeaderMap}, 7 + }; 8 + use rsky_lexicon::com::atproto::repo::{Blob, BlobOutput}; 9 + use rsky_repo::types::{BlobConstraint, PreparedBlobRef}; 10 + // use rsky_common::BadContentTypeError; 11 + 12 + use super::*; 13 + 14 + async fn inner_upload_blob( 15 + auth: AuthenticatedUser, 16 + blob: Bytes, 17 + content_type: String, 18 + actor_pools: HashMap<String, ActorStorage>, 19 + ) -> Result<BlobOutput> { 20 + // let requester = auth.access.credentials.unwrap().did.unwrap(); 21 + let requester = auth.did(); 22 + 23 + let actor_store = ActorStore::from_actor_pools(&requester, &actor_pools).await; 24 + 25 + let metadata = actor_store 26 + .blob 27 + .upload_blob_and_get_metadata(content_type, blob) 28 + .await?; 29 + let blobref = actor_store.blob.track_untethered_blob(metadata).await?; 30 + 31 + // make the blob permanent if an associated record is already indexed 32 + let records_for_blob = actor_store 33 + .blob 34 + .get_records_for_blob(blobref.get_cid()?) 35 + .await?; 36 + 37 + if !records_for_blob.is_empty() { 38 + actor_store 39 + .blob 40 + .verify_blob_and_make_permanent(PreparedBlobRef { 41 + cid: blobref.get_cid()?, 42 + mime_type: blobref.get_mime_type().to_string(), 43 + constraints: BlobConstraint { 44 + max_size: None, 45 + accept: None, 46 + }, 47 + }) 48 + .await?; 49 + } 50 + 51 + Ok(BlobOutput { 52 + blob: Blob { 53 + r#type: Some("blob".to_owned()), 54 + r#ref: Some(blobref.get_cid()?), 55 + cid: None, 56 + mime_type: blobref.get_mime_type().to_string(), 57 + size: blobref.get_size(), 58 + original: None, 59 + }, 60 + }) 61 + } 62 + 63 + /// Upload a new blob, to be referenced from a repository record. \ 64 + /// The blob will be deleted if it is not referenced within a time window (eg, minutes). \ 65 + /// Blob restrictions (mimetype, size, etc) are enforced when the reference is created. \ 66 + /// Requires auth, implemented by PDS. 67 + /// - POST /xrpc/com.atproto.repo.uploadBlob 68 + /// ### Request Body 69 + /// ### Responses 70 + /// - 200 OK: {"blob": "binary"} 71 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 72 + /// - 401 Unauthorized 73 + #[tracing::instrument(skip_all)] 74 + #[axum::debug_handler(state = AppState)] 75 + pub async fn upload_blob( 76 + auth: AuthenticatedUser, 77 + headers: HeaderMap, 78 + State(config): State<AppConfig>, 79 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 80 + blob: Bytes, 81 + ) -> Result<Json<BlobOutput>, ApiError> { 82 + let content_length = headers 83 + .get(http::header::CONTENT_LENGTH) 84 + .context("no content length provided")? 85 + .to_str() 86 + .map_err(anyhow::Error::from) 87 + .and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from)) 88 + .context("invalid content-length header")?; 89 + let content_type = headers 90 + .get(http::header::CONTENT_TYPE) 91 + .context("no content-type provided")? 92 + .to_str() 93 + // .map_err(BadContentTypeError::MissingType) 94 + .context("invalid content-type provided")? 95 + .to_owned(); 96 + 97 + if content_length > config.blob.limit { 98 + return Err(ApiError::InvalidRequest(format!( 99 + "Content-Length is greater than maximum of {}", 100 + config.blob.limit 101 + ))); 102 + }; 103 + if blob.len() as u64 > config.blob.limit { 104 + return Err(ApiError::InvalidRequest(format!( 105 + "Blob size is greater than maximum of {} despite content-length header", 106 + config.blob.limit 107 + ))); 108 + }; 109 + 110 + match inner_upload_blob(auth, blob, content_type, actor_pools).await { 111 + Ok(res) => Ok(Json(res)), 112 + Err(error) => { 113 + tracing::error!("{error:?}"); 114 + Err(ApiError::RuntimeError) 115 + } 116 + } 117 + }
+1 -1
src/apis/mod.rs
··· 7 7 use axum::{Json, Router, routing::get}; 8 8 use serde_json::json; 9 9 10 - use crate::{AppState, Result}; 10 + use crate::serve::{AppState, Result}; 11 11 12 12 /// Health check endpoint. Returns name and version of the service. 13 13 pub(crate) async fn health() -> Result<Json<serde_json::Value>> {
+4 -1
src/auth.rs
··· 8 8 use diesel::prelude::*; 9 9 use sha2::{Digest as _, Sha256}; 10 10 11 - use crate::{AppState, Error, error::ErrorMessage}; 11 + use crate::{ 12 + error::{Error, ErrorMessage}, 13 + serve::AppState, 14 + }; 12 15 13 16 /// Request extractor for authenticated users. 14 17 /// If specified in an API endpoint, this guarantees the API can only be called
+1 -1
src/did.rs
··· 5 5 use serde::{Deserialize, Serialize}; 6 6 use url::Url; 7 7 8 - use crate::Client; 8 + use crate::serve::Client; 9 9 10 10 /// URL whitelist for DID document resolution. 11 11 const ALLOWED_URLS: &[&str] = &["bsky.app", "bsky.chat"];
+18 -11
src/error.rs
··· 148 148 149 149 impl ApiError { 150 150 /// Get the appropriate HTTP status code for this error 151 - fn status_code(&self) -> StatusCode { 151 + const fn status_code(&self) -> StatusCode { 152 152 match self { 153 153 Self::RuntimeError => StatusCode::INTERNAL_SERVER_ERROR, 154 154 Self::InvalidLogin ··· 190 190 Self::BadRequest(error, _) => error, 191 191 Self::AuthRequiredError(_) => "AuthRequiredError", 192 192 } 193 - .to_string() 193 + .to_owned() 194 194 } 195 195 196 196 /// Get the user-facing error message ··· 218 218 Self::BadRequest(_, msg) => msg, 219 219 Self::AuthRequiredError(msg) => msg, 220 220 } 221 - .to_string() 221 + .to_owned() 222 222 } 223 223 } 224 224 225 225 impl From<Error> for ApiError { 226 226 fn from(_value: Error) -> Self { 227 - ApiError::RuntimeError 227 + Self::RuntimeError 228 + } 229 + } 230 + 231 + impl From<anyhow::Error> for ApiError { 232 + fn from(_value: anyhow::Error) -> Self { 233 + Self::RuntimeError 228 234 } 229 235 } 230 236 231 237 impl From<handle::errors::Error> for ApiError { 232 238 fn from(value: handle::errors::Error) -> Self { 233 239 match value.kind { 234 - ErrorKind::InvalidHandle => ApiError::InvalidHandle, 235 - ErrorKind::HandleNotAvailable => ApiError::HandleNotAvailable, 236 - ErrorKind::UnsupportedDomain => ApiError::UnsupportedDomain, 237 - ErrorKind::InternalError => ApiError::RuntimeError, 240 + ErrorKind::InvalidHandle => Self::InvalidHandle, 241 + ErrorKind::HandleNotAvailable => Self::HandleNotAvailable, 242 + ErrorKind::UnsupportedDomain => Self::UnsupportedDomain, 243 + ErrorKind::InternalError => Self::RuntimeError, 238 244 } 239 245 } 240 246 } ··· 245 251 let error_type = self.error_type(); 246 252 let message = self.message(); 247 253 248 - // Log the error for debugging 249 - error!("API Error: {}: {}", error_type, message); 254 + if cfg!(debug_assertions) { 255 + error!("API Error: {}: {}", error_type, message); 256 + } 250 257 251 258 // Create the error message and serialize to JSON 252 259 let error_message = ErrorMessage::new(error_type, message); 253 260 let body = serde_json::to_string(&error_message).unwrap_or_else(|_| { 254 - r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_string() 261 + r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_owned() 255 262 }); 256 263 257 264 // Build the response
-426
src/firehose.rs
··· 1 - //! The firehose module. 2 - use std::{collections::VecDeque, time::Duration}; 3 - 4 - use anyhow::{Result, bail}; 5 - use atrium_api::{ 6 - com::atproto::sync::{self}, 7 - types::string::{Datetime, Did, Tid}, 8 - }; 9 - use atrium_repo::Cid; 10 - use axum::extract::ws::{Message, WebSocket}; 11 - use metrics::{counter, gauge}; 12 - use rand::Rng as _; 13 - use serde::{Serialize, ser::SerializeMap as _}; 14 - use tracing::{debug, error, info, warn}; 15 - 16 - use crate::{ 17 - Client, 18 - config::AppConfig, 19 - metrics::{FIREHOSE_HISTORY, FIREHOSE_LISTENERS, FIREHOSE_MESSAGES, FIREHOSE_SEQUENCE}, 20 - }; 21 - 22 - enum FirehoseMessage { 23 - Broadcast(sync::subscribe_repos::Message), 24 - Connect(Box<(WebSocket, Option<i64>)>), 25 - } 26 - 27 - enum FrameHeader { 28 - Error, 29 - Message(String), 30 - } 31 - 32 - impl Serialize for FrameHeader { 33 - #[expect(clippy::question_mark_used, reason = "returns a Result")] 34 - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 35 - where 36 - S: serde::Serializer, 37 - { 38 - let mut map = serializer.serialize_map(None)?; 39 - 40 - match *self { 41 - Self::Message(ref s) => { 42 - map.serialize_key("op")?; 43 - map.serialize_value(&1_i32)?; 44 - map.serialize_key("t")?; 45 - map.serialize_value(s.as_str())?; 46 - } 47 - Self::Error => { 48 - map.serialize_key("op")?; 49 - map.serialize_value(&-1_i32)?; 50 - } 51 - } 52 - 53 - map.end() 54 - } 55 - } 56 - 57 - /// A repository operation. 58 - pub(crate) enum RepoOp { 59 - /// Create a new record. 60 - Create { 61 - /// The CID of the record. 62 - cid: Cid, 63 - /// The path of the record. 64 - path: String, 65 - }, 66 - /// Delete an existing record. 67 - Delete { 68 - /// The path of the record. 69 - path: String, 70 - /// The previous CID of the record. 71 - prev: Cid, 72 - }, 73 - /// Update an existing record. 74 - Update { 75 - /// The CID of the record. 76 - cid: Cid, 77 - /// The path of the record. 78 - path: String, 79 - /// The previous CID of the record. 80 - prev: Cid, 81 - }, 82 - } 83 - 84 - impl From<RepoOp> for sync::subscribe_repos::RepoOp { 85 - fn from(val: RepoOp) -> Self { 86 - let (action, cid, prev, path) = match val { 87 - RepoOp::Create { cid, path } => ("create", Some(cid), None, path), 88 - RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path), 89 - RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path), 90 - }; 91 - 92 - sync::subscribe_repos::RepoOpData { 93 - action: action.to_owned(), 94 - cid: cid.map(atrium_api::types::CidLink), 95 - prev: prev.map(atrium_api::types::CidLink), 96 - path, 97 - } 98 - .into() 99 - } 100 - } 101 - 102 - /// A commit to the repository. 103 - pub(crate) struct Commit { 104 - /// Blobs that were created in this commit. 105 - pub blobs: Vec<Cid>, 106 - /// The car file containing the commit blocks. 107 - pub car: Vec<u8>, 108 - /// The CID of the commit. 109 - pub cid: Cid, 110 - /// The DID of the repository changed. 111 - pub did: Did, 112 - /// The operations performed in this commit. 113 - pub ops: Vec<RepoOp>, 114 - /// The previous commit's CID (if applicable). 115 - pub pcid: Option<Cid>, 116 - /// The revision of the commit. 117 - pub rev: String, 118 - } 119 - 120 - impl From<Commit> for sync::subscribe_repos::Commit { 121 - fn from(val: Commit) -> Self { 122 - sync::subscribe_repos::CommitData { 123 - blobs: val 124 - .blobs 125 - .into_iter() 126 - .map(atrium_api::types::CidLink) 127 - .collect::<Vec<_>>(), 128 - blocks: val.car, 129 - commit: atrium_api::types::CidLink(val.cid), 130 - ops: val.ops.into_iter().map(Into::into).collect::<Vec<_>>(), 131 - prev_data: val.pcid.map(atrium_api::types::CidLink), 132 - rebase: false, 133 - repo: val.did, 134 - rev: Tid::new(val.rev).expect("should be valid revision"), 135 - seq: 0, 136 - since: None, 137 - time: Datetime::now(), 138 - too_big: false, 139 - } 140 - .into() 141 - } 142 - } 143 - 144 - /// A firehose producer. This is used to transmit messages to the firehose for broadcast. 145 - #[derive(Clone, Debug)] 146 - pub(crate) struct FirehoseProducer { 147 - /// The channel to send messages to the firehose. 148 - tx: tokio::sync::mpsc::Sender<FirehoseMessage>, 149 - } 150 - 151 - impl FirehoseProducer { 152 - /// Broadcast an `#account` event. 153 - pub(crate) async fn account(&self, account: impl Into<sync::subscribe_repos::Account>) { 154 - drop( 155 - self.tx 156 - .send(FirehoseMessage::Broadcast( 157 - sync::subscribe_repos::Message::Account(Box::new(account.into())), 158 - )) 159 - .await, 160 - ); 161 - } 162 - /// Handle client connection. 163 - pub(crate) async fn client_connection(&self, ws: WebSocket, cursor: Option<i64>) { 164 - drop( 165 - self.tx 166 - .send(FirehoseMessage::Connect(Box::new((ws, cursor)))) 167 - .await, 168 - ); 169 - } 170 - /// Broadcast a `#commit` event. 171 - pub(crate) async fn commit(&self, commit: impl Into<sync::subscribe_repos::Commit>) { 172 - drop( 173 - self.tx 174 - .send(FirehoseMessage::Broadcast( 175 - sync::subscribe_repos::Message::Commit(Box::new(commit.into())), 176 - )) 177 - .await, 178 - ); 179 - } 180 - /// Broadcast an `#identity` event. 181 - pub(crate) async fn identity(&self, identity: impl Into<sync::subscribe_repos::Identity>) { 182 - drop( 183 - self.tx 184 - .send(FirehoseMessage::Broadcast( 185 - sync::subscribe_repos::Message::Identity(Box::new(identity.into())), 186 - )) 187 - .await, 188 - ); 189 - } 190 - } 191 - 192 - #[expect( 193 - clippy::as_conversions, 194 - clippy::cast_possible_truncation, 195 - clippy::cast_sign_loss, 196 - clippy::cast_precision_loss, 197 - clippy::arithmetic_side_effects 198 - )] 199 - /// Convert a `usize` to a `f64`. 200 - const fn convert_usize_f64(x: usize) -> Result<f64, &'static str> { 201 - let result = x as f64; 202 - if result as usize - x > 0 { 203 - return Err("cannot convert"); 204 - } 205 - Ok(result) 206 - } 207 - 208 - /// Serialize a message. 209 - fn serialize_message(seq: u64, mut msg: sync::subscribe_repos::Message) -> (&'static str, Vec<u8>) { 210 - let mut dummy_seq = 0_i64; 211 - #[expect(clippy::pattern_type_mismatch)] 212 - let (ty, nseq) = match &mut msg { 213 - sync::subscribe_repos::Message::Account(m) => ("#account", &mut m.seq), 214 - sync::subscribe_repos::Message::Commit(m) => ("#commit", &mut m.seq), 215 - sync::subscribe_repos::Message::Identity(m) => ("#identity", &mut m.seq), 216 - sync::subscribe_repos::Message::Sync(m) => ("#sync", &mut m.seq), 217 - sync::subscribe_repos::Message::Info(_m) => ("#info", &mut dummy_seq), 218 - }; 219 - // Set the sequence number. 220 - *nseq = i64::try_from(seq).expect("should find seq"); 221 - 222 - let hdr = FrameHeader::Message(ty.to_owned()); 223 - 224 - let mut frame = Vec::new(); 225 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 226 - serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message"); 227 - 228 - (ty, frame) 229 - } 230 - 231 - /// Broadcast a message out to all clients. 232 - async fn broadcast_message(clients: &mut Vec<WebSocket>, msg: Message) -> Result<()> { 233 - counter!(FIREHOSE_MESSAGES).increment(1); 234 - 235 - for i in (0..clients.len()).rev() { 236 - let client = clients.get_mut(i).expect("should find client"); 237 - if let Err(e) = client.send(msg.clone()).await { 238 - debug!("Firehose client disconnected: {e}"); 239 - drop(clients.remove(i)); 240 - } 241 - } 242 - 243 - gauge!(FIREHOSE_LISTENERS) 244 - .set(convert_usize_f64(clients.len()).expect("should find clients length")); 245 - Ok(()) 246 - } 247 - 248 - /// Handle a new connection from a websocket client created by subscribeRepos. 249 - async fn handle_connect( 250 - mut ws: WebSocket, 251 - seq: u64, 252 - history: &VecDeque<(u64, &str, sync::subscribe_repos::Message)>, 253 - cursor: Option<i64>, 254 - ) -> Result<WebSocket> { 255 - if let Some(cursor) = cursor { 256 - let mut frame = Vec::new(); 257 - let cursor = u64::try_from(cursor); 258 - if cursor.is_err() { 259 - tracing::warn!("cursor is not a valid u64"); 260 - return Ok(ws); 261 - } 262 - let cursor = cursor.expect("should be valid u64"); 263 - // Cursor specified; attempt to backfill the consumer. 264 - if cursor > seq { 265 - let hdr = FrameHeader::Error; 266 - let msg = sync::subscribe_repos::Error::FutureCursor(Some(format!( 267 - "cursor {cursor} is greater than the current sequence number {seq}" 268 - ))); 269 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 270 - serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message"); 271 - // Drop the connection. 272 - drop(ws.send(Message::binary(frame)).await); 273 - bail!( 274 - "connection dropped: cursor {cursor} is greater than the current sequence number {seq}" 275 - ); 276 - } 277 - 278 - for &(historical_seq, ty, ref msg) in history { 279 - if cursor > historical_seq { 280 - continue; 281 - } 282 - let hdr = FrameHeader::Message(ty.to_owned()); 283 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 284 - serde_ipld_dagcbor::to_writer(&mut frame, msg).expect("should serialize message"); 285 - if let Err(e) = ws.send(Message::binary(frame.clone())).await { 286 - debug!("Firehose client disconnected during backfill: {e}"); 287 - break; 288 - } 289 - // Clear out the frame to begin a new one. 290 - frame.clear(); 291 - } 292 - } 293 - 294 - Ok(ws) 295 - } 296 - 297 - /// Reconnect to upstream relays. 298 - pub(crate) async fn reconnect_relays(client: &Client, config: &AppConfig) { 299 - // Avoid connecting to upstream relays in test mode. 300 - if config.test { 301 - return; 302 - } 303 - 304 - info!("attempting to reconnect to upstream relays"); 305 - for relay in &config.firehose.relays { 306 - let Some(host) = relay.host_str() else { 307 - warn!("relay {} has no host specified", relay); 308 - continue; 309 - }; 310 - 311 - let r = client 312 - .post(format!("https://{host}/xrpc/com.atproto.sync.requestCrawl")) 313 - .json(&serde_json::json!({ 314 - "hostname": format!("https://{}", config.host_name) 315 - })) 316 - .send() 317 - .await; 318 - 319 - let r = match r { 320 - Ok(r) => r, 321 - Err(e) => { 322 - error!("failed to hit upstream relay {host}: {e}"); 323 - continue; 324 - } 325 - }; 326 - 327 - let s = r.status(); 328 - if let Err(e) = r.error_for_status_ref() { 329 - error!("failed to hit upstream relay {host}: {e}"); 330 - } 331 - 332 - let b = r.json::<serde_json::Value>().await; 333 - if let Ok(b) = b { 334 - info!("relay {host}: {} {}", s, b); 335 - } else { 336 - info!("relay {host}: {}", s); 337 - } 338 - } 339 - } 340 - 341 - /// The main entrypoint for the firehose. 342 - /// 343 - /// This will broadcast all updates in this PDS out to anyone who is listening. 344 - /// 345 - /// Reference: <https://atproto.com/specs/sync> 346 - pub(crate) fn spawn( 347 - client: Client, 348 - config: AppConfig, 349 - ) -> (tokio::task::JoinHandle<()>, FirehoseProducer) { 350 - let (tx, mut rx) = tokio::sync::mpsc::channel(1000); 351 - let handle = tokio::spawn(async move { 352 - fn time_since_inception() -> u64 { 353 - chrono::Utc::now() 354 - .timestamp_micros() 355 - .checked_sub(1_743_442_000_000_000) 356 - .expect("should not wrap") 357 - .unsigned_abs() 358 - } 359 - let mut clients: Vec<WebSocket> = Vec::new(); 360 - let mut history = VecDeque::with_capacity(1000); 361 - let mut seq = time_since_inception(); 362 - 363 - loop { 364 - if let Ok(msg) = tokio::time::timeout(Duration::from_secs(30), rx.recv()).await { 365 - match msg { 366 - Some(FirehoseMessage::Broadcast(msg)) => { 367 - let (ty, by) = serialize_message(seq, msg.clone()); 368 - 369 - history.push_back((seq, ty, msg)); 370 - gauge!(FIREHOSE_HISTORY).set( 371 - convert_usize_f64(history.len()).expect("should find history length"), 372 - ); 373 - 374 - info!( 375 - "Broadcasting message {} {} to {} clients", 376 - seq, 377 - ty, 378 - clients.len() 379 - ); 380 - 381 - counter!(FIREHOSE_SEQUENCE).absolute(seq); 382 - let now = time_since_inception(); 383 - if now > seq { 384 - seq = now; 385 - } else { 386 - seq = seq.checked_add(1).expect("should not wrap"); 387 - } 388 - 389 - drop(broadcast_message(&mut clients, Message::binary(by)).await); 390 - } 391 - Some(FirehoseMessage::Connect(ws_cursor)) => { 392 - let (ws, cursor) = *ws_cursor; 393 - match handle_connect(ws, seq, &history, cursor).await { 394 - Ok(r) => { 395 - gauge!(FIREHOSE_LISTENERS).increment(1_i32); 396 - clients.push(r); 397 - } 398 - Err(e) => { 399 - error!("failed to connect new client: {e}"); 400 - } 401 - } 402 - } 403 - // All producers have been destroyed. 404 - None => break, 405 - } 406 - } else { 407 - if clients.is_empty() { 408 - reconnect_relays(&client, &config).await; 409 - } 410 - 411 - let contents = rand::thread_rng() 412 - .sample_iter(rand::distributions::Alphanumeric) 413 - .take(15) 414 - .map(char::from) 415 - .collect::<String>(); 416 - 417 - // Send a websocket ping message. 418 - // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets 419 - let message = Message::Ping(axum::body::Bytes::from_owner(contents)); 420 - drop(broadcast_message(&mut clients, message).await); 421 - } 422 - } 423 - }); 424 - 425 - (handle, FirehoseProducer { tx }) 426 - }
+4 -438
src/lib.rs
··· 8 8 mod db; 9 9 mod did; 10 10 pub mod error; 11 - mod firehose; 12 11 mod metrics; 13 - mod mmap; 14 12 mod models; 15 13 mod oauth; 16 - mod plc; 14 + mod pipethrough; 17 15 mod schema; 16 + mod serve; 18 17 mod service_proxy; 19 - #[cfg(test)] 20 - mod tests; 21 18 22 - use account_manager::{AccountManager, SharedAccountManager}; 23 - use anyhow::{Context as _, anyhow}; 24 - use atrium_api::types::string::Did; 25 - use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 26 - use auth::AuthenticatedUser; 27 - use axum::{ 28 - Router, 29 - body::Body, 30 - extract::{FromRef, Request, State}, 31 - http::{self, HeaderMap, Response, StatusCode, Uri}, 32 - response::IntoResponse, 33 - routing::get, 34 - }; 35 - use azure_core::credentials::TokenCredential; 36 - use clap::Parser; 37 - use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 38 - use config::AppConfig; 39 - use db::establish_pool; 40 - use deadpool_diesel::sqlite::Pool; 41 - use diesel::prelude::*; 42 - use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 43 - pub use error::Error; 44 - use figment::{Figment, providers::Format as _}; 45 - use firehose::FirehoseProducer; 46 - use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 47 - use rand::Rng as _; 48 - use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer}; 49 - use serde::{Deserialize, Serialize}; 50 - use service_proxy::service_proxy; 51 - use std::{ 52 - net::{IpAddr, Ipv4Addr, SocketAddr}, 53 - path::PathBuf, 54 - str::FromStr as _, 55 - sync::Arc, 56 - }; 57 - use tokio::{net::TcpListener, sync::RwLock}; 58 - use tower_http::{cors::CorsLayer, trace::TraceLayer}; 59 - use tracing::{info, warn}; 60 - use uuid::Uuid; 61 - 62 - /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 63 - pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 64 - 65 - /// Embedded migrations 66 - pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 67 - 68 - /// The application-wide result type. 69 - pub type Result<T> = std::result::Result<T, Error>; 70 - /// The reqwest client type with middleware. 71 - pub type Client = reqwest_middleware::ClientWithMiddleware; 72 - 73 - /// The Shared Sequencer which requests crawls from upstream relays and emits events to the firehose. 74 - pub struct SharedSequencer { 75 - /// The sequencer instance. 76 - pub sequencer: RwLock<Sequencer>, 77 - } 78 - 79 - #[expect( 80 - clippy::arbitrary_source_item_ordering, 81 - reason = "serialized data might be structured" 82 - )] 83 - #[derive(Serialize, Deserialize, Debug, Clone)] 84 - /// The key data structure. 85 - struct KeyData { 86 - /// Primary signing key for all repo operations. 87 - skey: Vec<u8>, 88 - /// Primary signing (rotation) key for all PLC operations. 89 - rkey: Vec<u8>, 90 - } 91 - 92 - // FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 93 - // and the implementations of this algorithm are much more limited as compared to P256. 94 - // 95 - // Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 96 - #[derive(Clone)] 97 - /// The signing key for PLC/DID operations. 98 - pub struct SigningKey(Arc<Secp256k1Keypair>); 99 - #[derive(Clone)] 100 - /// The rotation key for PLC operations. 101 - pub struct RotationKey(Arc<Secp256k1Keypair>); 102 - 103 - impl std::ops::Deref for SigningKey { 104 - type Target = Secp256k1Keypair; 105 - 106 - fn deref(&self) -> &Self::Target { 107 - &self.0 108 - } 109 - } 110 - 111 - impl SigningKey { 112 - /// Import from a private key. 113 - pub fn import(key: &[u8]) -> Result<Self> { 114 - let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 115 - Ok(Self(Arc::new(key))) 116 - } 117 - } 118 - 119 - impl std::ops::Deref for RotationKey { 120 - type Target = Secp256k1Keypair; 121 - 122 - fn deref(&self) -> &Self::Target { 123 - &self.0 124 - } 125 - } 126 - 127 - #[derive(Parser, Debug, Clone)] 128 - /// Command line arguments. 129 - pub struct Args { 130 - /// Path to the configuration file 131 - #[arg(short, long, default_value = "default.toml")] 132 - pub config: PathBuf, 133 - /// The verbosity level. 134 - #[command(flatten)] 135 - pub verbosity: Verbosity<InfoLevel>, 136 - } 137 - 138 - /// The actor pools for the database connections. 139 - pub struct ActorPools { 140 - /// The database connection pool for the actor's repository. 141 - pub repo: Pool, 142 - /// The database connection pool for the actor's blobs. 143 - pub blob: Pool, 144 - } 145 - 146 - impl Clone for ActorPools { 147 - fn clone(&self) -> Self { 148 - Self { 149 - repo: self.repo.clone(), 150 - blob: self.blob.clone(), 151 - } 152 - } 153 - } 154 - 155 - #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 156 - #[derive(Clone, FromRef)] 157 - pub struct AppState { 158 - /// The application configuration. 159 - pub config: AppConfig, 160 - /// The main database connection pool. Used for common PDS data, like invite codes. 161 - pub db: Pool, 162 - /// Actor-specific database connection pools. Hashed by DID. 163 - pub db_actors: std::collections::HashMap<String, ActorPools>, 164 - 165 - /// The HTTP client with middleware. 166 - pub client: Client, 167 - /// The simple HTTP client. 168 - pub simple_client: reqwest::Client, 169 - /// The firehose producer. 170 - pub sequencer: Arc<SharedSequencer>, 171 - /// The account manager. 172 - pub account_manager: Arc<SharedAccountManager>, 173 - 174 - /// The signing key. 175 - pub signing_key: SigningKey, 176 - /// The rotation key. 177 - pub rotation_key: RotationKey, 178 - } 19 + pub use serve::run; 179 20 180 21 /// The index (/) route. 181 - async fn index() -> impl IntoResponse { 22 + async fn index() -> impl axum::response::IntoResponse { 182 23 r" 183 24 __ __ 184 25 /\ \__ /\ \__ ··· 199 40 Protocol: https://atproto.com 200 41 " 201 42 } 202 - 203 - /// The main application entry point. 204 - #[expect( 205 - clippy::cognitive_complexity, 206 - clippy::too_many_lines, 207 - unused_qualifications, 208 - reason = "main function has high complexity" 209 - )] 210 - pub async fn run() -> anyhow::Result<()> { 211 - let args = Args::parse(); 212 - 213 - // Set up trace logging to console and account for the user-provided verbosity flag. 214 - if args.verbosity.log_level_filter() != LevelFilter::Off { 215 - let lvl = match args.verbosity.log_level_filter() { 216 - LevelFilter::Error => tracing::Level::ERROR, 217 - LevelFilter::Warn => tracing::Level::WARN, 218 - LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 219 - LevelFilter::Debug => tracing::Level::DEBUG, 220 - LevelFilter::Trace => tracing::Level::TRACE, 221 - }; 222 - tracing_subscriber::fmt().with_max_level(lvl).init(); 223 - } 224 - 225 - if !args.config.exists() { 226 - // Throw up a warning if the config file does not exist. 227 - // 228 - // This is not fatal because users can specify all configuration settings via 229 - // the environment, but the most likely scenario here is that a user accidentally 230 - // omitted the config file for some reason (e.g. forgot to mount it into Docker). 231 - warn!( 232 - "configuration file {} does not exist", 233 - args.config.display() 234 - ); 235 - } 236 - 237 - // Read and parse the user-provided configuration. 238 - let config: AppConfig = Figment::new() 239 - .admerge(figment::providers::Toml::file(args.config)) 240 - .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 241 - .extract() 242 - .context("failed to load configuration")?; 243 - 244 - if config.test { 245 - warn!("BluePDS starting up in TEST mode."); 246 - warn!("This means the application will not federate with the rest of the network."); 247 - warn!( 248 - "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 249 - ); 250 - } 251 - 252 - // Initialize metrics reporting. 253 - metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 254 - 255 - // Create a reqwest client that will be used for all outbound requests. 256 - let simple_client = reqwest::Client::builder() 257 - .user_agent(APP_USER_AGENT) 258 - .build() 259 - .context("failed to build requester client")?; 260 - let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 261 - .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 262 - mode: CacheMode::Default, 263 - manager: MokaManager::default(), 264 - options: HttpCacheOptions::default(), 265 - })) 266 - .build(); 267 - 268 - tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 269 - .await 270 - .context("failed to create key directory")?; 271 - 272 - // Check if crypto keys exist. If not, create new ones. 273 - let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 274 - let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 275 - .context("failed to deserialize crypto keys")?; 276 - 277 - let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 278 - let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 279 - 280 - (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 281 - } else { 282 - info!("signing keys not found, generating new ones"); 283 - 284 - let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 285 - let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 286 - 287 - let keys = KeyData { 288 - skey: skey.export(), 289 - rkey: rkey.export(), 290 - }; 291 - 292 - let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 293 - serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 294 - 295 - (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 296 - }; 297 - 298 - tokio::fs::create_dir_all(&config.repo.path).await?; 299 - tokio::fs::create_dir_all(&config.plc.path).await?; 300 - tokio::fs::create_dir_all(&config.blob.path).await?; 301 - 302 - // Create a database connection manager and pool for the main database. 303 - let pool = 304 - establish_pool(&config.db).context("failed to establish database connection pool")?; 305 - // Create a dictionary of database connection pools for each actor. 306 - let mut actor_pools = std::collections::HashMap::new(); 307 - // let mut actor_blob_pools = std::collections::HashMap::new(); 308 - // We'll determine actors by looking in the data/repo dir for .db files. 309 - let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 310 - .await 311 - .context("failed to read repo directory")?; 312 - while let Some(entry) = actor_dbs 313 - .next_entry() 314 - .await 315 - .context("failed to read repo dir")? 316 - { 317 - let path = entry.path(); 318 - if path.extension().and_then(|s| s.to_str()) == Some("db") { 319 - let did_path = path 320 - .file_stem() 321 - .and_then(|s| s.to_str()) 322 - .context("failed to get actor DID")?; 323 - let did = Did::from_str(&format!("did:plc:{}", did_path)) 324 - .expect("should be able to parse actor DID"); 325 - 326 - // Create a new database connection manager and pool for the actor. 327 - // The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db" 328 - let path_repo = format!("sqlite://{}", did_path); 329 - let actor_repo_pool = 330 - establish_pool(&path_repo).context("failed to create database connection pool")?; 331 - // Create a new database connection manager and pool for the actor blobs. 332 - // The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db" 333 - let path_blob = path_repo.replace("repo", "blob"); 334 - let actor_blob_pool = 335 - establish_pool(&path_blob).context("failed to create database connection pool")?; 336 - drop(actor_pools.insert( 337 - did.to_string(), 338 - ActorPools { 339 - repo: actor_repo_pool, 340 - blob: actor_blob_pool, 341 - }, 342 - )); 343 - } 344 - } 345 - // Apply pending migrations 346 - // let conn = pool.get().await?; 347 - // conn.run_pending_migrations(MIGRATIONS) 348 - // .expect("should be able to run migrations"); 349 - 350 - let hostname = config.host_name.clone(); 351 - let crawlers: Vec<String> = config 352 - .firehose 353 - .relays 354 - .iter() 355 - .map(|s| s.to_string()) 356 - .collect(); 357 - let sequencer = Arc::new(SharedSequencer { 358 - sequencer: RwLock::new(Sequencer::new( 359 - Crawlers::new(hostname, crawlers.clone()), 360 - None, 361 - )), 362 - }); 363 - let account_manager = SharedAccountManager { 364 - account_manager: RwLock::new(AccountManager::new(pool.clone())), 365 - }; 366 - 367 - let addr = config 368 - .listen_address 369 - .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 370 - 371 - let app = Router::new() 372 - .route("/", get(index)) 373 - .merge(oauth::routes()) 374 - .nest( 375 - "/xrpc", 376 - apis::routes() 377 - .merge(actor_endpoints::routes()) 378 - .fallback(service_proxy), 379 - ) 380 - // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 381 - .layer(CorsLayer::permissive()) 382 - .layer(TraceLayer::new_for_http()) 383 - .with_state(AppState { 384 - config: config.clone(), 385 - db: pool.clone(), 386 - db_actors: actor_pools.clone(), 387 - client: client.clone(), 388 - simple_client, 389 - sequencer: sequencer.clone(), 390 - account_manager: Arc::new(account_manager), 391 - signing_key: skey, 392 - rotation_key: rkey, 393 - }); 394 - 395 - info!("listening on {addr}"); 396 - info!("connect to: http://127.0.0.1:{}", addr.port()); 397 - 398 - // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 399 - // If so, create an invite code and share it via the console. 400 - let conn = pool.get().await.context("failed to get db connection")?; 401 - 402 - #[derive(QueryableByName)] 403 - struct TotalCount { 404 - #[diesel(sql_type = diesel::sql_types::Integer)] 405 - total_count: i32, 406 - } 407 - 408 - let result = conn.interact(move |conn| { 409 - diesel::sql_query( 410 - "SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count", 411 - ) 412 - .get_result::<TotalCount>(conn) 413 - }) 414 - .await 415 - .expect("should be able to query database")?; 416 - 417 - let c = result.total_count; 418 - 419 - #[expect(clippy::print_stdout)] 420 - if c == 0 { 421 - let uuid = Uuid::new_v4().to_string(); 422 - 423 - use crate::models::pds as models; 424 - use crate::schema::pds::invite_code::dsl as InviteCode; 425 - let uuid_clone = uuid.clone(); 426 - drop( 427 - conn.interact(move |conn| { 428 - diesel::insert_into(InviteCode::invite_code) 429 - .values(models::InviteCode { 430 - code: uuid_clone, 431 - available_uses: 1, 432 - disabled: 0, 433 - for_account: "None".to_owned(), 434 - created_by: "None".to_owned(), 435 - created_at: "None".to_owned(), 436 - }) 437 - .execute(conn) 438 - .context("failed to create new invite code") 439 - }) 440 - .await 441 - .expect("should be able to create invite code"), 442 - ); 443 - 444 - // N.B: This is a sensitive message, so we're bypassing `tracing` here and 445 - // logging it directly to console. 446 - println!("====================================="); 447 - println!(" FIRST STARTUP "); 448 - println!("====================================="); 449 - println!("Use this code to create an account:"); 450 - println!("{uuid}"); 451 - println!("====================================="); 452 - } 453 - 454 - let listener = TcpListener::bind(&addr) 455 - .await 456 - .context("failed to bind address")?; 457 - 458 - // Serve the app, and request crawling from upstream relays. 459 - let serve = tokio::spawn(async move { 460 - axum::serve(listener, app.into_make_service()) 461 - .await 462 - .context("failed to serve app") 463 - }); 464 - 465 - // Now that the app is live, request a crawl from upstream relays. 466 - let mut background_sequencer = sequencer.sequencer.write().await.clone(); 467 - drop(tokio::spawn( 468 - async move { background_sequencer.start().await }, 469 - )); 470 - 471 - serve 472 - .await 473 - .map_err(Into::into) 474 - .and_then(|r| r) 475 - .context("failed to serve app") 476 - }
+1 -3
src/main.rs
··· 1 1 //! BluePDS binary entry point. 2 2 3 3 use anyhow::Context as _; 4 - use clap::Parser; 5 4 6 5 #[tokio::main(flavor = "multi_thread")] 7 6 async fn main() -> anyhow::Result<()> { 8 - // Parse command line arguments and call into the library's run function 9 7 bluepds::run().await.context("failed to run application") 10 - } 8 + }
-274
src/mmap.rs
··· 1 - #![allow(clippy::arbitrary_source_item_ordering)] 2 - use std::io::{ErrorKind, Read as _, Seek as _, Write as _}; 3 - 4 - #[cfg(unix)] 5 - use std::os::fd::AsRawFd as _; 6 - #[cfg(windows)] 7 - use std::os::windows::io::AsRawHandle; 8 - 9 - use memmap2::{MmapMut, MmapOptions}; 10 - 11 - pub(crate) struct MappedFile { 12 - /// The underlying file handle. 13 - file: std::fs::File, 14 - /// The length of the file. 15 - len: u64, 16 - /// The mapped memory region. 17 - map: MmapMut, 18 - /// Our current offset into the file. 19 - off: u64, 20 - } 21 - 22 - impl MappedFile { 23 - pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> { 24 - let len = f.seek(std::io::SeekFrom::End(0))?; 25 - 26 - #[cfg(windows)] 27 - let raw = f.as_raw_handle(); 28 - #[cfg(unix)] 29 - let raw = f.as_raw_fd(); 30 - 31 - #[expect(unsafe_code)] 32 - Ok(Self { 33 - // SAFETY: 34 - // All file-backed memory map constructors are marked \ 35 - // unsafe because of the potential for Undefined Behavior (UB) \ 36 - // using the map if the underlying file is subsequently modified, in or out of process. 37 - map: unsafe { MmapOptions::new().map_mut(raw)? }, 38 - file: f, 39 - len, 40 - off: 0, 41 - }) 42 - } 43 - 44 - /// Resize the memory-mapped file. This will reallocate the memory mapping. 45 - #[expect(unsafe_code)] 46 - fn resize(&mut self, len: u64) -> std::io::Result<()> { 47 - // Resize the file. 48 - self.file.set_len(len)?; 49 - 50 - #[cfg(windows)] 51 - let raw = self.file.as_raw_handle(); 52 - #[cfg(unix)] 53 - let raw = self.file.as_raw_fd(); 54 - 55 - // SAFETY: 56 - // All file-backed memory map constructors are marked \ 57 - // unsafe because of the potential for Undefined Behavior (UB) \ 58 - // using the map if the underlying file is subsequently modified, in or out of process. 59 - self.map = unsafe { MmapOptions::new().map_mut(raw)? }; 60 - self.len = len; 61 - 62 - Ok(()) 63 - } 64 - } 65 - 66 - impl std::io::Read for MappedFile { 67 - fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { 68 - if self.off == self.len { 69 - // If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations. 70 - return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof")); 71 - } 72 - 73 - // Calculate the number of bytes we're going to read. 74 - let remaining_bytes = self.len.saturating_sub(self.off); 75 - let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX); 76 - let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX); 77 - 78 - let off = usize::try_from(self.off).map_err(|e| { 79 - std::io::Error::new( 80 - ErrorKind::InvalidInput, 81 - format!("offset too large for this platform: {e}"), 82 - ) 83 - })?; 84 - 85 - if let (Some(dest), Some(src)) = ( 86 - buf.get_mut(..len), 87 - self.map.get(off..off.saturating_add(len)), 88 - ) { 89 - dest.copy_from_slice(src); 90 - self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0)); 91 - Ok(len) 92 - } else { 93 - Err(std::io::Error::new( 94 - ErrorKind::InvalidInput, 95 - "invalid buffer range", 96 - )) 97 - } 98 - } 99 - } 100 - 101 - impl std::io::Write for MappedFile { 102 - fn flush(&mut self) -> std::io::Result<()> { 103 - // This is done by the system. 104 - Ok(()) 105 - } 106 - fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { 107 - // Determine if we need to resize the file. 108 - let buf_len = u64::try_from(buf.len()).map_err(|e| { 109 - std::io::Error::new( 110 - ErrorKind::InvalidInput, 111 - format!("buffer length too large for this platform: {e}"), 112 - ) 113 - })?; 114 - 115 - if self.off.saturating_add(buf_len) >= self.len { 116 - self.resize(self.off.saturating_add(buf_len))?; 117 - } 118 - 119 - let off = usize::try_from(self.off).map_err(|e| { 120 - std::io::Error::new( 121 - ErrorKind::InvalidInput, 122 - format!("offset too large for this platform: {e}"), 123 - ) 124 - })?; 125 - let len = buf.len(); 126 - 127 - if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) { 128 - dest.copy_from_slice(buf); 129 - self.off = self.off.saturating_add(buf_len); 130 - Ok(len) 131 - } else { 132 - Err(std::io::Error::new( 133 - ErrorKind::InvalidInput, 134 - "invalid buffer range", 135 - )) 136 - } 137 - } 138 - } 139 - 140 - impl std::io::Seek for MappedFile { 141 - fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> { 142 - let off = match pos { 143 - std::io::SeekFrom::Start(i) => i, 144 - std::io::SeekFrom::End(i) => { 145 - if i <= 0 { 146 - // If i is negative or zero, we're seeking backwards from the end 147 - // or exactly at the end 148 - self.len.saturating_sub(i.unsigned_abs()) 149 - } else { 150 - // If i is positive, we're seeking beyond the end, which is allowed 151 - // but requires extending the file 152 - self.len.saturating_add(i.unsigned_abs()) 153 - } 154 - } 155 - std::io::SeekFrom::Current(i) => { 156 - if i >= 0 { 157 - self.off.saturating_add(i.unsigned_abs()) 158 - } else { 159 - self.off.saturating_sub(i.unsigned_abs()) 160 - } 161 - } 162 - }; 163 - 164 - // If the offset is beyond EOF, extend the file to the new size. 165 - if off > self.len { 166 - self.resize(off)?; 167 - } 168 - 169 - self.off = off; 170 - Ok(off) 171 - } 172 - } 173 - 174 - impl tokio::io::AsyncRead for MappedFile { 175 - fn poll_read( 176 - mut self: std::pin::Pin<&mut Self>, 177 - _cx: &mut std::task::Context<'_>, 178 - buf: &mut tokio::io::ReadBuf<'_>, 179 - ) -> std::task::Poll<std::io::Result<()>> { 180 - let wbuf = buf.initialize_unfilled(); 181 - let len = wbuf.len(); 182 - 183 - std::task::Poll::Ready(match self.read(wbuf) { 184 - Ok(_) => { 185 - buf.advance(len); 186 - Ok(()) 187 - } 188 - Err(e) => Err(e), 189 - }) 190 - } 191 - } 192 - 193 - impl tokio::io::AsyncWrite for MappedFile { 194 - fn poll_flush( 195 - self: std::pin::Pin<&mut Self>, 196 - _cx: &mut std::task::Context<'_>, 197 - ) -> std::task::Poll<Result<(), std::io::Error>> { 198 - std::task::Poll::Ready(Ok(())) 199 - } 200 - 201 - fn poll_shutdown( 202 - self: std::pin::Pin<&mut Self>, 203 - _cx: &mut std::task::Context<'_>, 204 - ) -> std::task::Poll<Result<(), std::io::Error>> { 205 - std::task::Poll::Ready(Ok(())) 206 - } 207 - 208 - fn poll_write( 209 - mut self: std::pin::Pin<&mut Self>, 210 - _cx: &mut std::task::Context<'_>, 211 - buf: &[u8], 212 - ) -> std::task::Poll<Result<usize, std::io::Error>> { 213 - std::task::Poll::Ready(self.write(buf)) 214 - } 215 - } 216 - 217 - impl tokio::io::AsyncSeek for MappedFile { 218 - fn poll_complete( 219 - self: std::pin::Pin<&mut Self>, 220 - _cx: &mut std::task::Context<'_>, 221 - ) -> std::task::Poll<std::io::Result<u64>> { 222 - std::task::Poll::Ready(Ok(self.off)) 223 - } 224 - 225 - fn start_seek( 226 - mut self: std::pin::Pin<&mut Self>, 227 - position: std::io::SeekFrom, 228 - ) -> std::io::Result<()> { 229 - self.seek(position).map(|_p| ()) 230 - } 231 - } 232 - 233 - #[cfg(test)] 234 - mod test { 235 - use rand::Rng as _; 236 - use std::io::Write as _; 237 - 238 - use super::*; 239 - 240 - #[test] 241 - fn basic_rw() { 242 - let tmp = std::env::temp_dir().join( 243 - rand::thread_rng() 244 - .sample_iter(rand::distributions::Alphanumeric) 245 - .take(10) 246 - .map(char::from) 247 - .collect::<String>(), 248 - ); 249 - 250 - let mut m = MappedFile::new( 251 - std::fs::File::options() 252 - .create(true) 253 - .truncate(true) 254 - .read(true) 255 - .write(true) 256 - .open(&tmp) 257 - .expect("Failed to open temporary file"), 258 - ) 259 - .expect("Failed to create MappedFile"); 260 - 261 - m.write_all(b"abcd123").expect("Failed to write data"); 262 - let _: u64 = m 263 - .seek(std::io::SeekFrom::Start(0)) 264 - .expect("Failed to seek to start"); 265 - 266 - let mut buf = [0_u8; 7]; 267 - m.read_exact(&mut buf).expect("Failed to read data"); 268 - 269 - assert_eq!(&buf, b"abcd123"); 270 - 271 - drop(m); 272 - std::fs::remove_file(tmp).expect("Failed to remove temporary file"); 273 - } 274 - }
+4 -10
src/oauth.rs
··· 1 1 //! OAuth endpoints 2 2 #![allow(unnameable_types, unused_qualifications)] 3 + use crate::config::AppConfig; 4 + use crate::error::Error; 3 5 use crate::metrics::AUTH_FAILED; 4 - use crate::{AppConfig, AppState, Client, Error, Result, SigningKey}; 6 + use crate::serve::{AppState, Client, Result, SigningKey}; 5 7 use anyhow::{Context as _, anyhow}; 6 8 use argon2::{Argon2, PasswordHash, PasswordVerifier as _}; 7 9 use atrium_crypto::keypair::Did as _; ··· 365 367 let response_type = response_type.to_owned(); 366 368 let code_challenge = code_challenge.to_owned(); 367 369 let code_challenge_method = code_challenge_method.to_owned(); 368 - let state = state.map(|s| s.to_owned()); 369 - let login_hint = login_hint.map(|s| s.to_owned()); 370 - let scope = scope.map(|s| s.to_owned()); 371 - let redirect_uri = redirect_uri.map(|s| s.to_owned()); 372 - let response_mode = response_mode.map(|s| s.to_owned()); 373 - let display = display.map(|s| s.to_owned()); 374 - let created_at = created_at; 375 - let expires_at = expires_at; 376 370 _ = db 377 371 .get() 378 372 .await ··· 587 581 .interact(move |conn| { 588 582 AccountSchema::account 589 583 .filter(AccountSchema::email.eq(username_clone)) 590 - .first::<rsky_pds::models::Account>(conn) 584 + .first::<crate::models::pds::Account>(conn) 591 585 .optional() 592 586 }) 593 587 .await
+606
src/pipethrough.rs
··· 1 + //! Based on https://github.com/blacksky-algorithms/rsky/blob/main/rsky-pds/src/pipethrough.rs 2 + //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 + //! 4 + //! Modified for Axum instead of Rocket 5 + 6 + use anyhow::{Result, bail}; 7 + use axum::extract::{FromRequestParts, State}; 8 + use rsky_identity::IdResolver; 9 + use rsky_pds::apis::ApiError; 10 + use rsky_pds::auth_verifier::{AccessOutput, AccessStandard}; 11 + use rsky_pds::config::{ServerConfig, ServiceConfig, env_to_cfg}; 12 + use rsky_pds::pipethrough::{OverrideOpts, ProxyHeader, UrlAndAud}; 13 + use rsky_pds::xrpc_server::types::{HandlerPipeThrough, InvalidRequestError, XRPCError}; 14 + use rsky_pds::{APP_USER_AGENT, SharedIdResolver, context}; 15 + // use lazy_static::lazy_static; 16 + use reqwest::header::{CONTENT_TYPE, HeaderValue}; 17 + use reqwest::{Client, Method, RequestBuilder, Response}; 18 + // use rocket::data::ToByteUnit; 19 + // use rocket::http::{Method, Status}; 20 + // use rocket::request::{FromRequest, Outcome, Request}; 21 + // use rocket::{Data, State}; 22 + use axum::{ 23 + body::Bytes, 24 + http::{self, HeaderMap}, 25 + }; 26 + use rsky_common::{GetServiceEndpointOpts, get_service_endpoint}; 27 + use rsky_repo::types::Ids; 28 + use serde::de::DeserializeOwned; 29 + use serde_json::Value as JsonValue; 30 + use std::collections::{BTreeMap, HashSet}; 31 + use std::str::FromStr; 32 + use std::sync::Arc; 33 + use std::time::Duration; 34 + use ubyte::ToByteUnit as _; 35 + use url::Url; 36 + 37 + use crate::serve::AppState; 38 + 39 + // pub struct OverrideOpts { 40 + // pub aud: Option<String>, 41 + // pub lxm: Option<String>, 42 + // } 43 + 44 + // pub struct UrlAndAud { 45 + // pub url: Url, 46 + // pub aud: String, 47 + // pub lxm: String, 48 + // } 49 + 50 + // pub struct ProxyHeader { 51 + // pub did: String, 52 + // pub service_url: String, 53 + // } 54 + 55 + pub struct ProxyRequest { 56 + pub headers: BTreeMap<String, String>, 57 + pub query: Option<String>, 58 + pub path: String, 59 + pub method: Method, 60 + pub id_resolver: Arc<tokio::sync::RwLock<rsky_identity::IdResolver>>, 61 + pub cfg: ServerConfig, 62 + } 63 + impl FromRequestParts<AppState> for ProxyRequest { 64 + // type Rejection = ApiError; 65 + type Rejection = axum::response::Response; 66 + 67 + async fn from_request_parts( 68 + parts: &mut axum::http::request::Parts, 69 + state: &AppState, 70 + ) -> Result<Self, Self::Rejection> { 71 + let headers = parts 72 + .headers 73 + .iter() 74 + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) 75 + .collect::<BTreeMap<String, String>>(); 76 + let query = parts.uri.query().map(|s| s.to_string()); 77 + let path = parts.uri.path().to_string(); 78 + let method = parts.method.clone(); 79 + let id_resolver = state.id_resolver.clone(); 80 + // let cfg = state.cfg.clone(); 81 + let cfg = env_to_cfg(); // TODO: use state.cfg.clone(); 82 + 83 + Ok(Self { 84 + headers, 85 + query, 86 + path, 87 + method, 88 + id_resolver, 89 + cfg, 90 + }) 91 + } 92 + } 93 + 94 + // #[rocket::async_trait] 95 + // impl<'r> FromRequest<'r> for HandlerPipeThrough { 96 + // type Error = anyhow::Error; 97 + 98 + // #[tracing::instrument(skip_all)] 99 + // async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { 100 + // match AccessStandard::from_request(req).await { 101 + // Outcome::Success(output) => { 102 + // let AccessOutput { credentials, .. } = output.access; 103 + // let requester: Option<String> = match credentials { 104 + // None => None, 105 + // Some(credentials) => credentials.did, 106 + // }; 107 + // let headers = req.headers().clone().into_iter().fold( 108 + // BTreeMap::new(), 109 + // |mut acc: BTreeMap<String, String>, cur| { 110 + // let _ = acc.insert(cur.name().to_string(), cur.value().to_string()); 111 + // acc 112 + // }, 113 + // ); 114 + // let proxy_req = ProxyRequest { 115 + // headers, 116 + // query: match req.uri().query() { 117 + // None => None, 118 + // Some(query) => Some(query.to_string()), 119 + // }, 120 + // path: req.uri().path().to_string(), 121 + // method: req.method(), 122 + // id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(), 123 + // cfg: req.guard::<&State<ServerConfig>>().await.unwrap(), 124 + // }; 125 + // match pipethrough( 126 + // &proxy_req, 127 + // requester, 128 + // OverrideOpts { 129 + // aud: None, 130 + // lxm: None, 131 + // }, 132 + // ) 133 + // .await 134 + // { 135 + // Ok(res) => Outcome::Success(res), 136 + // Err(error) => match error.downcast_ref() { 137 + // Some(InvalidRequestError::XRPCError(xrpc)) => { 138 + // if let XRPCError::FailedResponse { 139 + // status, 140 + // error, 141 + // message, 142 + // headers, 143 + // } = xrpc 144 + // { 145 + // tracing::error!( 146 + // "@LOG: XRPC ERROR Status:{status}; Message: {message:?}; Error: {error:?}; Headers: {headers:?}" 147 + // ); 148 + // } 149 + // req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string()))); 150 + // Outcome::Error((Status::BadRequest, error)) 151 + // } 152 + // _ => { 153 + // req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string()))); 154 + // Outcome::Error((Status::BadRequest, error)) 155 + // } 156 + // }, 157 + // } 158 + // } 159 + // Outcome::Error(err) => { 160 + // req.local_cache(|| Some(ApiError::RuntimeError)); 161 + // Outcome::Error(( 162 + // Status::BadRequest, 163 + // anyhow::Error::new(InvalidRequestError::AuthError(err.1)), 164 + // )) 165 + // } 166 + // _ => panic!("Unexpected outcome during Pipethrough"), 167 + // } 168 + // } 169 + // } 170 + 171 + // #[rocket::async_trait] 172 + // impl<'r> FromRequest<'r> for ProxyRequest<'r> { 173 + // type Error = anyhow::Error; 174 + 175 + // async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { 176 + // let headers = req.headers().clone().into_iter().fold( 177 + // BTreeMap::new(), 178 + // |mut acc: BTreeMap<String, String>, cur| { 179 + // let _ = acc.insert(cur.name().to_string(), cur.value().to_string()); 180 + // acc 181 + // }, 182 + // ); 183 + // Outcome::Success(Self { 184 + // headers, 185 + // query: match req.uri().query() { 186 + // None => None, 187 + // Some(query) => Some(query.to_string()), 188 + // }, 189 + // path: req.uri().path().to_string(), 190 + // method: req.method(), 191 + // id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(), 192 + // cfg: req.guard::<&State<ServerConfig>>().await.unwrap(), 193 + // }) 194 + // } 195 + // } 196 + 197 + pub async fn pipethrough( 198 + req: &ProxyRequest, 199 + requester: Option<String>, 200 + override_opts: OverrideOpts, 201 + ) -> Result<HandlerPipeThrough> { 202 + let UrlAndAud { 203 + url, 204 + aud, 205 + lxm: nsid, 206 + } = format_url_and_aud(req, override_opts.aud).await?; 207 + let lxm = override_opts.lxm.unwrap_or(nsid); 208 + let headers = format_headers(req, aud, lxm, requester).await?; 209 + let req_init = format_req_init(req, url, headers, None)?; 210 + let res = make_request(req_init).await?; 211 + parse_proxy_res(res).await 212 + } 213 + 214 + pub async fn pipethrough_procedure<T: serde::Serialize>( 215 + req: &ProxyRequest, 216 + requester: Option<String>, 217 + body: Option<T>, 218 + ) -> Result<HandlerPipeThrough> { 219 + let UrlAndAud { 220 + url, 221 + aud, 222 + lxm: nsid, 223 + } = format_url_and_aud(req, None).await?; 224 + let headers = format_headers(req, aud, nsid, requester).await?; 225 + let encoded_body: Option<Vec<u8>> = match body { 226 + None => None, 227 + Some(body) => Some(serde_json::to_string(&body)?.into_bytes()), 228 + }; 229 + let req_init = format_req_init(req, url, headers, encoded_body)?; 230 + let res = make_request(req_init).await?; 231 + parse_proxy_res(res).await 232 + } 233 + 234 + #[tracing::instrument(skip_all)] 235 + pub async fn pipethrough_procedure_post( 236 + req: &ProxyRequest, 237 + requester: Option<String>, 238 + body: Option<Bytes>, 239 + ) -> Result<HandlerPipeThrough, ApiError> { 240 + let UrlAndAud { 241 + url, 242 + aud, 243 + lxm: nsid, 244 + } = format_url_and_aud(req, None).await?; 245 + let headers = format_headers(req, aud, nsid, requester).await?; 246 + let encoded_body: Option<JsonValue>; 247 + match body { 248 + None => encoded_body = None, 249 + Some(body) => { 250 + // let res = match body.open(50.megabytes()).into_string().await { 251 + // Ok(res1) => { 252 + // tracing::info!(res1.value); 253 + // res1.value 254 + // } 255 + // Err(error) => { 256 + // tracing::error!("{error}"); 257 + // return Err(ApiError::RuntimeError); 258 + // } 259 + // }; 260 + let res = String::from_utf8(body.to_vec()).expect("Invalid UTF-8"); 261 + 262 + match serde_json::from_str(res.as_str()) { 263 + Ok(res) => { 264 + encoded_body = Some(res); 265 + } 266 + Err(error) => { 267 + tracing::error!("{error}"); 268 + return Err(ApiError::RuntimeError); 269 + } 270 + } 271 + } 272 + }; 273 + let req_init = format_req_init_with_value(req, url, headers, encoded_body)?; 274 + let res = make_request(req_init).await?; 275 + Ok(parse_proxy_res(res).await?) 276 + } 277 + 278 + // Request setup/formatting 279 + // ------------------- 280 + 281 + const REQ_HEADERS_TO_FORWARD: [&str; 4] = [ 282 + "accept-language", 283 + "content-type", 284 + "atproto-accept-labelers", 285 + "x-bsky-topics", 286 + ]; 287 + 288 + #[tracing::instrument(skip_all)] 289 + pub async fn format_url_and_aud( 290 + req: &ProxyRequest, 291 + aud_override: Option<String>, 292 + ) -> Result<UrlAndAud> { 293 + let proxy_to = parse_proxy_header(req).await?; 294 + let nsid = parse_req_nsid(req); 295 + let default_proxy = default_service(req, &nsid).await; 296 + let service_url = match proxy_to { 297 + Some(ref proxy_to) => { 298 + tracing::info!( 299 + "@LOG: format_url_and_aud() proxy_to: {:?}", 300 + proxy_to.service_url 301 + ); 302 + Some(proxy_to.service_url.clone()) 303 + } 304 + None => match default_proxy { 305 + Some(ref default_proxy) => Some(default_proxy.url.clone()), 306 + None => None, 307 + }, 308 + }; 309 + let aud = match aud_override { 310 + Some(_) => aud_override, 311 + None => match proxy_to { 312 + Some(proxy_to) => Some(proxy_to.did), 313 + None => match default_proxy { 314 + Some(default_proxy) => Some(default_proxy.did), 315 + None => None, 316 + }, 317 + }, 318 + }; 319 + match (service_url, aud) { 320 + (Some(service_url), Some(aud)) => { 321 + let mut url = Url::parse(format!("{0}{1}", service_url, req.path).as_str())?; 322 + if let Some(ref params) = req.query { 323 + url.set_query(Some(params.as_str())); 324 + } 325 + if !req.cfg.service.dev_mode && !is_safe_url(url.clone()) { 326 + bail!(InvalidRequestError::InvalidServiceUrl(url.to_string())); 327 + } 328 + Ok(UrlAndAud { 329 + url, 330 + aud, 331 + lxm: nsid, 332 + }) 333 + } 334 + _ => bail!(InvalidRequestError::NoServiceConfigured(req.path.clone())), 335 + } 336 + } 337 + 338 + pub async fn format_headers( 339 + req: &ProxyRequest, 340 + aud: String, 341 + lxm: String, 342 + requester: Option<String>, 343 + ) -> Result<HeaderMap> { 344 + let mut headers: HeaderMap = match requester { 345 + Some(requester) => context::service_auth_headers(&requester, &aud, &lxm).await?, 346 + None => HeaderMap::new(), 347 + }; 348 + // forward select headers to upstream services 349 + for header in REQ_HEADERS_TO_FORWARD { 350 + let val = req.headers.get(header); 351 + if let Some(val) = val { 352 + headers.insert(header, HeaderValue::from_str(val)?); 353 + } 354 + } 355 + Ok(headers) 356 + } 357 + 358 + pub fn format_req_init( 359 + req: &ProxyRequest, 360 + url: Url, 361 + headers: HeaderMap, 362 + body: Option<Vec<u8>>, 363 + ) -> Result<RequestBuilder> { 364 + match req.method { 365 + Method::GET => { 366 + let client = Client::builder() 367 + .user_agent(APP_USER_AGENT) 368 + .http2_keep_alive_while_idle(true) 369 + .http2_keep_alive_timeout(Duration::from_secs(5)) 370 + .default_headers(headers) 371 + .build()?; 372 + Ok(client.get(url)) 373 + } 374 + Method::HEAD => { 375 + let client = Client::builder() 376 + .user_agent(APP_USER_AGENT) 377 + .http2_keep_alive_while_idle(true) 378 + .http2_keep_alive_timeout(Duration::from_secs(5)) 379 + .default_headers(headers) 380 + .build()?; 381 + Ok(client.head(url)) 382 + } 383 + Method::POST => { 384 + let client = Client::builder() 385 + .user_agent(APP_USER_AGENT) 386 + .http2_keep_alive_while_idle(true) 387 + .http2_keep_alive_timeout(Duration::from_secs(5)) 388 + .default_headers(headers) 389 + .build()?; 390 + Ok(client.post(url).body(body.unwrap())) 391 + } 392 + _ => bail!(InvalidRequestError::MethodNotFound), 393 + } 394 + } 395 + 396 + pub fn format_req_init_with_value( 397 + req: &ProxyRequest, 398 + url: Url, 399 + headers: HeaderMap, 400 + body: Option<JsonValue>, 401 + ) -> Result<RequestBuilder> { 402 + match req.method { 403 + Method::GET => { 404 + let client = Client::builder() 405 + .user_agent(APP_USER_AGENT) 406 + .http2_keep_alive_while_idle(true) 407 + .http2_keep_alive_timeout(Duration::from_secs(5)) 408 + .default_headers(headers) 409 + .build()?; 410 + Ok(client.get(url)) 411 + } 412 + Method::HEAD => { 413 + let client = Client::builder() 414 + .user_agent(APP_USER_AGENT) 415 + .http2_keep_alive_while_idle(true) 416 + .http2_keep_alive_timeout(Duration::from_secs(5)) 417 + .default_headers(headers) 418 + .build()?; 419 + Ok(client.head(url)) 420 + } 421 + Method::POST => { 422 + let client = Client::builder() 423 + .user_agent(APP_USER_AGENT) 424 + .http2_keep_alive_while_idle(true) 425 + .http2_keep_alive_timeout(Duration::from_secs(5)) 426 + .default_headers(headers) 427 + .build()?; 428 + Ok(client.post(url).json(&body.unwrap())) 429 + } 430 + _ => bail!(InvalidRequestError::MethodNotFound), 431 + } 432 + } 433 + 434 + pub async fn parse_proxy_header(req: &ProxyRequest) -> Result<Option<ProxyHeader>> { 435 + let headers = &req.headers; 436 + let proxy_to: Option<&String> = headers.get("atproto-proxy"); 437 + match proxy_to { 438 + None => Ok(None), 439 + Some(proxy_to) => { 440 + let parts: Vec<&str> = proxy_to.split("#").collect::<Vec<&str>>(); 441 + match (parts.get(0), parts.get(1), parts.get(2)) { 442 + (Some(did), Some(service_id), None) => { 443 + let did = did.to_string(); 444 + let mut lock = req.id_resolver.write().await; 445 + match lock.did.resolve(did.clone(), None).await? { 446 + None => bail!(InvalidRequestError::CannotResolveProxyDid), 447 + Some(did_doc) => { 448 + match get_service_endpoint( 449 + did_doc, 450 + GetServiceEndpointOpts { 451 + id: format!("#{service_id}"), 452 + r#type: None, 453 + }, 454 + ) { 455 + None => bail!(InvalidRequestError::CannotResolveServiceUrl), 456 + Some(service_url) => Ok(Some(ProxyHeader { did, service_url })), 457 + } 458 + } 459 + } 460 + } 461 + (_, None, _) => bail!(InvalidRequestError::NoServiceId), 462 + _ => bail!("error parsing atproto-proxy header"), 463 + } 464 + } 465 + } 466 + } 467 + 468 + pub fn parse_req_nsid(req: &ProxyRequest) -> String { 469 + let nsid = req.path.as_str().replace("/xrpc/", ""); 470 + match nsid.ends_with("/") { 471 + false => nsid, 472 + true => nsid 473 + .trim_end_matches(|c| c == nsid.chars().last().unwrap()) 474 + .to_string(), 475 + } 476 + } 477 + 478 + // Sending request 479 + // ------------------- 480 + #[tracing::instrument(skip_all)] 481 + pub async fn make_request(req_init: RequestBuilder) -> Result<Response> { 482 + let res = req_init.send().await; 483 + match res { 484 + Err(e) => { 485 + tracing::error!("@LOG WARN: pipethrough network error {}", e.to_string()); 486 + bail!(InvalidRequestError::XRPCError(XRPCError::UpstreamFailure)) 487 + } 488 + Ok(res) => match res.error_for_status_ref() { 489 + Ok(_) => Ok(res), 490 + Err(_) => { 491 + let status = res.status().to_string(); 492 + let headers = res.headers().clone(); 493 + let error_body = res.json::<JsonValue>().await?; 494 + bail!(InvalidRequestError::XRPCError(XRPCError::FailedResponse { 495 + status, 496 + headers, 497 + error: match error_body["error"].as_str() { 498 + None => None, 499 + Some(error_body_error) => Some(error_body_error.to_string()), 500 + }, 501 + message: match error_body["message"].as_str() { 502 + None => None, 503 + Some(error_body_message) => Some(error_body_message.to_string()), 504 + } 505 + })) 506 + } 507 + }, 508 + } 509 + } 510 + 511 + // Response parsing/forwarding 512 + // ------------------- 513 + 514 + const RES_HEADERS_TO_FORWARD: [&str; 4] = [ 515 + "content-type", 516 + "content-language", 517 + "atproto-repo-rev", 518 + "atproto-content-labelers", 519 + ]; 520 + 521 + pub async fn parse_proxy_res(res: Response) -> Result<HandlerPipeThrough> { 522 + let encoding = match res.headers().get(CONTENT_TYPE) { 523 + Some(content_type) => content_type.to_str()?, 524 + None => "application/json", 525 + }; 526 + // Release borrow 527 + let encoding = encoding.to_string(); 528 + let res_headers = RES_HEADERS_TO_FORWARD.into_iter().fold( 529 + BTreeMap::new(), 530 + |mut acc: BTreeMap<String, String>, cur| { 531 + let _ = match res.headers().get(cur) { 532 + Some(res_header_val) => acc.insert( 533 + cur.to_string(), 534 + res_header_val.clone().to_str().unwrap().to_string(), 535 + ), 536 + None => None, 537 + }; 538 + acc 539 + }, 540 + ); 541 + let buffer = read_array_buffer_res(res).await?; 542 + Ok(HandlerPipeThrough { 543 + encoding, 544 + buffer, 545 + headers: Some(res_headers), 546 + }) 547 + } 548 + 549 + // Utils 550 + // ------------------- 551 + 552 + pub async fn default_service(req: &ProxyRequest, nsid: &str) -> Option<ServiceConfig> { 553 + let cfg = req.cfg.clone(); 554 + match Ids::from_str(nsid) { 555 + Ok(Ids::ToolsOzoneTeamAddMember) => cfg.mod_service, 556 + Ok(Ids::ToolsOzoneTeamDeleteMember) => cfg.mod_service, 557 + Ok(Ids::ToolsOzoneTeamUpdateMember) => cfg.mod_service, 558 + Ok(Ids::ToolsOzoneTeamListMembers) => cfg.mod_service, 559 + Ok(Ids::ToolsOzoneCommunicationCreateTemplate) => cfg.mod_service, 560 + Ok(Ids::ToolsOzoneCommunicationDeleteTemplate) => cfg.mod_service, 561 + Ok(Ids::ToolsOzoneCommunicationUpdateTemplate) => cfg.mod_service, 562 + Ok(Ids::ToolsOzoneCommunicationListTemplates) => cfg.mod_service, 563 + Ok(Ids::ToolsOzoneModerationEmitEvent) => cfg.mod_service, 564 + Ok(Ids::ToolsOzoneModerationGetEvent) => cfg.mod_service, 565 + Ok(Ids::ToolsOzoneModerationGetRecord) => cfg.mod_service, 566 + Ok(Ids::ToolsOzoneModerationGetRepo) => cfg.mod_service, 567 + Ok(Ids::ToolsOzoneModerationQueryEvents) => cfg.mod_service, 568 + Ok(Ids::ToolsOzoneModerationQueryStatuses) => cfg.mod_service, 569 + Ok(Ids::ToolsOzoneModerationSearchRepos) => cfg.mod_service, 570 + Ok(Ids::ComAtprotoModerationCreateReport) => cfg.report_service, 571 + _ => cfg.bsky_app_view, 572 + } 573 + } 574 + 575 + pub fn parse_res<T: DeserializeOwned>(_nsid: String, res: HandlerPipeThrough) -> Result<T> { 576 + let buffer = res.buffer; 577 + let record = serde_json::from_slice::<T>(buffer.as_slice())?; 578 + Ok(record) 579 + } 580 + 581 + #[tracing::instrument(skip_all)] 582 + pub async fn read_array_buffer_res(res: Response) -> Result<Vec<u8>> { 583 + match res.bytes().await { 584 + Ok(bytes) => Ok(bytes.to_vec()), 585 + Err(err) => { 586 + tracing::error!("@LOG WARN: pipethrough network error {}", err.to_string()); 587 + bail!("UpstreamFailure") 588 + } 589 + } 590 + } 591 + 592 + pub fn is_safe_url(url: Url) -> bool { 593 + if url.scheme() != "https" { 594 + return false; 595 + } 596 + match url.host_str() { 597 + None => false, 598 + Some(hostname) if hostname == "localhost" => false, 599 + Some(hostname) => { 600 + if std::net::IpAddr::from_str(hostname).is_ok() { 601 + return false; 602 + } 603 + true 604 + } 605 + } 606 + }
-114
src/plc.rs
··· 1 - //! PLC operations. 2 - use std::collections::HashMap; 3 - 4 - use anyhow::{Context as _, bail}; 5 - use base64::Engine as _; 6 - use serde::{Deserialize, Serialize}; 7 - use tracing::debug; 8 - 9 - use crate::{Client, RotationKey}; 10 - 11 - /// The URL of the public PLC directory. 12 - const PLC_DIRECTORY: &str = "https://plc.directory/"; 13 - 14 - #[derive(Debug, Deserialize, Serialize, Clone)] 15 - #[serde(rename_all = "camelCase", tag = "type")] 16 - /// A PLC service. 17 - pub(crate) enum PlcService { 18 - #[serde(rename = "AtprotoPersonalDataServer")] 19 - /// A personal data server. 20 - Pds { 21 - /// The URL of the PDS. 22 - endpoint: String, 23 - }, 24 - } 25 - 26 - #[expect( 27 - clippy::arbitrary_source_item_ordering, 28 - reason = "serialized data might be structured" 29 - )] 30 - #[derive(Debug, Deserialize, Serialize, Clone)] 31 - #[serde(rename_all = "camelCase")] 32 - pub(crate) struct PlcOperation { 33 - #[serde(rename = "type")] 34 - pub typ: String, 35 - pub rotation_keys: Vec<String>, 36 - pub verification_methods: HashMap<String, String>, 37 - pub also_known_as: Vec<String>, 38 - pub services: HashMap<String, PlcService>, 39 - pub prev: Option<String>, 40 - } 41 - 42 - impl PlcOperation { 43 - /// Sign an operation with the provided signature. 44 - pub(crate) fn sign(self, sig: Vec<u8>) -> SignedPlcOperation { 45 - SignedPlcOperation { 46 - typ: self.typ, 47 - rotation_keys: self.rotation_keys, 48 - verification_methods: self.verification_methods, 49 - also_known_as: self.also_known_as, 50 - services: self.services, 51 - prev: self.prev, 52 - sig: base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig), 53 - } 54 - } 55 - } 56 - 57 - #[expect( 58 - clippy::arbitrary_source_item_ordering, 59 - reason = "serialized data might be structured" 60 - )] 61 - #[derive(Debug, Deserialize, Serialize, Clone)] 62 - #[serde(rename_all = "camelCase")] 63 - /// A signed PLC operation. 64 - pub(crate) struct SignedPlcOperation { 65 - #[serde(rename = "type")] 66 - pub typ: String, 67 - pub rotation_keys: Vec<String>, 68 - pub verification_methods: HashMap<String, String>, 69 - pub also_known_as: Vec<String>, 70 - pub services: HashMap<String, PlcService>, 71 - pub prev: Option<String>, 72 - pub sig: String, 73 - } 74 - 75 - pub(crate) fn sign_op(rkey: &RotationKey, op: PlcOperation) -> anyhow::Result<SignedPlcOperation> { 76 - let bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode op")?; 77 - let bytes = rkey.sign(&bytes).context("failed to sign op")?; 78 - 79 - Ok(op.sign(bytes)) 80 - } 81 - 82 - /// Submit a PLC operation to the public directory. 83 - pub(crate) async fn submit( 84 - client: &Client, 85 - did: &str, 86 - op: &SignedPlcOperation, 87 - ) -> anyhow::Result<()> { 88 - debug!( 89 - "submitting {} {}", 90 - did, 91 - serde_json::to_string(&op).context("should serialize")? 92 - ); 93 - 94 - let res = client 95 - .post(format!("{PLC_DIRECTORY}{did}")) 96 - .json(&op) 97 - .send() 98 - .await 99 - .context("failed to send directory request")?; 100 - 101 - if res.status().is_success() { 102 - Ok(()) 103 - } else { 104 - let e = res 105 - .json::<serde_json::Value>() 106 - .await 107 - .context("failed to read error response")?; 108 - 109 - bail!( 110 - "error from PLC directory: {}", 111 - serde_json::to_string(&e).context("should serialize")? 112 - ); 113 - } 114 - }
+429
src/serve.rs
··· 1 + use super::account_manager::AccountManager; 2 + use super::config::AppConfig; 3 + use super::db::establish_pool; 4 + pub use super::error::Error; 5 + use super::service_proxy::service_proxy; 6 + use anyhow::Context as _; 7 + use atrium_api::types::string::Did; 8 + use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 9 + use axum::{Router, extract::FromRef, routing::get}; 10 + use clap::Parser; 11 + use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 12 + use deadpool_diesel::sqlite::Pool; 13 + use diesel::prelude::*; 14 + use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 15 + use figment::{Figment, providers::Format as _}; 16 + use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 17 + use rsky_common::env::env_list; 18 + use rsky_identity::IdResolver; 19 + use rsky_identity::types::{DidCache, IdentityResolverOpts}; 20 + use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer}; 21 + use serde::{Deserialize, Serialize}; 22 + use std::env; 23 + use std::{ 24 + net::{IpAddr, Ipv4Addr, SocketAddr}, 25 + path::PathBuf, 26 + str::FromStr as _, 27 + sync::Arc, 28 + }; 29 + use tokio::{net::TcpListener, sync::RwLock}; 30 + use tower_http::{cors::CorsLayer, trace::TraceLayer}; 31 + use tracing::{info, warn}; 32 + use uuid::Uuid; 33 + 34 + /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 35 + pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 36 + 37 + /// Embedded migrations 38 + pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 39 + pub const MIGRATIONS_ACTOR: EmbeddedMigrations = embed_migrations!("./migrations_actor"); 40 + 41 + /// The application-wide result type. 42 + pub type Result<T> = std::result::Result<T, Error>; 43 + /// The reqwest client type with middleware. 44 + pub type Client = reqwest_middleware::ClientWithMiddleware; 45 + 46 + #[expect( 47 + clippy::arbitrary_source_item_ordering, 48 + reason = "serialized data might be structured" 49 + )] 50 + #[derive(Serialize, Deserialize, Debug, Clone)] 51 + /// The key data structure. 52 + struct KeyData { 53 + /// Primary signing key for all repo operations. 54 + skey: Vec<u8>, 55 + /// Primary signing (rotation) key for all PLC operations. 56 + rkey: Vec<u8>, 57 + } 58 + 59 + // FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 60 + // and the implementations of this algorithm are much more limited as compared to P256. 61 + // 62 + // Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 63 + #[derive(Clone)] 64 + /// The signing key for PLC/DID operations. 65 + pub struct SigningKey(Arc<Secp256k1Keypair>); 66 + #[derive(Clone)] 67 + /// The rotation key for PLC operations. 68 + pub struct RotationKey(Arc<Secp256k1Keypair>); 69 + 70 + impl std::ops::Deref for SigningKey { 71 + type Target = Secp256k1Keypair; 72 + 73 + fn deref(&self) -> &Self::Target { 74 + &self.0 75 + } 76 + } 77 + 78 + impl SigningKey { 79 + /// Import from a private key. 80 + pub fn import(key: &[u8]) -> Result<Self> { 81 + let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 82 + Ok(Self(Arc::new(key))) 83 + } 84 + } 85 + 86 + impl std::ops::Deref for RotationKey { 87 + type Target = Secp256k1Keypair; 88 + 89 + fn deref(&self) -> &Self::Target { 90 + &self.0 91 + } 92 + } 93 + 94 + #[derive(Parser, Debug, Clone)] 95 + /// Command line arguments. 96 + pub struct Args { 97 + /// Path to the configuration file 98 + #[arg(short, long, default_value = "default.toml")] 99 + pub config: PathBuf, 100 + /// The verbosity level. 101 + #[command(flatten)] 102 + pub verbosity: Verbosity<InfoLevel>, 103 + } 104 + 105 + /// The actor pools for the database connections. 106 + pub struct ActorStorage { 107 + /// The database connection pool for the actor's repository. 108 + pub repo: Pool, 109 + /// The file storage path for the actor's blobs. 110 + pub blob: PathBuf, 111 + } 112 + 113 + impl Clone for ActorStorage { 114 + fn clone(&self) -> Self { 115 + Self { 116 + repo: self.repo.clone(), 117 + blob: self.blob.clone(), 118 + } 119 + } 120 + } 121 + 122 + #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 123 + #[derive(Clone, FromRef)] 124 + /// The application state, shared across all routes. 125 + pub struct AppState { 126 + /// The application configuration. 127 + pub(crate) config: AppConfig, 128 + /// The main database connection pool. Used for common PDS data, like invite codes. 129 + pub db: Pool, 130 + /// Actor-specific database connection pools. Hashed by DID. 131 + pub db_actors: std::collections::HashMap<String, ActorStorage>, 132 + 133 + /// The HTTP client with middleware. 134 + pub client: Client, 135 + /// The simple HTTP client. 136 + pub simple_client: reqwest::Client, 137 + /// The firehose producer. 138 + pub sequencer: Arc<RwLock<Sequencer>>, 139 + /// The account manager. 140 + pub account_manager: Arc<RwLock<AccountManager>>, 141 + /// The ID resolver. 142 + pub id_resolver: Arc<RwLock<IdResolver>>, 143 + 144 + /// The signing key. 145 + pub signing_key: SigningKey, 146 + /// The rotation key. 147 + pub rotation_key: RotationKey, 148 + } 149 + 150 + /// The main application entry point. 151 + #[expect( 152 + clippy::cognitive_complexity, 153 + clippy::too_many_lines, 154 + unused_qualifications, 155 + reason = "main function has high complexity" 156 + )] 157 + pub async fn run() -> anyhow::Result<()> { 158 + let args = Args::parse(); 159 + 160 + // Set up trace logging to console and account for the user-provided verbosity flag. 161 + if args.verbosity.log_level_filter() != LevelFilter::Off { 162 + let lvl = match args.verbosity.log_level_filter() { 163 + LevelFilter::Error => tracing::Level::ERROR, 164 + LevelFilter::Warn => tracing::Level::WARN, 165 + LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 166 + LevelFilter::Debug => tracing::Level::DEBUG, 167 + LevelFilter::Trace => tracing::Level::TRACE, 168 + }; 169 + tracing_subscriber::fmt().with_max_level(lvl).init(); 170 + } 171 + 172 + if !args.config.exists() { 173 + // Throw up a warning if the config file does not exist. 174 + // 175 + // This is not fatal because users can specify all configuration settings via 176 + // the environment, but the most likely scenario here is that a user accidentally 177 + // omitted the config file for some reason (e.g. forgot to mount it into Docker). 178 + warn!( 179 + "configuration file {} does not exist", 180 + args.config.display() 181 + ); 182 + } 183 + 184 + // Read and parse the user-provided configuration. 185 + let config: AppConfig = Figment::new() 186 + .admerge(figment::providers::Toml::file(args.config)) 187 + .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 188 + .extract() 189 + .context("failed to load configuration")?; 190 + 191 + if config.test { 192 + warn!("BluePDS starting up in TEST mode."); 193 + warn!("This means the application will not federate with the rest of the network."); 194 + warn!( 195 + "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 196 + ); 197 + } 198 + 199 + // Initialize metrics reporting. 200 + super::metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 201 + 202 + // Create a reqwest client that will be used for all outbound requests. 203 + let simple_client = reqwest::Client::builder() 204 + .user_agent(APP_USER_AGENT) 205 + .build() 206 + .context("failed to build requester client")?; 207 + let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 208 + .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 209 + mode: CacheMode::Default, 210 + manager: MokaManager::default(), 211 + options: HttpCacheOptions::default(), 212 + })) 213 + .build(); 214 + 215 + tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 216 + .await 217 + .context("failed to create key directory")?; 218 + 219 + // Check if crypto keys exist. If not, create new ones. 220 + let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 221 + let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 222 + .context("failed to deserialize crypto keys")?; 223 + 224 + let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 225 + let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 226 + 227 + (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 228 + } else { 229 + info!("signing keys not found, generating new ones"); 230 + 231 + let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 232 + let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 233 + 234 + let keys = KeyData { 235 + skey: skey.export(), 236 + rkey: rkey.export(), 237 + }; 238 + 239 + let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 240 + serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 241 + 242 + (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 243 + }; 244 + 245 + tokio::fs::create_dir_all(&config.repo.path).await?; 246 + tokio::fs::create_dir_all(&config.plc.path).await?; 247 + tokio::fs::create_dir_all(&config.blob.path).await?; 248 + 249 + // Create a database connection manager and pool for the main database. 250 + let pool = 251 + establish_pool(&config.db).context("failed to establish database connection pool")?; 252 + 253 + // Create a dictionary of database connection pools for each actor. 254 + let mut actor_pools = std::collections::HashMap::new(); 255 + // We'll determine actors by looking in the data/repo dir for .db files. 256 + let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 257 + .await 258 + .context("failed to read repo directory")?; 259 + while let Some(entry) = actor_dbs 260 + .next_entry() 261 + .await 262 + .context("failed to read repo dir")? 263 + { 264 + let path = entry.path(); 265 + if path.extension().and_then(|s| s.to_str()) == Some("db") { 266 + let actor_repo_pool = establish_pool(&format!("sqlite://{}", path.display())) 267 + .context("failed to create database connection pool")?; 268 + 269 + let did = Did::from_str(&format!( 270 + "did:plc:{}", 271 + path.file_stem() 272 + .and_then(|s| s.to_str()) 273 + .context("failed to get actor DID")? 274 + )) 275 + .expect("should be able to parse actor DID") 276 + .to_string(); 277 + let blob_path = config.blob.path.to_path_buf(); 278 + let actor_storage = ActorStorage { 279 + repo: actor_repo_pool, 280 + blob: blob_path.clone(), 281 + }; 282 + drop(actor_pools.insert(did, actor_storage)); 283 + } 284 + } 285 + // Apply pending migrations 286 + // let conn = pool.get().await?; 287 + // conn.run_pending_migrations(MIGRATIONS) 288 + // .expect("should be able to run migrations"); 289 + 290 + let hostname = config.host_name.clone(); 291 + let crawlers: Vec<String> = config 292 + .firehose 293 + .relays 294 + .iter() 295 + .map(|s| s.to_string()) 296 + .collect(); 297 + let sequencer = Arc::new(RwLock::new(Sequencer::new( 298 + Crawlers::new(hostname, crawlers.clone()), 299 + None, 300 + ))); 301 + let account_manager = Arc::new(RwLock::new(AccountManager::new(pool.clone()))); 302 + let plc_url = if cfg!(debug_assertions) { 303 + "http://localhost:8000".to_owned() // dummy for debug 304 + } else { 305 + env::var("PDS_DID_PLC_URL").unwrap_or("https://plc.directory".to_owned()) // TODO: toml config 306 + }; 307 + let id_resolver = Arc::new(RwLock::new(IdResolver::new(IdentityResolverOpts { 308 + timeout: None, 309 + plc_url: Some(plc_url), 310 + did_cache: Some(DidCache::new(None, None)), 311 + backup_nameservers: Some(env_list("PDS_HANDLE_BACKUP_NAMESERVERS")), 312 + }))); 313 + 314 + let addr = config 315 + .listen_address 316 + .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 317 + 318 + let app = Router::new() 319 + .route("/", get(super::index)) 320 + .merge(super::oauth::routes()) 321 + .nest( 322 + "/xrpc", 323 + super::apis::routes() 324 + .merge(super::actor_endpoints::routes()) 325 + .fallback(service_proxy), 326 + ) 327 + // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 328 + .layer(CorsLayer::permissive()) 329 + .layer(TraceLayer::new_for_http()) 330 + .with_state(AppState { 331 + config: config.clone(), 332 + db: pool.clone(), 333 + db_actors: actor_pools.clone(), 334 + client: client.clone(), 335 + simple_client, 336 + sequencer: sequencer.clone(), 337 + account_manager, 338 + id_resolver, 339 + signing_key: skey, 340 + rotation_key: rkey, 341 + }); 342 + 343 + info!("listening on {addr}"); 344 + info!("connect to: http://127.0.0.1:{}", addr.port()); 345 + 346 + // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 347 + // If so, create an invite code and share it via the console. 348 + let conn = pool.get().await.context("failed to get db connection")?; 349 + 350 + #[derive(QueryableByName)] 351 + struct TotalCount { 352 + #[diesel(sql_type = diesel::sql_types::Integer)] 353 + total_count: i32, 354 + } 355 + 356 + let result = conn.interact(move |conn| { 357 + diesel::sql_query( 358 + "SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count", 359 + ) 360 + .get_result::<TotalCount>(conn) 361 + }) 362 + .await 363 + .expect("should be able to query database")?; 364 + 365 + let c = result.total_count; 366 + 367 + #[expect(clippy::print_stdout)] 368 + if c == 0 { 369 + let uuid = Uuid::new_v4().to_string(); 370 + 371 + use crate::models::pds as models; 372 + use crate::schema::pds::invite_code::dsl as InviteCode; 373 + let uuid_clone = uuid.clone(); 374 + drop( 375 + conn.interact(move |conn| { 376 + diesel::insert_into(InviteCode::invite_code) 377 + .values(models::InviteCode { 378 + code: uuid_clone, 379 + available_uses: 1, 380 + disabled: 0, 381 + for_account: "None".to_owned(), 382 + created_by: "None".to_owned(), 383 + created_at: "None".to_owned(), 384 + }) 385 + .execute(conn) 386 + .context("failed to create new invite code") 387 + }) 388 + .await 389 + .expect("should be able to create invite code"), 390 + ); 391 + 392 + // N.B: This is a sensitive message, so we're bypassing `tracing` here and 393 + // logging it directly to console. 394 + println!("====================================="); 395 + println!(" FIRST STARTUP "); 396 + println!("====================================="); 397 + println!("Use this code to create an account:"); 398 + println!("{uuid}"); 399 + println!("====================================="); 400 + } 401 + 402 + let listener = TcpListener::bind(&addr) 403 + .await 404 + .context("failed to bind address")?; 405 + 406 + // Serve the app, and request crawling from upstream relays. 407 + let serve = tokio::spawn(async move { 408 + axum::serve(listener, app.into_make_service()) 409 + .await 410 + .context("failed to serve app") 411 + }); 412 + 413 + // Now that the app is live, request a crawl from upstream relays. 414 + if cfg!(debug_assertions) { 415 + info!("debug mode: not requesting crawl"); 416 + } else { 417 + info!("requesting crawl from upstream relays"); 418 + let mut background_sequencer = sequencer.write().await.clone(); 419 + drop(tokio::spawn( 420 + async move { background_sequencer.start().await }, 421 + )); 422 + } 423 + 424 + serve 425 + .await 426 + .map_err(Into::into) 427 + .and_then(|r| r) 428 + .context("failed to serve app") 429 + }
+6 -26
src/service_proxy.rs
··· 3 3 /// Reference: <https://atproto.com/specs/xrpc#service-proxying> 4 4 use anyhow::{Context as _, anyhow}; 5 5 use atrium_api::types::string::Did; 6 - use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 7 6 use axum::{ 8 - Router, 9 7 body::Body, 10 - extract::{FromRef, Request, State}, 8 + extract::{Request, State}, 11 9 http::{self, HeaderMap, Response, StatusCode, Uri}, 12 - response::IntoResponse, 13 - routing::get, 14 10 }; 15 - use azure_core::credentials::TokenCredential; 16 - use clap::Parser; 17 - use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 18 - use deadpool_diesel::sqlite::Pool; 19 - use diesel::prelude::*; 20 - use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 21 - use figment::{Figment, providers::Format as _}; 22 - use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 23 11 use rand::Rng as _; 24 - use serde::{Deserialize, Serialize}; 25 - use std::{ 26 - net::{IpAddr, Ipv4Addr, SocketAddr}, 27 - path::PathBuf, 28 - str::FromStr as _, 29 - sync::Arc, 30 - }; 31 - use tokio::net::TcpListener; 32 - use tower_http::{cors::CorsLayer, trace::TraceLayer}; 33 - use tracing::{info, warn}; 34 - use uuid::Uuid; 12 + use std::str::FromStr as _; 35 13 36 - use super::{Client, Error, Result}; 37 - use crate::{AuthenticatedUser, SigningKey}; 14 + use super::{ 15 + auth::AuthenticatedUser, 16 + serve::{Client, Error, Result, SigningKey}, 17 + }; 38 18 39 19 pub(super) async fn service_proxy( 40 20 uri: Uri,
-459
src/tests.rs
··· 1 - //! Testing utilities for the PDS. 2 - #![expect(clippy::arbitrary_source_item_ordering)] 3 - use std::{ 4 - net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}, 5 - path::PathBuf, 6 - time::{Duration, Instant}, 7 - }; 8 - 9 - use anyhow::Result; 10 - use atrium_api::{ 11 - com::atproto::server, 12 - types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey}, 13 - }; 14 - use figment::{Figment, providers::Format as _}; 15 - use futures::future::join_all; 16 - use serde::{Deserialize, Serialize}; 17 - use tokio::sync::OnceCell; 18 - use uuid::Uuid; 19 - 20 - use crate::config::AppConfig; 21 - 22 - /// Global test state, created once for all tests. 23 - pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new(); 24 - 25 - /// A temporary test directory that will be cleaned up when the struct is dropped. 26 - struct TempDir { 27 - /// The path to the directory. 28 - path: PathBuf, 29 - } 30 - 31 - impl TempDir { 32 - /// Create a new temporary directory. 33 - fn new() -> Result<Self> { 34 - let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4())); 35 - std::fs::create_dir_all(&path)?; 36 - Ok(Self { path }) 37 - } 38 - 39 - /// Get the path to the directory. 40 - fn path(&self) -> &PathBuf { 41 - &self.path 42 - } 43 - } 44 - 45 - impl Drop for TempDir { 46 - fn drop(&mut self) { 47 - drop(std::fs::remove_dir_all(&self.path)); 48 - } 49 - } 50 - 51 - /// Test state for the application. 52 - pub(crate) struct TestState { 53 - /// The address the test server is listening on. 54 - address: SocketAddr, 55 - /// The HTTP client. 56 - client: reqwest::Client, 57 - /// The application configuration. 58 - config: AppConfig, 59 - /// The temporary directory for test data. 60 - #[expect(dead_code)] 61 - temp_dir: TempDir, 62 - } 63 - 64 - impl TestState { 65 - /// Get a base URL for the test server. 66 - pub(crate) fn base_url(&self) -> String { 67 - format!("http://{}", self.address) 68 - } 69 - 70 - /// Create a test account. 71 - pub(crate) async fn create_test_account(&self) -> Result<TestAccount> { 72 - // Create the account 73 - let handle = "test.handle"; 74 - let response = self 75 - .client 76 - .post(format!( 77 - "http://{}/xrpc/com.atproto.server.createAccount", 78 - self.address 79 - )) 80 - .json(&server::create_account::InputData { 81 - did: None, 82 - verification_code: None, 83 - verification_phone: None, 84 - email: Some(format!("{}@example.com", &handle)), 85 - handle: Handle::new(handle.to_owned()).expect("should be able to create handle"), 86 - password: Some("password123".to_owned()), 87 - invite_code: None, 88 - recovery_key: None, 89 - plc_op: None, 90 - }) 91 - .send() 92 - .await?; 93 - 94 - let account: server::create_account::Output = response.json().await?; 95 - 96 - Ok(TestAccount { 97 - handle: handle.to_owned(), 98 - did: account.did.to_string(), 99 - access_token: account.access_jwt.clone(), 100 - refresh_token: account.refresh_jwt.clone(), 101 - }) 102 - } 103 - 104 - /// Create a new test state. 105 - #[expect(clippy::unused_async)] 106 - async fn new() -> Result<Self> { 107 - // Configure the test app 108 - #[derive(Serialize, Deserialize)] 109 - struct TestConfigInput { 110 - db: Option<String>, 111 - host_name: Option<String>, 112 - key: Option<PathBuf>, 113 - listen_address: Option<SocketAddr>, 114 - test: Option<bool>, 115 - } 116 - // Create a temporary directory for test data 117 - let temp_dir = TempDir::new()?; 118 - 119 - // Find a free port 120 - let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; 121 - let address = listener.local_addr()?; 122 - drop(listener); 123 - 124 - let test_config = TestConfigInput { 125 - db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())), 126 - host_name: Some(format!("localhost:{}", address.port())), 127 - key: Some(temp_dir.path().join("test.key")), 128 - listen_address: Some(address), 129 - test: Some(true), 130 - }; 131 - 132 - let config: AppConfig = Figment::new() 133 - .admerge(figment::providers::Toml::file("default.toml")) 134 - .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 135 - .merge(figment::providers::Serialized::defaults(test_config)) 136 - .merge( 137 - figment::providers::Toml::string( 138 - r#" 139 - [firehose] 140 - relays = [] 141 - 142 - [repo] 143 - path = "repo" 144 - 145 - [plc] 146 - path = "plc" 147 - 148 - [blob] 149 - path = "blob" 150 - limit = 10485760 # 10 MB 151 - "#, 152 - ) 153 - .nested(), 154 - ) 155 - .extract()?; 156 - 157 - // Create directories 158 - std::fs::create_dir_all(temp_dir.path().join("repo"))?; 159 - std::fs::create_dir_all(temp_dir.path().join("plc"))?; 160 - std::fs::create_dir_all(temp_dir.path().join("blob"))?; 161 - 162 - // Create client 163 - let client = reqwest::Client::builder() 164 - .timeout(Duration::from_secs(30)) 165 - .build()?; 166 - 167 - Ok(Self { 168 - address, 169 - client, 170 - config, 171 - temp_dir, 172 - }) 173 - } 174 - 175 - /// Start the application in a background task. 176 - async fn start_app(&self) -> Result<()> { 177 - // // Get a reference to the config that can be moved into the task 178 - // let config = self.config.clone(); 179 - // let address = self.address; 180 - 181 - // // Start the application in a background task 182 - // let _handle = tokio::spawn(async move { 183 - // // Set up the application 184 - // use crate::*; 185 - 186 - // // Initialize metrics (noop in test mode) 187 - // drop(metrics::setup(None)); 188 - 189 - // // Create client 190 - // let simple_client = reqwest::Client::builder() 191 - // .user_agent(APP_USER_AGENT) 192 - // .build() 193 - // .context("failed to build requester client")?; 194 - // let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 195 - // .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 196 - // mode: CacheMode::Default, 197 - // manager: MokaManager::default(), 198 - // options: HttpCacheOptions::default(), 199 - // })) 200 - // .build(); 201 - 202 - // // Create a test keypair 203 - // std::fs::create_dir_all(config.key.parent().context("should have parent")?)?; 204 - // let (skey, rkey) = { 205 - // let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 206 - // let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 207 - 208 - // let keys = KeyData { 209 - // skey: skey.export(), 210 - // rkey: rkey.export(), 211 - // }; 212 - 213 - // let mut f = 214 - // std::fs::File::create(&config.key).context("failed to create key file")?; 215 - // serde_ipld_dagcbor::to_writer(&mut f, &keys) 216 - // .context("failed to serialize crypto keys")?; 217 - 218 - // (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 219 - // }; 220 - 221 - // // Set up database 222 - // let opts = SqliteConnectOptions::from_str(&config.db) 223 - // .context("failed to parse database options")? 224 - // .create_if_missing(true); 225 - // let db = SqliteDbConn::connect_with(opts).await?; 226 - 227 - // sqlx::migrate!() 228 - // .run(&db) 229 - // .await 230 - // .context("failed to apply migrations")?; 231 - 232 - // // Create firehose 233 - // let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 234 - 235 - // // Create the application state 236 - // let app_state = AppState { 237 - // cred: azure_identity::DefaultAzureCredential::new()?, 238 - // config: config.clone(), 239 - // db: db.clone(), 240 - // client: client.clone(), 241 - // simple_client, 242 - // firehose: fhp, 243 - // signing_key: skey, 244 - // rotation_key: rkey, 245 - // }; 246 - 247 - // // Create the router 248 - // let app = Router::new() 249 - // .route("/", get(index)) 250 - // .merge(oauth::routes()) 251 - // .nest( 252 - // "/xrpc", 253 - // endpoints::routes() 254 - // .merge(actor_endpoints::routes()) 255 - // .fallback(service_proxy), 256 - // ) 257 - // .layer(CorsLayer::permissive()) 258 - // .layer(TraceLayer::new_for_http()) 259 - // .with_state(app_state); 260 - 261 - // // Listen for connections 262 - // let listener = TcpListener::bind(&address) 263 - // .await 264 - // .context("failed to bind address")?; 265 - 266 - // axum::serve(listener, app.into_make_service()) 267 - // .await 268 - // .context("failed to serve app") 269 - // }); 270 - 271 - // // Give the server a moment to start 272 - // tokio::time::sleep(Duration::from_millis(500)).await; 273 - 274 - Ok(()) 275 - } 276 - } 277 - 278 - /// A test account that can be used for testing. 279 - pub(crate) struct TestAccount { 280 - /// The access token for the account. 281 - pub(crate) access_token: String, 282 - /// The account DID. 283 - pub(crate) did: String, 284 - /// The account handle. 285 - pub(crate) handle: String, 286 - /// The refresh token for the account. 287 - #[expect(dead_code)] 288 - pub(crate) refresh_token: String, 289 - } 290 - 291 - /// Initialize the test state. 292 - pub(crate) async fn init_test_state() -> Result<&'static TestState> { 293 - async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> { 294 - let state = TestState::new().await?; 295 - state.start_app().await?; 296 - Ok(state) 297 - } 298 - TEST_STATE.get_or_try_init(init_test_state).await 299 - } 300 - 301 - /// Create a record benchmark that creates records and measures the time it takes. 302 - #[expect( 303 - clippy::arithmetic_side_effects, 304 - clippy::integer_division, 305 - clippy::integer_division_remainder_used, 306 - clippy::use_debug, 307 - clippy::print_stdout 308 - )] 309 - pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> { 310 - // Initialize the test state 311 - let state = init_test_state().await?; 312 - 313 - // Create a test account 314 - let account = state.create_test_account().await?; 315 - 316 - // Create the client with authorization 317 - let client = reqwest::Client::builder() 318 - .timeout(Duration::from_secs(30)) 319 - .build()?; 320 - 321 - let start = Instant::now(); 322 - 323 - // Split the work into batches 324 - let mut handles = Vec::new(); 325 - for batch_idx in 0..concurrent { 326 - let batch_size = count / concurrent; 327 - let client = client.clone(); 328 - let base_url = state.base_url(); 329 - let account_did = account.did.clone(); 330 - let account_handle = account.handle.clone(); 331 - let access_token = account.access_token.clone(); 332 - 333 - let handle = tokio::spawn(async move { 334 - let mut results = Vec::new(); 335 - 336 - for i in 0..batch_size { 337 - let request_start = Instant::now(); 338 - let record_idx = batch_idx * batch_size + i; 339 - 340 - let result = client 341 - .post(format!("{base_url}/xrpc/com.atproto.repo.createRecord")) 342 - .header("Authorization", format!("Bearer {access_token}")) 343 - .json(&atrium_api::com::atproto::repo::create_record::InputData { 344 - repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")), 345 - collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"), 346 - rkey: Some( 347 - RecordKey::new(format!("test-{record_idx}")).expect("valid record key"), 348 - ), 349 - validate: None, 350 - record: serde_json::from_str( 351 - &serde_json::json!({ 352 - "$type": "app.bsky.feed.post", 353 - "text": format!("Test post {record_idx} from {account_handle}"), 354 - "createdAt": chrono::Utc::now().to_rfc3339(), 355 - }) 356 - .to_string(), 357 - ) 358 - .expect("valid JSON record"), 359 - swap_commit: None, 360 - }) 361 - .send() 362 - .await; 363 - 364 - // Fetch the record we just created 365 - let get_response = client 366 - .get(format!( 367 - "{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}" 368 - )) 369 - .header("Authorization", format!("Bearer {access_token}")) 370 - .send() 371 - .await; 372 - if get_response.is_err() { 373 - println!("Failed to fetch record {record_idx}: {get_response:?}"); 374 - results.push(get_response); 375 - continue; 376 - } 377 - 378 - let request_duration = request_start.elapsed(); 379 - if record_idx % 10 == 0 { 380 - println!("Created record {record_idx} in {request_duration:?}"); 381 - } 382 - results.push(result); 383 - } 384 - 385 - results 386 - }); 387 - 388 - handles.push(handle); 389 - } 390 - 391 - // Wait for all batches to complete 392 - let results = join_all(handles).await; 393 - 394 - // Check for errors 395 - for batch_result in results { 396 - let batch_responses = batch_result?; 397 - for response_result in batch_responses { 398 - match response_result { 399 - Ok(response) => { 400 - if !response.status().is_success() { 401 - return Err(anyhow::anyhow!( 402 - "Failed to create record: {}", 403 - response.status() 404 - )); 405 - } 406 - } 407 - Err(err) => { 408 - return Err(anyhow::anyhow!("Failed to create record: {}", err)); 409 - } 410 - } 411 - } 412 - } 413 - 414 - let duration = start.elapsed(); 415 - Ok(duration) 416 - } 417 - 418 - #[cfg(test)] 419 - #[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)] 420 - mod tests { 421 - use super::*; 422 - use anyhow::anyhow; 423 - 424 - #[tokio::test] 425 - async fn test_create_account() -> Result<()> { 426 - return Ok(()); 427 - #[expect(unreachable_code, reason = "Disabled")] 428 - let state = init_test_state().await?; 429 - let account = state.create_test_account().await?; 430 - 431 - println!("Created test account: {}", account.handle); 432 - if account.handle.is_empty() { 433 - return Err(anyhow::anyhow!("Account handle is empty")); 434 - } 435 - if account.did.is_empty() { 436 - return Err(anyhow::anyhow!("Account DID is empty")); 437 - } 438 - if account.access_token.is_empty() { 439 - return Err(anyhow::anyhow!("Account access token is empty")); 440 - } 441 - 442 - Ok(()) 443 - } 444 - 445 - #[tokio::test] 446 - async fn test_create_record_benchmark() -> Result<()> { 447 - return Ok(()); 448 - #[expect(unreachable_code, reason = "Disabled")] 449 - let duration = create_record_benchmark(100, 1).await?; 450 - 451 - println!("Created 100 records in {duration:?}"); 452 - 453 - if duration.as_secs() >= 10 { 454 - return Err(anyhow!("Benchmark took too long")); 455 - } 456 - 457 - Ok(()) 458 - } 459 - }