Alternative ATProto PDS implementation

Compare changes

Choose any two refs to compare.

Changed files
+6915 -6414
.nix
.sqlx
migrations
src
+3 -2
.nix/flake.nix
··· 26 26 git 27 27 nixd 28 28 direnv 29 + libpq 29 30 ]; 30 31 overlays = [ (import rust-overlay) ]; 31 32 pkgs = import nixpkgs { ··· 41 42 nativeBuildInputs = with pkgs; [ rust pkg-config ]; 42 43 in 43 44 with pkgs; 44 - { 45 + { 45 46 devShells.default = mkShell { 46 47 inherit buildInputs nativeBuildInputs; 47 48 LD_LIBRARY_PATH = nixpkgs.legacyPackages.x86_64-linux.lib.makeLibraryPath buildInputs; ··· 49 50 DATABASE_URL = "sqlite://data/sqlite.db"; 50 51 }; 51 52 }); 52 - } 53 + }
-20
.sqlx/query-02a5737bb92665ef0a3dac013eb03366ab6b31a5c4ab856e6458a52704b86e23.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT COUNT(*) FROM oauth_used_jtis WHERE jti = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "COUNT(*)", 8 - "ordinal": 0, 9 - "type_info": "Integer" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "02a5737bb92665ef0a3dac013eb03366ab6b31a5c4ab856e6458a52704b86e23" 20 - }
-12
.sqlx/query-19dc08b9f2f609e0610b6bd1e4908fc5d7922cc95b13de3214a055bf36b80284.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO invites (id, did, count, created_at)\n VALUES (?, NULL, 1, datetime('now'))\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 1 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "19dc08b9f2f609e0610b6bd1e4908fc5d7922cc95b13de3214a055bf36b80284" 12 - }
-20
.sqlx/query-1db52857493a1e8a7004872eaff6e8fe5dec41579dd57d696008385b8d23788d.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT data FROM blocks WHERE cid = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "data", 8 - "ordinal": 0, 9 - "type_info": "Blob" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "1db52857493a1e8a7004872eaff6e8fe5dec41579dd57d696008385b8d23788d" 20 - }
-20
.sqlx/query-22c1e98ac038509ad16ce437e6670a59d3fc97a05ea8b0f1f80dba0157c53e13.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT name FROM actor_migration", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "name", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 0 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "22c1e98ac038509ad16ce437e6670a59d3fc97a05ea8b0f1f80dba0157c53e13" 20 - }
-62
.sqlx/query-243e2127a5181657d5e08c981a7a6d395fb2112ebf7a1a676d57c33866310add.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT * FROM oauth_refresh_tokens\n WHERE token = ? AND client_id = ? AND expires_at > ? AND revoked = FALSE AND dpop_thumbprint = ?\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "token", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "client_id", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "subject", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - }, 21 - { 22 - "name": "dpop_thumbprint", 23 - "ordinal": 3, 24 - "type_info": "Text" 25 - }, 26 - { 27 - "name": "scope", 28 - "ordinal": 4, 29 - "type_info": "Text" 30 - }, 31 - { 32 - "name": "created_at", 33 - "ordinal": 5, 34 - "type_info": "Integer" 35 - }, 36 - { 37 - "name": "expires_at", 38 - "ordinal": 6, 39 - "type_info": "Integer" 40 - }, 41 - { 42 - "name": "revoked", 43 - "ordinal": 7, 44 - "type_info": "Bool" 45 - } 46 - ], 47 - "parameters": { 48 - "Right": 4 49 - }, 50 - "nullable": [ 51 - false, 52 - false, 53 - false, 54 - false, 55 - true, 56 - false, 57 - false, 58 - false 59 - ] 60 - }, 61 - "hash": "243e2127a5181657d5e08c981a7a6d395fb2112ebf7a1a676d57c33866310add" 62 - }
-12
.sqlx/query-2918ecf03675a789568c777904966911ca63e991dede42a2d7d87e174799ea46.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "INSERT INTO blocks (cid, data, multicodec, multihash) VALUES (?, ?, ?, ?)", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 4 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "2918ecf03675a789568c777904966911ca63e991dede42a2d7d87e174799ea46" 12 - }
-20
.sqlx/query-2e13e052dfc64f29d9da1bce2bf844cbb918ad3bb01e386801d3b0d3be246573.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT COUNT(*) FROM oauth_refresh_tokens WHERE dpop_thumbprint = ? AND client_id = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "COUNT(*)", 8 - "ordinal": 0, 9 - "type_info": "Integer" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 2 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "2e13e052dfc64f29d9da1bce2bf844cbb918ad3bb01e386801d3b0d3be246573" 20 - }
-32
.sqlx/query-3516a6de0f3aa40b301d60479f5c34d0fd21a800328a05458ecc3ac688d016e6.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT a.email, a.status, (\n SELECT h.handle\n FROM handles h\n WHERE h.did = a.did\n ORDER BY h.created_at ASC\n LIMIT 1\n ) AS handle\n FROM accounts a\n WHERE a.did = ?\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "email", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "status", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "handle", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - } 21 - ], 22 - "parameters": { 23 - "Right": 1 24 - }, 25 - "nullable": [ 26 - false, 27 - false, 28 - false 29 - ] 30 - }, 31 - "hash": "3516a6de0f3aa40b301d60479f5c34d0fd21a800328a05458ecc3ac688d016e6" 32 - }
-20
.sqlx/query-3b4745208f268678a84401e522c3836e0632ca34a0f23bbae5297d076610f0ab.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT content FROM repo_block WHERE cid = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "content", 8 - "ordinal": 0, 9 - "type_info": "Blob" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "3b4745208f268678a84401e522c3836e0632ca34a0f23bbae5297d076610f0ab" 20 - }
-20
.sqlx/query-3d1a877177899665c37393beae31a399054b7c02d3871c6c5d317923fec8442e.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT did FROM handles WHERE handle = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "did", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "3d1a877177899665c37393beae31a399054b7c02d3871c6c5d317923fec8442e" 20 - }
-20
.sqlx/query-4198b96804f3a0a805e441857b452e84a083d80dca12ce95c545dc9eadbac0c3.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT plc_root FROM accounts WHERE did = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "plc_root", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "4198b96804f3a0a805e441857b452e84a083d80dca12ce95c545dc9eadbac0c3" 20 - }
-12
.sqlx/query-459be26080e3497b3807d22e86377eee9e19366709864e3369c867cef01c83bb.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO repo_block (cid, repoRev, size, content)\n VALUES (?, ?, ?, ?)\n ON CONFLICT DO NOTHING\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 4 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "459be26080e3497b3807d22e86377eee9e19366709864e3369c867cef01c83bb" 12 - }
-26
.sqlx/query-50a7b5f57df41d06a8c11c8268d8dbef4c76bcf92c6b47b6316bf5e39fb889a7.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT a.status, h.handle\n FROM accounts a\n JOIN handles h ON a.did = h.did\n WHERE a.did = ?\n ORDER BY h.created_at ASC\n LIMIT 1\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "status", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "handle", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Right": 1 19 - }, 20 - "nullable": [ 21 - false, 22 - false 23 - ] 24 - }, 25 - "hash": "50a7b5f57df41d06a8c11c8268d8dbef4c76bcf92c6b47b6316bf5e39fb889a7" 26 - }
-12
.sqlx/query-51f7f9d5bf4cbfe372a8fa130f4cabcb57766638792d61297df2fb91c2fe2937.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO repo_root (did, cid, rev, indexedAt)\n VALUES (?, ?, ?, ?)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 4 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "51f7f9d5bf4cbfe372a8fa130f4cabcb57766638792d61297df2fb91c2fe2937" 12 - }
-12
.sqlx/query-5bbf8300ca519576e4f60074cf16756bc1dca79f43e1e89c5a08b8c9d95d241f.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO repo_block (cid, repoRev, size, content)\n VALUES (?, ?, ?, ?)\n ON CONFLICT DO NOTHING\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 4 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "5bbf8300ca519576e4f60074cf16756bc1dca79f43e1e89c5a08b8c9d95d241f" 12 - }
-12
.sqlx/query-5d4586821dff3ed0fd1e352946751c3bb66610a472d8c42a7bfa3a565fccc30a.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO oauth_authorization_codes (\n code, client_id, subject, code_challenge, code_challenge_method,\n redirect_uri, scope, created_at, expires_at, used\n ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 10 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "5d4586821dff3ed0fd1e352946751c3bb66610a472d8c42a7bfa3a565fccc30a" 12 - }
-12
.sqlx/query-5ea8376fbbe3077b2fc62187cc29a2d03eda91fa468c7fe63306f04e160ecb5d.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "INSERT INTO actor_migration (name, appliedAt) VALUES (?, ?)", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 2 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "5ea8376fbbe3077b2fc62187cc29a2d03eda91fa468c7fe63306f04e160ecb5d" 12 - }
-26
.sqlx/query-5f17a390750b52886f8c3ba80cb16776f3430bc91c4158aafb3012a7812a97cc.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT rev, status FROM accounts WHERE did = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "rev", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "status", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Right": 1 19 - }, 20 - "nullable": [ 21 - false, 22 - false 23 - ] 24 - }, 25 - "hash": "5f17a390750b52886f8c3ba80cb16776f3430bc91c4158aafb3012a7812a97cc" 26 - }
-32
.sqlx/query-6b0a871527c5c37663ee17ec6f5ec4f97521900f45e549b0b065004a4e2e6207.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n WITH LatestHandles AS (\n SELECT did, handle\n FROM handles\n WHERE (did, created_at) IN (\n SELECT did, MAX(created_at) AS max_created_at\n FROM handles\n GROUP BY did\n )\n )\n SELECT a.did, a.password, h.handle\n FROM accounts a\n LEFT JOIN LatestHandles h ON a.did = h.did\n WHERE h.handle = ?\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "did", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "password", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "handle", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - } 21 - ], 22 - "parameters": { 23 - "Right": 1 24 - }, 25 - "nullable": [ 26 - false, 27 - false, 28 - false 29 - ] 30 - }, 31 - "hash": "6b0a871527c5c37663ee17ec6f5ec4f97521900f45e549b0b065004a4e2e6207" 32 - }
-20
.sqlx/query-73fd3e30b7694c92cf9309751d186fe622fa7d99fdf56dde7e60c3696581116c.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT COUNT(*) FROM blocks WHERE cid = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "COUNT(*)", 8 - "ordinal": 0, 9 - "type_info": "Integer" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "73fd3e30b7694c92cf9309751d186fe622fa7d99fdf56dde7e60c3696581116c" 20 - }
-32
.sqlx/query-7eb22fdfc107b33361c599fcd4ae3a4a4fafef8438c41e1fdc6d4f7fd44f1094.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT did, root, rev FROM accounts LIMIT ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "did", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "root", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "rev", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - } 21 - ], 22 - "parameters": { 23 - "Right": 1 24 - }, 25 - "nullable": [ 26 - false, 27 - false, 28 - false 29 - ] 30 - }, 31 - "hash": "7eb22fdfc107b33361c599fcd4ae3a4a4fafef8438c41e1fdc6d4f7fd44f1094" 32 - }
-20
.sqlx/query-813409fb7218c548ee3e8b1226559686cd40aa81ac1b68659b087276cbb0137d.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT cid FROM blob_ref WHERE did = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "cid", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "813409fb7218c548ee3e8b1226559686cd40aa81ac1b68659b087276cbb0137d" 20 - }
-20
.sqlx/query-865f757ca7c8b15357622bf0d1a25745288f87ad6ace019c1f4316a4ba1efb34.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT revoked FROM oauth_refresh_tokens WHERE token = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "revoked", 8 - "ordinal": 0, 9 - "type_info": "Bool" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "865f757ca7c8b15357622bf0d1a25745288f87ad6ace019c1f4316a4ba1efb34" 20 - }
-12
.sqlx/query-87cbc4f5bb615163ff62234e0de0c69b543179cffcdaf79fcae5fd6fdc7e14c7.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 1 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "87cbc4f5bb615163ff62234e0de0c69b543179cffcdaf79fcae5fd6fdc7e14c7" 12 - }
-74
.sqlx/query-92858ad9b0a35c3b8d4be795f88325aa4a1995f53fc90ef455ef9a499335f088.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT * FROM oauth_authorization_codes\n WHERE code = ? AND client_id = ? AND redirect_uri = ? AND expires_at > ? AND used = FALSE\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "code", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "client_id", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "subject", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - }, 21 - { 22 - "name": "code_challenge", 23 - "ordinal": 3, 24 - "type_info": "Text" 25 - }, 26 - { 27 - "name": "code_challenge_method", 28 - "ordinal": 4, 29 - "type_info": "Text" 30 - }, 31 - { 32 - "name": "redirect_uri", 33 - "ordinal": 5, 34 - "type_info": "Text" 35 - }, 36 - { 37 - "name": "scope", 38 - "ordinal": 6, 39 - "type_info": "Text" 40 - }, 41 - { 42 - "name": "created_at", 43 - "ordinal": 7, 44 - "type_info": "Integer" 45 - }, 46 - { 47 - "name": "expires_at", 48 - "ordinal": 8, 49 - "type_info": "Integer" 50 - }, 51 - { 52 - "name": "used", 53 - "ordinal": 9, 54 - "type_info": "Bool" 55 - } 56 - ], 57 - "parameters": { 58 - "Right": 4 59 - }, 60 - "nullable": [ 61 - false, 62 - false, 63 - false, 64 - false, 65 - false, 66 - false, 67 - true, 68 - false, 69 - false, 70 - false 71 - ] 72 - }, 73 - "hash": "92858ad9b0a35c3b8d4be795f88325aa4a1995f53fc90ef455ef9a499335f088" 74 - }
-26
.sqlx/query-9890e97761e6ed1256ed32775ad4f394e199b5a3588a711ea8ad672cf666eee4.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT cid, rev FROM repo_root WHERE did = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "cid", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "rev", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Right": 1 19 - }, 20 - "nullable": [ 21 - false, 22 - false 23 - ] 24 - }, 25 - "hash": "9890e97761e6ed1256ed32775ad4f394e199b5a3588a711ea8ad672cf666eee4" 26 - }
-12
.sqlx/query-9a04bdf627ee146ddaac6cdd1bacf2106b22bc215ef22ab400cd62b4353f414b.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "UPDATE accounts SET private_prefs = ? WHERE did = ?", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 2 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "9a04bdf627ee146ddaac6cdd1bacf2106b22bc215ef22ab400cd62b4353f414b" 12 - }
-26
.sqlx/query-9b6ac33211a2231754650bb0daca5ffb980c9e530ea47dd892aa06fab1450a05.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT cid, content\n FROM repo_block\n WHERE repoRev = ?\n LIMIT 15\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "cid", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "content", 13 - "ordinal": 1, 14 - "type_info": "Blob" 15 - } 16 - ], 17 - "parameters": { 18 - "Right": 1 19 - }, 20 - "nullable": [ 21 - false, 22 - false 23 - ] 24 - }, 25 - "hash": "9b6ac33211a2231754650bb0daca5ffb980c9e530ea47dd892aa06fab1450a05" 26 - }
-38
.sqlx/query-a16bb62753f6568238cab50d3a597d279db5564d3bcc1f8606850d5442aaf20a.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n WITH LatestHandles AS (\n SELECT did, handle\n FROM handles\n WHERE (did, created_at) IN (\n SELECT did, MAX(created_at) AS max_created_at\n FROM handles\n GROUP BY did\n )\n )\n SELECT a.did, a.email, a.password, h.handle\n FROM accounts a\n LEFT JOIN LatestHandles h ON a.did = h.did\n WHERE h.handle = ?\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "did", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "email", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "password", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - }, 21 - { 22 - "name": "handle", 23 - "ordinal": 3, 24 - "type_info": "Text" 25 - } 26 - ], 27 - "parameters": { 28 - "Right": 1 29 - }, 30 - "nullable": [ 31 - false, 32 - false, 33 - false, 34 - false 35 - ] 36 - }, 37 - "hash": "a16bb62753f6568238cab50d3a597d279db5564d3bcc1f8606850d5442aaf20a" 38 - }
-12
.sqlx/query-a527a1863a9a2f5ba129c1f5ee9d0cdc78e0c69de43c7da1f9a936222c17c4bf.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO accounts (did, email, password, root, plc_root, rev, created_at)\n VALUES (?, ?, ?, ?, ?, ?, datetime('now'));\n\n INSERT INTO handles (did, handle, created_at)\n VALUES (?, ?, datetime('now'));\n\n -- Cleanup stale invite codes\n DELETE FROM invites\n WHERE count <= 0;\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 8 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "a527a1863a9a2f5ba129c1f5ee9d0cdc78e0c69de43c7da1f9a936222c17c4bf" 12 - }
-12
.sqlx/query-a9fbd43dbd50907f550a2221dab552ff5a00d7f00d7223b4cee745354f77c532.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n UPDATE repo_root\n SET cid = ?, rev = ?, indexedAt = ?\n WHERE did = ?\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 4 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "a9fbd43dbd50907f550a2221dab552ff5a00d7f00d7223b4cee745354f77c532" 12 - }
-92
.sqlx/query-b4e6da72ee82515d2ff739c805e1c0ccb837d06c62d338dd782a3ea375f7eee3.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT * FROM oauth_par_requests\n WHERE request_uri = ? AND client_id = ? AND expires_at > ?\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "request_uri", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "client_id", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "response_type", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - }, 21 - { 22 - "name": "code_challenge", 23 - "ordinal": 3, 24 - "type_info": "Text" 25 - }, 26 - { 27 - "name": "code_challenge_method", 28 - "ordinal": 4, 29 - "type_info": "Text" 30 - }, 31 - { 32 - "name": "state", 33 - "ordinal": 5, 34 - "type_info": "Text" 35 - }, 36 - { 37 - "name": "login_hint", 38 - "ordinal": 6, 39 - "type_info": "Text" 40 - }, 41 - { 42 - "name": "scope", 43 - "ordinal": 7, 44 - "type_info": "Text" 45 - }, 46 - { 47 - "name": "redirect_uri", 48 - "ordinal": 8, 49 - "type_info": "Text" 50 - }, 51 - { 52 - "name": "response_mode", 53 - "ordinal": 9, 54 - "type_info": "Text" 55 - }, 56 - { 57 - "name": "display", 58 - "ordinal": 10, 59 - "type_info": "Text" 60 - }, 61 - { 62 - "name": "created_at", 63 - "ordinal": 11, 64 - "type_info": "Integer" 65 - }, 66 - { 67 - "name": "expires_at", 68 - "ordinal": 12, 69 - "type_info": "Integer" 70 - } 71 - ], 72 - "parameters": { 73 - "Right": 3 74 - }, 75 - "nullable": [ 76 - false, 77 - false, 78 - false, 79 - false, 80 - false, 81 - true, 82 - true, 83 - true, 84 - true, 85 - true, 86 - true, 87 - false, 88 - false 89 - ] 90 - }, 91 - "hash": "b4e6da72ee82515d2ff739c805e1c0ccb837d06c62d338dd782a3ea375f7eee3" 92 - }
-12
.sqlx/query-bcef1b9aeaf0db7ac4b2e8f4b3ec40b425e48af26cf91496208c04e31239f7c6.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "DELETE FROM oauth_used_jtis WHERE expires_at < ?", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 1 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "bcef1b9aeaf0db7ac4b2e8f4b3ec40b425e48af26cf91496208c04e31239f7c6" 12 - }
-12
.sqlx/query-c51b4c9de70b5be51a6e0a5fd744387ae804e8ba978b61c4d04d74b1f8de2614.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "UPDATE oauth_refresh_tokens SET revoked = TRUE\n WHERE client_id = ? AND dpop_thumbprint = ?", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 2 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "c51b4c9de70b5be51a6e0a5fd744387ae804e8ba978b61c4d04d74b1f8de2614" 12 - }
-20
.sqlx/query-cc1c5a90cfd95024cb03fe579941f296b1ac1230cce5819ae9f6eb03c8b19398.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT\n (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites)\n AS total_count\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "total_count", 8 - "ordinal": 0, 9 - "type_info": "Integer" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 0 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "cc1c5a90cfd95024cb03fe579941f296b1ac1230cce5819ae9f6eb03c8b19398" 20 - }
-12
.sqlx/query-cd91f7a134089bb77cac221a9bcc489b6d6860123f755c1ee2068e32dc687301.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO oauth_refresh_tokens (\n token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked\n ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 8 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "cd91f7a134089bb77cac221a9bcc489b6d6860123f755c1ee2068e32dc687301" 12 - }
-12
.sqlx/query-d1408c77d790337a265891b5502a59a62a5d1d01e787dea74b753b1fab794b3a.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO oauth_par_requests (\n request_uri, client_id, response_type, code_challenge, code_challenge_method,\n state, login_hint, scope, redirect_uri, response_mode, display,\n created_at, expires_at\n ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 13 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "d1408c77d790337a265891b5502a59a62a5d1d01e787dea74b753b1fab794b3a" 12 - }
-26
.sqlx/query-d1c3ea6ebc19b0362851ebd0b8c8a0b9c87d5cddf4f03670636d29ba5ceb9435.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT cid, rev\n FROM repo_root\n WHERE did = ?\n LIMIT 1\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "cid", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "rev", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - } 16 - ], 17 - "parameters": { 18 - "Right": 1 19 - }, 20 - "nullable": [ 21 - false, 22 - false 23 - ] 24 - }, 25 - "hash": "d1c3ea6ebc19b0362851ebd0b8c8a0b9c87d5cddf4f03670636d29ba5ceb9435" 26 - }
-12
.sqlx/query-d39b83ec2f091556e6fb5e4d729b8e6fa1cc966855f934e2b1611d8a26614849.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "UPDATE accounts SET plc_root = ? WHERE did = ?", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 2 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "d39b83ec2f091556e6fb5e4d729b8e6fa1cc966855f934e2b1611d8a26614849" 12 - }
-12
.sqlx/query-d6ddbce18d6a78a78e8713a0f0b1499517aae7ab9f49744a4cf8a722e03f82fa.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO oauth_used_jtis (jti, issuer, created_at, expires_at)\n VALUES (?, ?, ?, ?)\n ", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 4 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "d6ddbce18d6a78a78e8713a0f0b1499517aae7ab9f49744a4cf8a722e03f82fa" 12 - }
-20
.sqlx/query-dbedb512e10704bc9f0e571314ff68724edf10b76a62071bd1ef04a68c708890.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n INSERT INTO invites (id, did, count, created_at)\n VALUES (?, ?, ?, datetime('now'))\n RETURNING id\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "id", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 3 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "dbedb512e10704bc9f0e571314ff68724edf10b76a62071bd1ef04a68c708890" 20 - }
-20
.sqlx/query-dc444d99848fff3578add45fb464004c0797ef7d455652cb92f2c7de8a7f8cc4.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT status FROM accounts WHERE did = ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "status", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "dc444d99848fff3578add45fb464004c0797ef7d455652cb92f2c7de8a7f8cc4" 20 - }
-20
.sqlx/query-e26b7c36a34130e350f3f3e06b3200c56a0e3330ac0b658de6bbdb39b5497fab.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n UPDATE invites\n SET count = count - 1\n WHERE id = ?\n AND count > 0\n RETURNING id\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "id", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "e26b7c36a34130e350f3f3e06b3200c56a0e3330ac0b658de6bbdb39b5497fab" 20 - }
-12
.sqlx/query-e4bd80a305f929229b234b79b1e9e90a36af0e630c8c7530b6d935c6e32d381f.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "UPDATE oauth_authorization_codes SET used = TRUE WHERE code = ?", 4 - "describe": { 5 - "columns": [], 6 - "parameters": { 7 - "Right": 1 8 - }, 9 - "nullable": [] 10 - }, 11 - "hash": "e4bd80a305f929229b234b79b1e9e90a36af0e630c8c7530b6d935c6e32d381f" 12 - }
-20
.sqlx/query-e6007f29d6b7681d7a1f5029d1bf635250ac4449494b925e67735513edfcbdb3.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "\n SELECT root FROM accounts\n WHERE did = ?\n ", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "root", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - } 11 - ], 12 - "parameters": { 13 - "Right": 1 14 - }, 15 - "nullable": [ 16 - false 17 - ] 18 - }, 19 - "hash": "e6007f29d6b7681d7a1f5029d1bf635250ac4449494b925e67735513edfcbdb3" 20 - }
-32
.sqlx/query-fdd74b27ee260f2cc6fa9102f5c216b86436bb6ccf9bf707118c12b0bd393922.json
··· 1 - { 2 - "db_name": "SQLite", 3 - "query": "SELECT did, root, rev FROM accounts WHERE did > ? LIMIT ?", 4 - "describe": { 5 - "columns": [ 6 - { 7 - "name": "did", 8 - "ordinal": 0, 9 - "type_info": "Text" 10 - }, 11 - { 12 - "name": "root", 13 - "ordinal": 1, 14 - "type_info": "Text" 15 - }, 16 - { 17 - "name": "rev", 18 - "ordinal": 2, 19 - "type_info": "Text" 20 - } 21 - ], 22 - "parameters": { 23 - "Right": 2 24 - }, 25 - "nullable": [ 26 - false, 27 - false, 28 - false 29 - ] 30 - }, 31 - "hash": "fdd74b27ee260f2cc6fa9102f5c216b86436bb6ccf9bf707118c12b0bd393922" 32 - }
+2 -22
Cargo.lock
··· 1282 1282 dependencies = [ 1283 1283 "anyhow", 1284 1284 "argon2", 1285 - "async-trait", 1286 1285 "atrium-api 0.25.3", 1287 1286 "atrium-crypto", 1288 1287 "atrium-repo", 1289 - "atrium-xrpc", 1290 - "atrium-xrpc-client", 1291 1288 "axum", 1292 1289 "azure_core", 1293 1290 "azure_identity", ··· 1306 1303 "futures", 1307 1304 "hex", 1308 1305 "http-cache-reqwest", 1309 - "ipld-core", 1310 - "k256", 1311 - "lazy_static", 1312 1306 "memmap2", 1313 1307 "metrics", 1314 1308 "metrics-exporter-prometheus", 1315 - "multihash 0.19.3", 1316 - "r2d2", 1317 1309 "rand 0.8.5", 1318 - "regex", 1319 1310 "reqwest 0.12.15", 1320 1311 "reqwest-middleware", 1321 1312 "rsky-common", 1313 + "rsky-identity", 1322 1314 "rsky-lexicon", 1323 1315 "rsky-pds", 1324 1316 "rsky-repo", 1325 1317 "rsky-syntax", 1326 1318 "secp256k1", 1327 1319 "serde", 1328 - "serde_bytes", 1329 1320 "serde_ipld_dagcbor", 1330 - "serde_ipld_dagjson", 1331 1321 "serde_json", 1332 1322 "sha2", 1333 1323 "thiserror 2.0.12", ··· 1336 1326 "tower-http", 1337 1327 "tracing", 1338 1328 "tracing-subscriber", 1329 + "ubyte", 1339 1330 "url", 1340 1331 "urlencoding", 1341 1332 "uuid 1.16.0", ··· 5839 5830 "ipld-core", 5840 5831 "scopeguard", 5841 5832 "serde", 5842 - ] 5843 - 5844 - [[package]] 5845 - name = "serde_ipld_dagjson" 5846 - version = "0.2.0" 5847 - source = "registry+https://github.com/rust-lang/crates.io-index" 5848 - checksum = "3359b47ba7f4a306ef5984665e10539e212e97217afa489437d533208eecda36" 5849 - dependencies = [ 5850 - "ipld-core", 5851 - "serde", 5852 - "serde_json", 5853 5833 ] 5854 5834 5855 5835 [[package]]
+26 -18
Cargo.toml
··· 1 + # cargo-features = ["codegen-backend"] 2 + 1 3 [package] 2 4 name = "bluepds" 3 5 version = "0.0.0" ··· 13 15 14 16 [profile.dev.package."*"] 15 17 opt-level = 3 18 + # codegen-backend = "cranelift" 16 19 17 20 [profile.dev] 18 21 opt-level = 1 22 + # codegen-backend = "cranelift" 19 23 20 24 [profile.release] 21 25 opt-level = "s" # Slightly slows compile times, great improvements to file size and runtime performance. ··· 36 40 rust-2021-compatibility = { level = "warn", priority = -1 } # Lints used to transition code from the 2018 edition to 2021 37 41 rust-2018-idioms = { level = "warn", priority = -1 } # Lints to nudge you toward idiomatic features of Rust 2018 38 42 rust-2024-compatibility = { level = "warn", priority = -1 } # Lints used to transition code from the 2021 edition to 2024 39 - unused = { level = "warn", priority = -1 } # Lints that detect things being declared but not used, or excess syntax 43 + # unused = { level = "warn", priority = -1 } # Lints that detect things being declared but not used, or excess syntax 40 44 ## Individual 41 45 ambiguous_negative_literals = "warn" # checks for cases that are confusing between a negative literal and a negation that's not part of the literal. 42 46 closure_returning_async_block = "warn" # detects cases where users write a closure that returns an async block. # nightly ··· 62 66 unit_bindings = "warn" 63 67 unnameable_types = "warn" 64 68 # unqualified_local_imports = "warn" # unstable 65 - unreachable_pub = "warn" 69 + # unreachable_pub = "warn" 66 70 unsafe_code = "warn" 67 71 unstable_features = "warn" 68 72 # unused_crate_dependencies = "warn" ··· 73 77 variant_size_differences = "warn" 74 78 elided_lifetimes_in_paths = "allow" 75 79 # unstable-features = "allow" 80 + # # Temporary Allows 81 + dead_code = "allow" 82 + # unused_imports = "allow" 76 83 77 84 [lints.clippy] 78 85 # Groups 79 86 nursery = { level = "warn", priority = -1 } 80 87 correctness = { level = "warn", priority = -1 } 81 88 suspicious = { level = "warn", priority = -1 } 82 - complexity = { level = "warn", priority = -1 } 83 - perf = { level = "warn", priority = -1 } 84 - style = { level = "warn", priority = -1 } 85 - pedantic = { level = "warn", priority = -1 } 86 - restriction = { level = "warn", priority = -1 } 89 + # complexity = { level = "warn", priority = -1 } 90 + # perf = { level = "warn", priority = -1 } 91 + # style = { level = "warn", priority = -1 } 92 + # pedantic = { level = "warn", priority = -1 } 93 + # restriction = { level = "warn", priority = -1 } 87 94 cargo = { level = "warn", priority = -1 } 88 95 # Temporary Allows 89 96 multiple_crate_versions = "allow" # triggered by lib ··· 128 135 # expect_used = "deny" 129 136 130 137 [dependencies] 131 - multihash = "0.19.3" 138 + # multihash = "0.19.3" 132 139 diesel = { version = "2.1.5", features = [ 133 140 "chrono", 134 141 "sqlite", ··· 136 143 "returning_clauses_for_sqlite_3_35", 137 144 ] } 138 145 diesel_migrations = { version = "2.1.0" } 139 - r2d2 = "0.8.10" 146 + # r2d2 = "0.8.10" 140 147 141 148 atrium-repo = "0.1" 142 149 atrium-api = "0.25" 143 150 # atrium-common = { version = "0.1.2", path = "atrium-common" } 144 151 atrium-crypto = "0.1" 145 152 # atrium-identity = { version = "0.1.4", path = "atrium-identity" } 146 - atrium-xrpc = "0.12" 147 - atrium-xrpc-client = "0.5" 153 + # atrium-xrpc = "0.12" 154 + # atrium-xrpc-client = "0.5" 148 155 # bsky-sdk = { version = "0.1.19", path = "bsky-sdk" } 149 156 rsky-syntax = { git = "https://github.com/blacksky-algorithms/rsky.git" } 150 157 rsky-repo = { git = "https://github.com/blacksky-algorithms/rsky.git" } 151 158 rsky-pds = { git = "https://github.com/blacksky-algorithms/rsky.git" } 152 159 rsky-common = { git = "https://github.com/blacksky-algorithms/rsky.git" } 153 160 rsky-lexicon = { git = "https://github.com/blacksky-algorithms/rsky.git" } 161 + rsky-identity = { git = "https://github.com/blacksky-algorithms/rsky.git" } 154 162 155 163 # async in streams 156 164 # async-stream = "0.3" 157 165 158 166 # DAG-CBOR codec 159 - ipld-core = "0.4.2" 167 + # ipld-core = "0.4.2" 160 168 serde_ipld_dagcbor = { version = "0.6.2", default-features = false, features = [ 161 169 "std", 162 170 ] } 163 - serde_ipld_dagjson = "0.2.0" 171 + # serde_ipld_dagjson = "0.2.0" 164 172 cidv10 = { version = "0.10.1", package = "cid" } 165 173 166 174 # Parsing and validation ··· 169 177 hex = "0.4.3" 170 178 # langtag = "0.3" 171 179 # multibase = "0.9.1" 172 - regex = "1.11.1" 180 + # regex = "1.11.1" 173 181 serde = { version = "1.0.218", features = ["derive"] } 174 - serde_bytes = "0.11.17" 182 + # serde_bytes = "0.11.17" 175 183 # serde_html_form = "0.2.6" 176 184 serde_json = "1.0.139" 177 185 # unsigned-varint = "0.8" ··· 181 189 # elliptic-curve = "0.13.6" 182 190 # jose-jwa = "0.1.2" 183 191 # jose-jwk = { version = "0.1.2", default-features = false } 184 - k256 = "0.13.4" 192 + # k256 = "0.13.4" 185 193 # p256 = { version = "0.13.2", default-features = false } 186 194 rand = "0.8.5" 187 195 sha2 = "0.10.8" ··· 253 261 url = "2.5.4" 254 262 uuid = { version = "1.14.0", features = ["v4"] } 255 263 urlencoding = "2.1.3" 256 - async-trait = "0.1.88" 257 - lazy_static = "1.5.0" 264 + # lazy_static = "1.5.0" 258 265 secp256k1 = "0.28.2" 259 266 dotenvy = "0.15.7" 260 267 deadpool-diesel = { version = "0.6.1", features = [ ··· 262 269 "sqlite", 263 270 "tracing", 264 271 ] } 272 + ubyte = "0.10.4"
+31 -118
README.md
··· 11 11 \/_/ 12 12 ``` 13 13 14 - This is an implementation of an ATProto PDS, built with [Axum](https://github.com/tokio-rs/axum) and [Atrium](https://github.com/sugyan/atrium). 15 - This PDS implementation uses a SQLite database to store private account information and file storage to store canonical user data. 14 + This is an implementation of an ATProto PDS, built with [Axum](https://github.com/tokio-rs/axum), [rsky](https://github.com/blacksky-algorithms/rsky/) and [Atrium](https://github.com/sugyan/atrium). 15 + This PDS implementation uses a SQLite database and [diesel.rs](https://diesel.rs/) ORM to store canonical user data, and file system storage to store user blobs. 16 16 17 17 Heavily inspired by David Buchanan's [millipds](https://github.com/DavidBuchanan314/millipds). 18 - This implementation forked from the [azure-rust-app](https://github.com/DrChat/azure-rust-app) starter template and the upstream [DrChat/bluepds](https://github.com/DrChat/bluepds). 19 - See TODO below for this fork's changes from upstream. 18 + This implementation forked from [DrChat/bluepds](https://github.com/DrChat/bluepds), and now makes heavy use of the [rsky-repo](https://github.com/blacksky-algorithms/rsky/tree/main/rsky-repo) repository implementation. 19 + The `actor_store` and `account_manager` modules have been reimplemented from [rsky-pds](https://github.com/blacksky-algorithms/rsky/tree/main/rsky-pds) to use a SQLite backend and file storage, which are themselves adapted from the [original Bluesky implementation](https://github.com/bluesky-social/atproto) using SQLite in Typescript. 20 + 20 21 21 22 If you want to see this fork in action, there is a live account hosted by this PDS at [@teq.shatteredsky.net](https://bsky.app/profile/teq.shatteredsky.net)! 22 23 23 24 > [!WARNING] 24 - > This PDS is undergoing heavy development. Do _NOT_ use this to host your primary account or any important data! 25 + > This PDS is undergoing heavy development, and this branch is not at an operable release. Do _NOT_ use this to host your primary account or any important data! 25 26 26 27 ## Quick Start 27 28 ``` ··· 43 44 - Size: 47 GB 44 45 - VPUs/GB: 10 45 46 46 - This is about half of the 3,000 OCPU hours and 18,000 GB hours available per month for free on the VM.Standard.A1.Flex shape. This is _without_ optimizing for costs. The PDS can likely be made much cheaper. 47 - 48 - ## Code map 49 - ``` 50 - * migrations/ - SQLite database migrations 51 - * src/ 52 - * endpoints/ - ATProto API endpoints 53 - * auth.rs - Authentication primitives 54 - * config.rs - Application configuration 55 - * did.rs - Decentralized Identifier helpers 56 - * error.rs - Axum error helpers 57 - * firehose.rs - ATProto firehose producer 58 - * main.rs - Main entrypoint 59 - * metrics.rs - Definitions for telemetry instruments 60 - * oauth.rs - OAuth routes 61 - * plc.rs - Functionality to access the Public Ledger of Credentials 62 - * storage.rs - Helpers to access user repository storage 63 - ``` 47 + This is about half of the 3,000 OCPU hours and 18,000 GB hours available per month for free on the VM.Standard.A1.Flex shape. This is _without_ optimizing for costs. The PDS can likely be made to run on much less resources. 64 48 65 49 ## To-do 66 - ### Teq's fork 67 - - [ ] OAuth 68 - - [X] `/.well-known/oauth-protected-resource` - Authorization Server Metadata 69 - - [X] `/.well-known/oauth-authorization-server` 70 - - [X] `/par` - Pushed Authorization Request 71 - - [X] `/client-metadata.json` - Client metadata discovery 72 - - [X] `/oauth/authorize` 73 - - [X] `/oauth/authorize/sign-in` 74 - - [X] `/oauth/token` 75 - - [ ] Authorization flow - Backend client 76 - - [X] Authorization flow - Serverless browser app 77 - - [ ] DPoP-Nonce 78 - - [ ] Verify JWT signature with JWK 79 - - [ ] Email verification 80 - - [ ] 2FA 81 - - [ ] Admin endpoints 82 - - [ ] App passwords 83 - - [X] `listRecords` fixes 84 - - [X] Fix collection prefixing (terminate with `/`) 85 - - [X] Fix cursor handling (return `cid` instead of `key`) 86 - - [X] Session management (JWT) 87 - - [X] Match token fields to reference implementation 88 - - [X] RefreshSession from Bluesky Client 89 - - [X] Respond with JSON error message `ExpiredToken` 90 - - [X] Cursor handling 91 - - [X] Implement time-based unix microsecond sequences 92 - - [X] Startup with present cursor 93 - - [X] Respond `RecordNotFound`, required for: 94 - - [X] app.bsky.feed.postgate 95 - - [X] app.bsky.feed.threadgate 96 - - [ ] app.bsky... (profile creation?) 97 - - [X] Linting 98 - - [X] Rustfmt 99 - - [X] warnings 100 - - [X] deprecated-safe 101 - - [X] future-incompatible 102 - - [X] keyword-idents 103 - - [X] let-underscore 104 - - [X] nonstandard-style 105 - - [X] refining-impl-trait 106 - - [X] rust-2018-idioms 107 - - [X] rust-2018/2021/2024-compatibility 108 - - [X] ungrouped 109 - - [X] Clippy 110 - - [X] nursery 111 - - [X] correctness 112 - - [X] suspicious 113 - - [X] complexity 114 - - [X] perf 115 - - [X] style 116 - - [X] pedantic 117 - - [X] cargo 118 - - [X] ungrouped 119 - 120 - ### High-level features 121 - - [ ] Storage backend abstractions 122 - - [ ] Azure blob storage backend 123 - - [ ] Backblaze b2(?) 124 - - [ ] Telemetry 125 - - [X] [Metrics](https://github.com/metrics-rs/metrics) (counters/gauges/etc) 126 - - [X] Exporters for common backends (Prometheus/etc) 127 - 128 50 ### APIs 129 - - [X] [Service proxying](https://atproto.com/specs/xrpc#service-proxying) 130 - - [X] UG /xrpc/_health (undocumented, but impl by reference PDS) 51 + - [ ] [Service proxying](https://atproto.com/specs/xrpc#service-proxying) 52 + - [ ] UG /xrpc/_health (undocumented, but impl by reference PDS) 131 53 <!-- - [ ] xx /xrpc/app.bsky.notification.registerPush 132 54 - app.bsky.actor 133 - - [X] AG /xrpc/app.bsky.actor.getPreferences 55 + - [ ] AG /xrpc/app.bsky.actor.getPreferences 134 56 - [ ] xx /xrpc/app.bsky.actor.getProfile 135 57 - [ ] xx /xrpc/app.bsky.actor.getProfiles 136 - - [X] AP /xrpc/app.bsky.actor.putPreferences 58 + - [ ] AP /xrpc/app.bsky.actor.putPreferences 137 59 - app.bsky.feed 138 60 - [ ] xx /xrpc/app.bsky.feed.getActorLikes 139 61 - [ ] xx /xrpc/app.bsky.feed.getAuthorFeed ··· 157 79 - com.atproto.identity 158 80 - [ ] xx /xrpc/com.atproto.identity.getRecommendedDidCredentials 159 81 - [ ] AP /xrpc/com.atproto.identity.requestPlcOperationSignature 160 - - [X] UG /xrpc/com.atproto.identity.resolveHandle 82 + - [ ] UG /xrpc/com.atproto.identity.resolveHandle 161 83 - [ ] AP /xrpc/com.atproto.identity.signPlcOperation 162 84 - [ ] xx /xrpc/com.atproto.identity.submitPlcOperation 163 - - [X] AP /xrpc/com.atproto.identity.updateHandle 85 + - [ ] AP /xrpc/com.atproto.identity.updateHandle 164 86 <!-- - com.atproto.moderation 165 87 - [ ] xx /xrpc/com.atproto.moderation.createReport --> 166 88 - com.atproto.repo ··· 169 91 - [X] AP /xrpc/com.atproto.repo.deleteRecord 170 92 - [X] UG /xrpc/com.atproto.repo.describeRepo 171 93 - [X] UG /xrpc/com.atproto.repo.getRecord 172 - - [ ] xx /xrpc/com.atproto.repo.importRepo 173 - - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs 94 + - [X] xx /xrpc/com.atproto.repo.importRepo 95 + - [X] xx /xrpc/com.atproto.repo.listMissingBlobs 174 96 - [X] UG /xrpc/com.atproto.repo.listRecords 175 97 - [X] AP /xrpc/com.atproto.repo.putRecord 176 98 - [X] AP /xrpc/com.atproto.repo.uploadBlob ··· 178 100 - [ ] xx /xrpc/com.atproto.server.activateAccount 179 101 - [ ] xx /xrpc/com.atproto.server.checkAccountStatus 180 102 - [ ] xx /xrpc/com.atproto.server.confirmEmail 181 - - [X] UP /xrpc/com.atproto.server.createAccount 103 + - [ ] UP /xrpc/com.atproto.server.createAccount 182 104 - [ ] xx /xrpc/com.atproto.server.createAppPassword 183 - - [X] AP /xrpc/com.atproto.server.createInviteCode 105 + - [ ] AP /xrpc/com.atproto.server.createInviteCode 184 106 - [ ] xx /xrpc/com.atproto.server.createInviteCodes 185 - - [X] UP /xrpc/com.atproto.server.createSession 107 + - [ ] UP /xrpc/com.atproto.server.createSession 186 108 - [ ] xx /xrpc/com.atproto.server.deactivateAccount 187 109 - [ ] xx /xrpc/com.atproto.server.deleteAccount 188 110 - [ ] xx /xrpc/com.atproto.server.deleteSession 189 - - [X] UG /xrpc/com.atproto.server.describeServer 111 + - [ ] UG /xrpc/com.atproto.server.describeServer 190 112 - [ ] xx /xrpc/com.atproto.server.getAccountInviteCodes 191 - - [X] AG /xrpc/com.atproto.server.getServiceAuth 192 - - [X] AG /xrpc/com.atproto.server.getSession 113 + - [ ] AG /xrpc/com.atproto.server.getServiceAuth 114 + - [ ] AG /xrpc/com.atproto.server.getSession 193 115 - [ ] xx /xrpc/com.atproto.server.listAppPasswords 194 116 - [ ] xx /xrpc/com.atproto.server.refreshSession 195 117 - [ ] xx /xrpc/com.atproto.server.requestAccountDelete ··· 201 123 - [ ] xx /xrpc/com.atproto.server.revokeAppPassword 202 124 - [ ] xx /xrpc/com.atproto.server.updateEmail 203 125 - com.atproto.sync 204 - - [X] UG /xrpc/com.atproto.sync.getBlob 205 - - [X] UG /xrpc/com.atproto.sync.getBlocks 206 - - [X] UG /xrpc/com.atproto.sync.getLatestCommit 207 - - [X] UG /xrpc/com.atproto.sync.getRecord 208 - - [X] UG /xrpc/com.atproto.sync.getRepo 209 - - [X] UG /xrpc/com.atproto.sync.getRepoStatus 210 - - [X] UG /xrpc/com.atproto.sync.listBlobs 211 - - [X] UG /xrpc/com.atproto.sync.listRepos 212 - - [X] UG /xrpc/com.atproto.sync.subscribeRepos 126 + - [ ] UG /xrpc/com.atproto.sync.getBlob 127 + - [ ] UG /xrpc/com.atproto.sync.getBlocks 128 + - [ ] UG /xrpc/com.atproto.sync.getLatestCommit 129 + - [ ] UG /xrpc/com.atproto.sync.getRecord 130 + - [ ] UG /xrpc/com.atproto.sync.getRepo 131 + - [ ] UG /xrpc/com.atproto.sync.getRepoStatus 132 + - [ ] UG /xrpc/com.atproto.sync.listBlobs 133 + - [ ] UG /xrpc/com.atproto.sync.listRepos 134 + - [ ] UG /xrpc/com.atproto.sync.subscribeRepos 213 135 214 - ## Quick Deployment (Azure CLI) 215 - ``` 216 - az group create --name "webapp" --location southcentralus 217 - az deployment group create --resource-group "webapp" --template-file .\deployment.bicep --parameters webAppName=testapp 218 - 219 - az acr login --name <insert name of ACR resource here> 220 - docker build -t <ACR>.azurecr.io/testapp:latest . 221 - docker push <ACR>.azurecr.io/testapp:latest 222 - ``` 223 - ## Quick Deployment (NixOS) 136 + ## Deployment (NixOS) 224 137 ```nix 225 138 { 226 139 inputs = {
-182
deployment.bicep
··· 1 - param webAppName string 2 - param location string = resourceGroup().location // Location for all resources 3 - 4 - param sku string = 'B1' // The SKU of App Service Plan 5 - param dockerContainerName string = '${webAppName}:latest' 6 - param repositoryUrl string = 'https://github.com/DrChat/bluepds' 7 - param branch string = 'main' 8 - param customDomain string 9 - 10 - @description('Redeploy hostnames without SSL binding. Just specify `true` if this is the first time you\'re deploying the app.') 11 - param redeployHostnamesHack bool = false 12 - 13 - var acrName = toLower('${webAppName}${uniqueString(resourceGroup().id)}') 14 - var aspName = toLower('${webAppName}-asp') 15 - var webName = toLower('${webAppName}${uniqueString(resourceGroup().id)}') 16 - var sanName = toLower('${webAppName}${uniqueString(resourceGroup().id)}') 17 - 18 - // resource appInsights 'Microsoft.OperationalInsights/workspaces@2023-09-01' = { 19 - // name: '${webAppName}-ai' 20 - // location: location 21 - // properties: { 22 - // publicNetworkAccessForIngestion: 'Enabled' 23 - // workspaceCapping: { 24 - // dailyQuotaGb: 1 25 - // } 26 - // sku: { 27 - // name: 'Standalone' 28 - // } 29 - // } 30 - // } 31 - 32 - // resource appServicePlanDiagnostics 'Microsoft.Insights/diagnosticSettings@2021-05-01-preview' = { 33 - // name: appServicePlan.name 34 - // scope: appServicePlan 35 - // properties: { 36 - // workspaceId: appInsights.id 37 - // metrics: [ 38 - // { 39 - // category: 'AllMetrics' 40 - // enabled: true 41 - // } 42 - // ] 43 - // } 44 - // } 45 - 46 - resource appServicePlan 'Microsoft.Web/serverfarms@2020-06-01' = { 47 - name: aspName 48 - location: location 49 - properties: { 50 - reserved: true 51 - } 52 - sku: { 53 - name: sku 54 - } 55 - kind: 'linux' 56 - } 57 - 58 - resource acrResource 'Microsoft.ContainerRegistry/registries@2023-01-01-preview' = { 59 - name: acrName 60 - location: location 61 - sku: { 62 - name: 'Basic' 63 - } 64 - properties: { 65 - adminUserEnabled: false 66 - } 67 - } 68 - 69 - resource appStorage 'Microsoft.Storage/storageAccounts@2023-05-01' = { 70 - name: sanName 71 - location: location 72 - kind: 'StorageV2' 73 - sku: { 74 - name: 'Standard_LRS' 75 - } 76 - } 77 - 78 - resource fileShare 'Microsoft.Storage/storageAccounts/fileServices/shares@2023-05-01' = { 79 - name: '${appStorage.name}/default/data' 80 - properties: {} 81 - } 82 - 83 - resource appService 'Microsoft.Web/sites@2020-06-01' = { 84 - name: webName 85 - location: location 86 - identity: { 87 - type: 'SystemAssigned' 88 - } 89 - properties: { 90 - httpsOnly: true 91 - serverFarmId: appServicePlan.id 92 - siteConfig: { 93 - // Sigh. This took _far_ too long to figure out. 94 - // We must authenticate to ACR, as no credentials are set up by default 95 - // (the Az CLI will implicitly set them up in the background) 96 - acrUseManagedIdentityCreds: true 97 - appSettings: [ 98 - { 99 - name: 'BLUEPDS_HOST_NAME' 100 - value: empty(customDomain) ? '${webName}.azurewebsites.net' : customDomain 101 - } 102 - { 103 - name: 'BLUEPDS_TEST' 104 - value: 'false' 105 - } 106 - { 107 - name: 'WEBSITES_PORT' 108 - value: '8000' 109 - } 110 - ] 111 - linuxFxVersion: 'DOCKER|${acrName}.azurecr.io/${dockerContainerName}' 112 - } 113 - } 114 - } 115 - 116 - resource hostNameBinding 'Microsoft.Web/sites/hostNameBindings@2024-04-01' = if (redeployHostnamesHack) { 117 - name: customDomain 118 - parent: appService 119 - properties: { 120 - siteName: appService.name 121 - hostNameType: 'Verified' 122 - sslState: 'Disabled' 123 - } 124 - } 125 - 126 - // This stupidity is required because Azure requires a circular dependency in order to define a custom hostname with SSL. 127 - // https://stackoverflow.com/questions/73077972/how-to-deploy-app-service-with-managed-ssl-certificate-using-arm 128 - module certificateBindings './deploymentBindingHack.bicep' = { 129 - name: '${deployment().name}-ssl' 130 - params: { 131 - appServicePlanResourceId: appServicePlan.id 132 - customHostnames: [customDomain] 133 - location: location 134 - webAppName: appService.name 135 - } 136 - dependsOn: [hostNameBinding] 137 - } 138 - 139 - resource appServiceStorageConfig 'Microsoft.Web/sites/config@2024-04-01' = { 140 - name: 'azurestorageaccounts' 141 - parent: appService 142 - properties: { 143 - data: { 144 - type: 'AzureFiles' 145 - shareName: 'data' 146 - mountPath: '/app/data' 147 - accountName: appStorage.name 148 - // WTF? Where's the ability to mount storage via managed identity? 149 - accessKey: appStorage.listKeys().keys[0].value 150 - } 151 - } 152 - } 153 - 154 - @description('This is the built-in AcrPull role. See https://docs.microsoft.com/azure/role-based-access-control/built-in-roles#acrpull') 155 - resource acrPullRoleDefinition 'Microsoft.Authorization/roleDefinitions@2018-01-01-preview' existing = { 156 - scope: subscription() 157 - name: '7f951dda-4ed3-4680-a7ca-43fe172d538d' 158 - } 159 - 160 - resource appServiceAcrPull 'Microsoft.Authorization/roleAssignments@2020-04-01-preview' = { 161 - name: guid(resourceGroup().id, acrResource.id, appService.id, 'AssignAcrPullToAS') 162 - scope: acrResource 163 - properties: { 164 - description: 'Assign AcrPull role to AS' 165 - principalId: appService.identity.principalId 166 - principalType: 'ServicePrincipal' 167 - roleDefinitionId: acrPullRoleDefinition.id 168 - } 169 - } 170 - 171 - resource srcControls 'Microsoft.Web/sites/sourcecontrols@2021-01-01' = { 172 - name: 'web' 173 - parent: appService 174 - properties: { 175 - repoUrl: repositoryUrl 176 - branch: branch 177 - isManualIntegration: true 178 - } 179 - } 180 - 181 - output acr string = acrResource.name 182 - output domain string = appService.properties.hostNames[0]
-30
deploymentBindingHack.bicep
··· 1 - // https://stackoverflow.com/questions/73077972/how-to-deploy-app-service-with-managed-ssl-certificate-using-arm 2 - // 3 - // TLDR: Azure requires a circular dependency in order to define an app service with a custom domain with SSL enabled. 4 - // Terrific user experience. Really makes me love using Azure in my free time. 5 - param webAppName string 6 - param location string 7 - param appServicePlanResourceId string 8 - param customHostnames array 9 - 10 - // Managed certificates can only be created once the hostname is added to the web app. 11 - resource certificates 'Microsoft.Web/certificates@2022-03-01' = [for (fqdn, i) in customHostnames: { 12 - name: '${fqdn}-${webAppName}' 13 - location: location 14 - properties: { 15 - serverFarmId: appServicePlanResourceId 16 - canonicalName: fqdn 17 - } 18 - }] 19 - 20 - // sslState and thumbprint can only be set once the managed certificate is created 21 - @batchSize(1) 22 - resource customHostname 'Microsoft.web/sites/hostnameBindings@2019-08-01' = [for (fqdn, i) in customHostnames: { 23 - name: '${webAppName}/${fqdn}' 24 - properties: { 25 - siteName: webAppName 26 - hostNameType: 'Verified' 27 - sslState: 'SniEnabled' 28 - thumbprint: certificates[i].properties.thumbprint 29 - } 30 - }]
+3 -2
flake.nix
··· 22 22 "rust-analyzer" 23 23 ]; 24 24 })); 25 - 25 + 26 26 inherit (pkgs) lib; 27 27 unfilteredRoot = ./.; # The original, unfiltered source 28 28 src = lib.fileset.toSource { ··· 109 109 git 110 110 nixd 111 111 direnv 112 + libpq 112 113 ]; 113 114 }; 114 115 }) ··· 165 166 }; 166 167 }; 167 168 }); 168 - } 169 + }
+14
migrations/2025-05-15-182818_init_diff/down.sql
··· 1 + DROP TABLE IF EXISTS `repo_seq`; 2 + DROP TABLE IF EXISTS `app_password`; 3 + DROP TABLE IF EXISTS `device_account`; 4 + DROP TABLE IF EXISTS `actor`; 5 + DROP TABLE IF EXISTS `device`; 6 + DROP TABLE IF EXISTS `did_doc`; 7 + DROP TABLE IF EXISTS `email_token`; 8 + DROP TABLE IF EXISTS `invite_code`; 9 + DROP TABLE IF EXISTS `used_refresh_token`; 10 + DROP TABLE IF EXISTS `invite_code_use`; 11 + DROP TABLE IF EXISTS `authorization_request`; 12 + DROP TABLE IF EXISTS `token`; 13 + DROP TABLE IF EXISTS `refresh_token`; 14 + DROP TABLE IF EXISTS `account`;
+122
migrations/2025-05-15-182818_init_diff/up.sql
··· 1 + CREATE TABLE `repo_seq`( 2 + `seq` INT8 NOT NULL PRIMARY KEY, 3 + `did` VARCHAR NOT NULL, 4 + `eventtype` VARCHAR NOT NULL, 5 + `event` BYTEA NOT NULL, 6 + `invalidated` INT2 NOT NULL, 7 + `sequencedat` VARCHAR NOT NULL 8 + ); 9 + 10 + CREATE TABLE `app_password`( 11 + `did` VARCHAR NOT NULL, 12 + `name` VARCHAR NOT NULL, 13 + `password` VARCHAR NOT NULL, 14 + `createdat` VARCHAR NOT NULL, 15 + PRIMARY KEY(`did`, `name`) 16 + ); 17 + 18 + CREATE TABLE `device_account`( 19 + `did` VARCHAR NOT NULL, 20 + `deviceid` VARCHAR NOT NULL, 21 + `authenticatedat` TIMESTAMPTZ NOT NULL, 22 + `remember` BOOL NOT NULL, 23 + `authorizedclients` VARCHAR NOT NULL, 24 + PRIMARY KEY(`deviceId`, `did`) 25 + ); 26 + 27 + CREATE TABLE `actor`( 28 + `did` VARCHAR NOT NULL PRIMARY KEY, 29 + `handle` VARCHAR, 30 + `createdat` VARCHAR NOT NULL, 31 + `takedownref` VARCHAR, 32 + `deactivatedat` VARCHAR, 33 + `deleteafter` VARCHAR 34 + ); 35 + 36 + CREATE TABLE `device`( 37 + `id` VARCHAR NOT NULL PRIMARY KEY, 38 + `sessionid` VARCHAR, 39 + `useragent` VARCHAR, 40 + `ipaddress` VARCHAR NOT NULL, 41 + `lastseenat` TIMESTAMPTZ NOT NULL 42 + ); 43 + 44 + CREATE TABLE `did_doc`( 45 + `did` VARCHAR NOT NULL PRIMARY KEY, 46 + `doc` TEXT NOT NULL, 47 + `updatedat` INT8 NOT NULL 48 + ); 49 + 50 + CREATE TABLE `email_token`( 51 + `purpose` VARCHAR NOT NULL, 52 + `did` VARCHAR NOT NULL, 53 + `token` VARCHAR NOT NULL, 54 + `requestedat` VARCHAR NOT NULL, 55 + PRIMARY KEY(`purpose`, `did`) 56 + ); 57 + 58 + CREATE TABLE `invite_code`( 59 + `code` VARCHAR NOT NULL PRIMARY KEY, 60 + `availableuses` INT4 NOT NULL, 61 + `disabled` INT2 NOT NULL, 62 + `foraccount` VARCHAR NOT NULL, 63 + `createdby` VARCHAR NOT NULL, 64 + `createdat` VARCHAR NOT NULL 65 + ); 66 + 67 + CREATE TABLE `used_refresh_token`( 68 + `refreshtoken` VARCHAR NOT NULL PRIMARY KEY, 69 + `tokenid` VARCHAR NOT NULL 70 + ); 71 + 72 + CREATE TABLE `invite_code_use`( 73 + `code` VARCHAR NOT NULL, 74 + `usedby` VARCHAR NOT NULL, 75 + `usedat` VARCHAR NOT NULL, 76 + PRIMARY KEY(`code`, `usedBy`) 77 + ); 78 + 79 + CREATE TABLE `authorization_request`( 80 + `id` VARCHAR NOT NULL PRIMARY KEY, 81 + `did` VARCHAR, 82 + `deviceid` VARCHAR, 83 + `clientid` VARCHAR NOT NULL, 84 + `clientauth` VARCHAR NOT NULL, 85 + `parameters` VARCHAR NOT NULL, 86 + `expiresat` TIMESTAMPTZ NOT NULL, 87 + `code` VARCHAR 88 + ); 89 + 90 + CREATE TABLE `token`( 91 + `id` VARCHAR NOT NULL PRIMARY KEY, 92 + `did` VARCHAR NOT NULL, 93 + `tokenid` VARCHAR NOT NULL, 94 + `createdat` TIMESTAMPTZ NOT NULL, 95 + `updatedat` TIMESTAMPTZ NOT NULL, 96 + `expiresat` TIMESTAMPTZ NOT NULL, 97 + `clientid` VARCHAR NOT NULL, 98 + `clientauth` VARCHAR NOT NULL, 99 + `deviceid` VARCHAR, 100 + `parameters` VARCHAR NOT NULL, 101 + `details` VARCHAR, 102 + `code` VARCHAR, 103 + `currentrefreshtoken` VARCHAR 104 + ); 105 + 106 + CREATE TABLE `refresh_token`( 107 + `id` VARCHAR NOT NULL PRIMARY KEY, 108 + `did` VARCHAR NOT NULL, 109 + `expiresat` VARCHAR NOT NULL, 110 + `nextid` VARCHAR, 111 + `apppasswordname` VARCHAR 112 + ); 113 + 114 + CREATE TABLE `account`( 115 + `did` VARCHAR NOT NULL PRIMARY KEY, 116 + `email` VARCHAR NOT NULL, 117 + `recoverykey` VARCHAR, 118 + `password` VARCHAR NOT NULL, 119 + `createdat` VARCHAR NOT NULL, 120 + `invitesdisabled` INT2 NOT NULL, 121 + `emailconfirmedat` VARCHAR 122 + );
+4
migrations/2025-05-17-094600_oauth_temp/down.sql
··· 1 + DROP TABLE IF EXISTS `oauth_refresh_tokens`; 2 + DROP TABLE IF EXISTS `oauth_used_jtis`; 3 + DROP TABLE IF EXISTS `oauth_par_requests`; 4 + DROP TABLE IF EXISTS `oauth_authorization_codes`;
+46
migrations/2025-05-17-094600_oauth_temp/up.sql
··· 1 + CREATE TABLE `oauth_refresh_tokens`( 2 + `token` VARCHAR NOT NULL PRIMARY KEY, 3 + `client_id` VARCHAR NOT NULL, 4 + `subject` VARCHAR NOT NULL, 5 + `dpop_thumbprint` VARCHAR NOT NULL, 6 + `scope` VARCHAR, 7 + `created_at` INT8 NOT NULL, 8 + `expires_at` INT8 NOT NULL, 9 + `revoked` BOOL NOT NULL 10 + ); 11 + 12 + CREATE TABLE `oauth_used_jtis`( 13 + `jti` VARCHAR NOT NULL PRIMARY KEY, 14 + `issuer` VARCHAR NOT NULL, 15 + `created_at` INT8 NOT NULL, 16 + `expires_at` INT8 NOT NULL 17 + ); 18 + 19 + CREATE TABLE `oauth_par_requests`( 20 + `request_uri` VARCHAR NOT NULL PRIMARY KEY, 21 + `client_id` VARCHAR NOT NULL, 22 + `response_type` VARCHAR NOT NULL, 23 + `code_challenge` VARCHAR NOT NULL, 24 + `code_challenge_method` VARCHAR NOT NULL, 25 + `state` VARCHAR, 26 + `login_hint` VARCHAR, 27 + `scope` VARCHAR, 28 + `redirect_uri` VARCHAR, 29 + `response_mode` VARCHAR, 30 + `display` VARCHAR, 31 + `created_at` INT8 NOT NULL, 32 + `expires_at` INT8 NOT NULL 33 + ); 34 + 35 + CREATE TABLE `oauth_authorization_codes`( 36 + `code` VARCHAR NOT NULL PRIMARY KEY, 37 + `client_id` VARCHAR NOT NULL, 38 + `subject` VARCHAR NOT NULL, 39 + `code_challenge` VARCHAR NOT NULL, 40 + `code_challenge_method` VARCHAR NOT NULL, 41 + `redirect_uri` VARCHAR NOT NULL, 42 + `scope` VARCHAR, 43 + `created_at` INT8 NOT NULL, 44 + `expires_at` INT8 NOT NULL, 45 + `used` BOOL NOT NULL 46 + );
-7
migrations/20250104202448_init.down.sql
··· 1 - DROP TABLE IF EXISTS invites; 2 - 3 - DROP TABLE IF EXISTS handles; 4 - 5 - DROP TABLE IF EXISTS accounts; 6 - 7 - DROP TABLE IF EXISTS sessions;
-29
migrations/20250104202448_init.up.sql
··· 1 - CREATE TABLE IF NOT EXISTS accounts ( 2 - did TEXT PRIMARY KEY NOT NULL, 3 - email TEXT NOT NULL UNIQUE, 4 - password TEXT NOT NULL, 5 - root TEXT NOT NULL, 6 - rev TEXT NOT NULL, 7 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 8 - ); 9 - 10 - CREATE TABLE IF NOT EXISTS handles ( 11 - handle TEXT PRIMARY KEY NOT NULL, 12 - did TEXT NOT NULL, 13 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 14 - FOREIGN KEY (did) REFERENCES accounts(did) 15 - ); 16 - 17 - CREATE TABLE IF NOT EXISTS invites ( 18 - id TEXT PRIMARY KEY NOT NULL, 19 - did TEXT, 20 - count INTEGER NOT NULL DEFAULT 1, 21 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP 22 - ); 23 - 24 - CREATE TABLE IF NOT EXISTS sessions ( 25 - id TEXT PRIMARY KEY NOT NULL, 26 - did TEXT NOT NULL, 27 - created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, 28 - FOREIGN KEY (did) REFERENCES accounts(did) 29 - );
-1
migrations/20250217052304_repo_status.down.sql
··· 1 - ALTER TABLE accounts DROP COLUMN status;
-1
migrations/20250217052304_repo_status.up.sql
··· 1 - ALTER TABLE accounts ADD COLUMN status TEXT NOT NULL DEFAULT "active";
-1
migrations/20250219055555_account_plc_root.down.sql
··· 1 - ALTER TABLE accounts DROP COLUMN plc_root;
-1
migrations/20250219055555_account_plc_root.up.sql
··· 1 - ALTER TABLE accounts ADD COLUMN plc_root TEXT NOT NULL;
-1
migrations/20250220235950_private_data.down.sql
··· 1 - ALTER TABLE accounts DROP COLUMN private_prefs;
-1
migrations/20250220235950_private_data.up.sql
··· 1 - ALTER TABLE accounts ADD COLUMN private_prefs JSON;
-1
migrations/20250223015249_blob_ref.down.sql
··· 1 - DROP TABLE blob_ref;
-6
migrations/20250223015249_blob_ref.up.sql
··· 1 - CREATE TABLE IF NOT EXISTS blob_ref ( 2 - -- N.B: There is a hidden `rowid` field inserted by sqlite. 3 - cid TEXT NOT NULL, 4 - did TEXT NOT NULL, 5 - record TEXT 6 - );
-1
migrations/20250330074000_oauth.down.sql
··· 1 - DROP TABLE oauth_par_requests;
-37
migrations/20250330074000_oauth.up.sql
··· 1 - CREATE TABLE IF NOT EXISTS oauth_par_requests ( 2 - request_uri TEXT PRIMARY KEY NOT NULL, 3 - client_id TEXT NOT NULL, 4 - response_type TEXT NOT NULL, 5 - code_challenge TEXT NOT NULL, 6 - code_challenge_method TEXT NOT NULL, 7 - state TEXT, 8 - login_hint TEXT, 9 - scope TEXT, 10 - redirect_uri TEXT, 11 - response_mode TEXT, 12 - display TEXT, 13 - created_at INTEGER NOT NULL, 14 - expires_at INTEGER NOT NULL 15 - ); 16 - CREATE TABLE IF NOT EXISTS oauth_authorization_codes ( 17 - code TEXT PRIMARY KEY NOT NULL, 18 - client_id TEXT NOT NULL, 19 - subject TEXT NOT NULL, 20 - code_challenge TEXT NOT NULL, 21 - code_challenge_method TEXT NOT NULL, 22 - redirect_uri TEXT NOT NULL, 23 - scope TEXT, 24 - created_at INTEGER NOT NULL, 25 - expires_at INTEGER NOT NULL, 26 - used BOOLEAN NOT NULL DEFAULT FALSE 27 - ); 28 - CREATE TABLE IF NOT EXISTS oauth_refresh_tokens ( 29 - token TEXT PRIMARY KEY NOT NULL, 30 - client_id TEXT NOT NULL, 31 - subject TEXT NOT NULL, 32 - dpop_thumbprint TEXT NOT NULL, 33 - scope TEXT, 34 - created_at INTEGER NOT NULL, 35 - expires_at INTEGER NOT NULL, 36 - revoked BOOLEAN NOT NULL DEFAULT FALSE 37 - );
-6
migrations/20250502032700_jti.down.sql
··· 1 - DROP INDEX IF EXISTS idx_jtis_expires_at; 2 - DROP INDEX IF EXISTS idx_refresh_tokens_expires_at; 3 - DROP INDEX IF EXISTS idx_auth_codes_expires_at; 4 - DROP INDEX IF EXISTS idx_par_expires_at; 5 - 6 - DROP TABLE IF EXISTS oauth_used_jtis;
-13
migrations/20250502032700_jti.up.sql
··· 1 - -- Table for tracking used JTIs to prevent replay attacks 2 - CREATE TABLE IF NOT EXISTS oauth_used_jtis ( 3 - jti TEXT PRIMARY KEY NOT NULL, 4 - issuer TEXT NOT NULL, 5 - created_at INTEGER NOT NULL, 6 - expires_at INTEGER NOT NULL 7 - ); 8 - 9 - -- Create indexes for faster lookups and cleanup 10 - CREATE INDEX IF NOT EXISTS idx_par_expires_at ON oauth_par_requests(expires_at); 11 - CREATE INDEX IF NOT EXISTS idx_auth_codes_expires_at ON oauth_authorization_codes(expires_at); 12 - CREATE INDEX IF NOT EXISTS idx_refresh_tokens_expires_at ON oauth_refresh_tokens(expires_at); 13 - CREATE INDEX IF NOT EXISTS idx_jtis_expires_at ON oauth_used_jtis(expires_at);
-16
migrations/20250508251242_actor_store.down.sql
··· 1 - -- Drop indexes 2 - DROP INDEX IF EXISTS idx_backlink_link_to; 3 - DROP INDEX IF EXISTS idx_blob_tempkey; 4 - DROP INDEX IF EXISTS idx_record_repo_rev; 5 - DROP INDEX IF EXISTS idx_record_collection; 6 - DROP INDEX IF EXISTS idx_record_cid; 7 - DROP INDEX IF EXISTS idx_repo_block_repo_rev; 8 - 9 - -- Drop tables 10 - DROP TABLE IF EXISTS account_pref; 11 - DROP TABLE IF EXISTS backlink; 12 - DROP TABLE IF EXISTS record_blob; 13 - DROP TABLE IF EXISTS blob; 14 - DROP TABLE IF EXISTS record; 15 - DROP TABLE IF EXISTS repo_block; 16 - DROP TABLE IF EXISTS repo_root;
-70
migrations/20250508251242_actor_store.up.sql
··· 1 - -- Actor store schema matching TypeScript implementation 2 - 3 - -- Repository root information 4 - CREATE TABLE IF NOT EXISTS repo_root ( 5 - did TEXT PRIMARY KEY NOT NULL, 6 - cid TEXT NOT NULL, 7 - rev TEXT NOT NULL, 8 - indexedAt TEXT NOT NULL 9 - ); 10 - 11 - -- Repository blocks (IPLD blocks) 12 - CREATE TABLE IF NOT EXISTS repo_block ( 13 - cid TEXT PRIMARY KEY NOT NULL, 14 - repoRev TEXT NOT NULL, 15 - size INTEGER NOT NULL, 16 - content BLOB NOT NULL 17 - ); 18 - 19 - -- Record index 20 - CREATE TABLE IF NOT EXISTS record ( 21 - uri TEXT PRIMARY KEY NOT NULL, 22 - cid TEXT NOT NULL, 23 - collection TEXT NOT NULL, 24 - rkey TEXT NOT NULL, 25 - repoRev TEXT NOT NULL, 26 - indexedAt TEXT NOT NULL, 27 - takedownRef TEXT 28 - ); 29 - 30 - -- Blob storage metadata 31 - CREATE TABLE IF NOT EXISTS blob ( 32 - cid TEXT PRIMARY KEY NOT NULL, 33 - mimeType TEXT NOT NULL, 34 - size INTEGER NOT NULL, 35 - tempKey TEXT, 36 - width INTEGER, 37 - height INTEGER, 38 - createdAt TEXT NOT NULL, 39 - takedownRef TEXT 40 - ); 41 - 42 - -- Record-blob associations 43 - CREATE TABLE IF NOT EXISTS record_blob ( 44 - blobCid TEXT NOT NULL, 45 - recordUri TEXT NOT NULL, 46 - PRIMARY KEY (blobCid, recordUri) 47 - ); 48 - 49 - -- Backlinks between records 50 - CREATE TABLE IF NOT EXISTS backlink ( 51 - uri TEXT NOT NULL, 52 - path TEXT NOT NULL, 53 - linkTo TEXT NOT NULL, 54 - PRIMARY KEY (uri, path) 55 - ); 56 - 57 - -- User preferences 58 - CREATE TABLE IF NOT EXISTS account_pref ( 59 - id INTEGER PRIMARY KEY AUTOINCREMENT, 60 - name TEXT NOT NULL, 61 - valueJson TEXT NOT NULL 62 - ); 63 - 64 - -- Create indexes 65 - CREATE INDEX IF NOT EXISTS idx_repo_block_repo_rev ON repo_block(repoRev, cid); 66 - CREATE INDEX IF NOT EXISTS idx_record_cid ON record(cid); 67 - CREATE INDEX IF NOT EXISTS idx_record_collection ON record(collection); 68 - CREATE INDEX IF NOT EXISTS idx_record_repo_rev ON record(repoRev); 69 - CREATE INDEX IF NOT EXISTS idx_blob_tempkey ON blob(tempKey); 70 - CREATE INDEX IF NOT EXISTS idx_backlink_link_to ON backlink(path, linkTo);
-15
migrations/20250508252057_blockstore.up.sql
··· 1 - CREATE TABLE IF NOT EXISTS blocks ( 2 - cid TEXT PRIMARY KEY NOT NULL, 3 - data BLOB NOT NULL, 4 - multicodec INTEGER NOT NULL, 5 - multihash INTEGER NOT NULL 6 - ); 7 - CREATE TABLE IF NOT EXISTS tree_nodes ( 8 - repo_did TEXT NOT NULL, 9 - key TEXT NOT NULL, 10 - value_cid TEXT NOT NULL, 11 - PRIMARY KEY (repo_did, key), 12 - FOREIGN KEY (value_cid) REFERENCES blocks(cid) 13 - ); 14 - CREATE INDEX IF NOT EXISTS idx_blocks_cid ON blocks(cid); 15 - CREATE INDEX IF NOT EXISTS idx_tree_nodes_repo ON tree_nodes(repo_did);
-5
migrations/20250510222500_actor_migration.up.sql
··· 1 - CREATE TABLE IF NOT EXISTS actor_migration ( 2 - id INTEGER PRIMARY KEY AUTOINCREMENT, 3 - name TEXT NOT NULL, 4 - appliedAt TEXT NOT NULL 5 - );
+68 -36
src/account_manager/helpers/account.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 + use crate::schema::pds::account::dsl as AccountSchema; 6 + use crate::schema::pds::account::table as AccountTable; 7 + use crate::schema::pds::actor::dsl as ActorSchema; 8 + use crate::schema::pds::actor::table as ActorTable; 5 9 use anyhow::Result; 6 10 use chrono::DateTime; 7 11 use chrono::offset::Utc as UtcOffset; 8 - use diesel::dsl::{exists, not}; 9 12 use diesel::result::{DatabaseErrorKind, Error as DieselError}; 10 13 use diesel::*; 11 14 use rsky_common::RFC3339_VARIANT; 12 15 use rsky_lexicon::com::atproto::admin::StatusAttr; 13 16 #[expect(unused_imports)] 14 17 pub(crate) use rsky_pds::account_manager::helpers::account::{ 15 - AccountStatus, ActorAccount, ActorJoinAccount, AvailabilityFlags, FormattedAccountStatus, 18 + AccountStatus, ActorAccount, AvailabilityFlags, FormattedAccountStatus, 16 19 GetAccountAdminStatusOutput, format_account_status, 17 20 }; 18 - use rsky_pds::schema::pds::account::dsl as AccountSchema; 19 - use rsky_pds::schema::pds::actor::dsl as ActorSchema; 20 21 use std::ops::Add; 21 22 use std::time::SystemTime; 22 23 use thiserror::Error; 23 24 25 + use diesel::dsl::{LeftJoinOn, exists, not}; 26 + use diesel::helper_types::Eq; 27 + 24 28 #[derive(Error, Debug)] 25 29 pub enum AccountHelperError { 26 30 #[error("UserAlreadyExistsError")] ··· 28 32 #[error("DatabaseError: `{0}`")] 29 33 DieselError(String), 30 34 } 31 - 35 + pub type ActorJoinAccount = 36 + LeftJoinOn<ActorTable, AccountTable, Eq<ActorSchema::did, AccountSchema::did>>; 32 37 pub type BoxedQuery<'life> = dsl::IntoBoxed<'life, ActorJoinAccount, sqlite::Sqlite>; 33 38 pub fn select_account_qb(flags: Option<AvailabilityFlags>) -> BoxedQuery<'static> { 34 39 let AvailabilityFlags { 35 40 include_taken_down, 36 41 include_deactivated, 37 - } = flags.unwrap_or_else(|| AvailabilityFlags { 42 + } = flags.unwrap_or(AvailabilityFlags { 38 43 include_taken_down: Some(false), 39 44 include_deactivated: Some(false), 40 45 }); ··· 252 257 deadpool_diesel::Manager<SqliteConnection>, 253 258 deadpool_diesel::sqlite::Object, 254 259 >, 260 + actor_db: &deadpool_diesel::Pool< 261 + deadpool_diesel::Manager<SqliteConnection>, 262 + deadpool_diesel::sqlite::Object, 263 + >, 255 264 ) -> Result<()> { 256 - use rsky_pds::schema::pds::email_token::dsl as EmailTokenSchema; 257 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 258 - use rsky_pds::schema::pds::repo_root::dsl as RepoRootSchema; 265 + use crate::schema::actor_store::repo_root::dsl as RepoRootSchema; 266 + use crate::schema::pds::email_token::dsl as EmailTokenSchema; 267 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 259 268 260 - let did = did.to_owned(); 261 - db.get() 269 + let did_clone = did.to_owned(); 270 + _ = actor_db 271 + .get() 262 272 .await? 263 273 .interact(move |conn| { 264 274 delete(RepoRootSchema::repo_root) 265 - .filter(RepoRootSchema::did.eq(&did)) 266 - .execute(conn)?; 267 - delete(EmailTokenSchema::email_token) 268 - .filter(EmailTokenSchema::did.eq(&did)) 275 + .filter(RepoRootSchema::did.eq(&did_clone)) 276 + .execute(conn) 277 + }) 278 + .await 279 + .expect("Failed to delete actor")?; 280 + let did_clone = did.to_owned(); 281 + _ = db 282 + .get() 283 + .await? 284 + .interact(move |conn| { 285 + _ = delete(EmailTokenSchema::email_token) 286 + .filter(EmailTokenSchema::did.eq(&did_clone)) 269 287 .execute(conn)?; 270 - delete(RefreshTokenSchema::refresh_token) 271 - .filter(RefreshTokenSchema::did.eq(&did)) 288 + _ = delete(RefreshTokenSchema::refresh_token) 289 + .filter(RefreshTokenSchema::did.eq(&did_clone)) 272 290 .execute(conn)?; 273 - delete(AccountSchema::account) 274 - .filter(AccountSchema::did.eq(&did)) 291 + _ = delete(AccountSchema::account) 292 + .filter(AccountSchema::did.eq(&did_clone)) 275 293 .execute(conn)?; 276 294 delete(ActorSchema::actor) 277 - .filter(ActorSchema::did.eq(&did)) 295 + .filter(ActorSchema::did.eq(&did_clone)) 278 296 .execute(conn) 279 297 }) 280 298 .await 281 299 .expect("Failed to delete account")?; 300 + 301 + let data_repo_file = format!("data/repo/{}.db", did.to_owned()); 302 + let data_blob_path = format!("data/blob/{}", did); 303 + let data_blob_path = std::path::Path::new(&data_blob_path); 304 + let data_repo_file = std::path::Path::new(&data_repo_file); 305 + if data_repo_file.exists() { 306 + std::fs::remove_file(data_repo_file)?; 307 + }; 308 + if data_blob_path.exists() { 309 + std::fs::remove_dir_all(data_blob_path)?; 310 + }; 282 311 Ok(()) 283 312 } 284 313 ··· 291 320 >, 292 321 ) -> Result<()> { 293 322 let takedown_ref: Option<String> = match takedown.applied { 294 - true => match takedown.r#ref { 295 - Some(takedown_ref) => Some(takedown_ref), 296 - None => Some(rsky_common::now()), 297 - }, 323 + true => takedown 324 + .r#ref 325 + .map_or_else(|| Some(rsky_common::now()), Some), 298 326 false => None, 299 327 }; 300 328 let did = did.to_owned(); 301 - db.get() 329 + _ = db 330 + .get() 302 331 .await? 303 332 .interact(move |conn| { 304 333 update(ActorSchema::actor) ··· 320 349 >, 321 350 ) -> Result<()> { 322 351 let did = did.to_owned(); 323 - db.get() 352 + _ = db 353 + .get() 324 354 .await? 325 355 .interact(move |conn| { 326 356 update(ActorSchema::actor) ··· 344 374 >, 345 375 ) -> Result<()> { 346 376 let did = did.to_owned(); 347 - db.get() 377 + _ = db 378 + .get() 348 379 .await? 349 380 .interact(move |conn| { 350 381 update(ActorSchema::actor) ··· 407 438 deadpool_diesel::sqlite::Object, 408 439 >, 409 440 ) -> Result<()> { 410 - use rsky_pds::schema::pds::actor; 441 + use crate::schema::pds::actor; 411 442 412 443 let actor2 = diesel::alias!(actor as actor2); 413 444 ··· 443 474 >, 444 475 ) -> Result<()> { 445 476 let did = did.to_owned(); 446 - db.get() 477 + _ = db 478 + .get() 447 479 .await? 448 480 .interact(move |conn| { 449 481 update(AccountSchema::account) ··· 479 511 match res { 480 512 None => Ok(None), 481 513 Some(res) => { 482 - let takedown = match res.0 { 483 - Some(takedown_ref) => StatusAttr { 514 + let takedown = res.0.map_or( 515 + StatusAttr { 516 + applied: false, 517 + r#ref: None, 518 + }, 519 + |takedown_ref| StatusAttr { 484 520 applied: true, 485 521 r#ref: Some(takedown_ref), 486 522 }, 487 - None => StatusAttr { 488 - applied: false, 489 - r#ref: None, 490 - }, 491 - }; 523 + ); 492 524 let deactivated = match res.1 { 493 525 Some(_) => StatusAttr { 494 526 applied: true,
+29 -26
src/account_manager/helpers/auth.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 + use crate::models::pds as models; 5 6 use anyhow::Result; 6 7 use diesel::*; 7 8 use rsky_common::time::from_micros_to_utc; ··· 12 13 RefreshToken, ServiceJwtHeader, ServiceJwtParams, ServiceJwtPayload, create_access_token, 13 14 create_refresh_token, create_service_jwt, create_tokens, decode_refresh_token, 14 15 }; 15 - use rsky_pds::models; 16 16 17 17 pub async fn store_refresh_token( 18 18 payload: RefreshToken, ··· 22 22 deadpool_diesel::sqlite::Object, 23 23 >, 24 24 ) -> Result<()> { 25 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 25 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 26 26 27 27 let exp = from_micros_to_utc((payload.exp.as_millis() / 1000) as i64); 28 28 29 - db.get() 29 + _ = db 30 + .get() 30 31 .await? 31 32 .interact(move |conn| { 32 33 insert_into(RefreshTokenSchema::refresh_token) ··· 52 53 deadpool_diesel::sqlite::Object, 53 54 >, 54 55 ) -> Result<bool> { 55 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 56 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 56 57 db.get() 57 58 .await? 58 59 .interact(move |conn| { ··· 73 74 deadpool_diesel::sqlite::Object, 74 75 >, 75 76 ) -> Result<bool> { 76 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 77 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 77 78 let did = did.to_owned(); 78 79 db.get() 79 80 .await? ··· 96 97 deadpool_diesel::sqlite::Object, 97 98 >, 98 99 ) -> Result<bool> { 99 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 100 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 100 101 101 102 let did = did.to_owned(); 102 103 let app_pass_name = app_pass_name.to_owned(); ··· 121 122 deadpool_diesel::sqlite::Object, 122 123 >, 123 124 ) -> Result<Option<models::RefreshToken>> { 124 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 125 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 125 126 let id = id.to_owned(); 126 127 db.get() 127 128 .await? ··· 143 144 deadpool_diesel::sqlite::Object, 144 145 >, 145 146 ) -> Result<()> { 146 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 147 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 147 148 let did = did.to_owned(); 148 149 149 150 db.get() 150 151 .await? 151 152 .interact(move |conn| { 152 - delete(RefreshTokenSchema::refresh_token) 153 + _ = delete(RefreshTokenSchema::refresh_token) 153 154 .filter(RefreshTokenSchema::did.eq(did)) 154 155 .filter(RefreshTokenSchema::expiresAt.le(now)) 155 156 .execute(conn)?; ··· 174 175 expires_at, 175 176 next_id, 176 177 } = opts; 177 - use rsky_pds::schema::pds::refresh_token::dsl as RefreshTokenSchema; 178 + use crate::schema::pds::refresh_token::dsl as RefreshTokenSchema; 178 179 179 - update(RefreshTokenSchema::refresh_token) 180 - .filter(RefreshTokenSchema::id.eq(id)) 181 - .filter( 182 - RefreshTokenSchema::nextId 183 - .is_null() 184 - .or(RefreshTokenSchema::nextId.eq(&next_id)), 185 - ) 186 - .set(( 187 - RefreshTokenSchema::expiresAt.eq(expires_at), 188 - RefreshTokenSchema::nextId.eq(&next_id), 189 - )) 190 - .returning(models::RefreshToken::as_select()) 191 - .get_results(conn) 192 - .map_err(|error| { 193 - anyhow::Error::new(AuthHelperError::ConcurrentRefresh).context(error) 194 - })?; 180 + drop( 181 + update(RefreshTokenSchema::refresh_token) 182 + .filter(RefreshTokenSchema::id.eq(id)) 183 + .filter( 184 + RefreshTokenSchema::nextId 185 + .is_null() 186 + .or(RefreshTokenSchema::nextId.eq(&next_id)), 187 + ) 188 + .set(( 189 + RefreshTokenSchema::expiresAt.eq(expires_at), 190 + RefreshTokenSchema::nextId.eq(&next_id), 191 + )) 192 + .returning(models::RefreshToken::as_select()) 193 + .get_results(conn) 194 + .map_err(|error| { 195 + anyhow::Error::new(AuthHelperError::ConcurrentRefresh).context(error) 196 + })?, 197 + ); 195 198 Ok(()) 196 199 }) 197 200 .await
+12 -87
src/account_manager/helpers/email_token.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 + use crate::models::pds::EmailToken; 6 + use crate::models::pds::EmailTokenPurpose; 5 7 use anyhow::{Result, bail}; 6 8 use diesel::*; 7 9 use rsky_common::time::{MINUTE, from_str_to_utc, less_than_ago_s}; 8 10 use rsky_pds::apis::com::atproto::server::get_random_token; 9 - use rsky_pds::models::EmailToken; 10 11 11 12 pub async fn create_email_token( 12 13 did: &str, ··· 16 17 deadpool_diesel::sqlite::Object, 17 18 >, 18 19 ) -> Result<String> { 19 - use rsky_pds::schema::pds::email_token::dsl as EmailTokenSchema; 20 + use crate::schema::pds::email_token::dsl as EmailTokenSchema; 20 21 let token = get_random_token().to_uppercase(); 21 22 let now = rsky_common::now(); 22 23 ··· 24 25 db.get() 25 26 .await? 26 27 .interact(move |conn| { 27 - insert_into(EmailTokenSchema::email_token) 28 + _ = insert_into(EmailTokenSchema::email_token) 28 29 .values(( 29 30 EmailTokenSchema::purpose.eq(purpose), 30 31 EmailTokenSchema::did.eq(did), ··· 55 56 >, 56 57 ) -> Result<()> { 57 58 let expiration_len = expiration_len.unwrap_or(MINUTE * 15); 58 - use rsky_pds::schema::pds::email_token::dsl as EmailTokenSchema; 59 + use crate::schema::pds::email_token::dsl as EmailTokenSchema; 59 60 60 61 let did = did.to_owned(); 61 62 let token = token.to_owned(); ··· 95 96 >, 96 97 ) -> Result<String> { 97 98 let expiration_len = expiration_len.unwrap_or(MINUTE * 15); 98 - use rsky_pds::schema::pds::email_token::dsl as EmailTokenSchema; 99 + use crate::schema::pds::email_token::dsl as EmailTokenSchema; 99 100 100 101 let token = token.to_owned(); 101 102 let res = db ··· 123 124 } 124 125 } 125 126 126 - #[derive( 127 - Clone, 128 - Copy, 129 - Debug, 130 - PartialEq, 131 - Eq, 132 - Hash, 133 - Default, 134 - serde::Serialize, 135 - serde::Deserialize, 136 - AsExpression, 137 - )] 138 - #[diesel(sql_type = sql_types::Text)] 139 - pub enum EmailTokenPurpose { 140 - #[default] 141 - ConfirmEmail, 142 - UpdateEmail, 143 - ResetPassword, 144 - DeleteAccount, 145 - PlcOperation, 146 - } 147 - 148 - impl EmailTokenPurpose { 149 - pub fn as_str(&self) -> &'static str { 150 - match self { 151 - EmailTokenPurpose::ConfirmEmail => "confirm_email", 152 - EmailTokenPurpose::UpdateEmail => "update_email", 153 - EmailTokenPurpose::ResetPassword => "reset_password", 154 - EmailTokenPurpose::DeleteAccount => "delete_account", 155 - EmailTokenPurpose::PlcOperation => "plc_operation", 156 - } 157 - } 158 - 159 - pub fn from_str(s: &str) -> Result<Self> { 160 - match s { 161 - "confirm_email" => Ok(EmailTokenPurpose::ConfirmEmail), 162 - "update_email" => Ok(EmailTokenPurpose::UpdateEmail), 163 - "reset_password" => Ok(EmailTokenPurpose::ResetPassword), 164 - "delete_account" => Ok(EmailTokenPurpose::DeleteAccount), 165 - "plc_operation" => Ok(EmailTokenPurpose::PlcOperation), 166 - _ => bail!("Unable to parse as EmailTokenPurpose: `{s:?}`"), 167 - } 168 - } 169 - } 170 - 171 - impl<DB> Queryable<sql_types::Text, DB> for EmailTokenPurpose 172 - where 173 - DB: backend::Backend, 174 - String: deserialize::FromSql<sql_types::Text, DB>, 175 - { 176 - type Row = String; 177 - 178 - fn build(s: String) -> deserialize::Result<Self> { 179 - Ok(EmailTokenPurpose::from_str(&s)?) 180 - } 181 - } 182 - 183 - impl serialize::ToSql<sql_types::Text, sqlite::Sqlite> for EmailTokenPurpose 184 - where 185 - String: serialize::ToSql<sql_types::Text, sqlite::Sqlite>, 186 - { 187 - fn to_sql<'lifetime>( 188 - &'lifetime self, 189 - out: &mut serialize::Output<'lifetime, '_, sqlite::Sqlite>, 190 - ) -> serialize::Result { 191 - serialize::ToSql::<sql_types::Text, sqlite::Sqlite>::to_sql( 192 - match self { 193 - EmailTokenPurpose::ConfirmEmail => "confirm_email", 194 - EmailTokenPurpose::UpdateEmail => "update_email", 195 - EmailTokenPurpose::ResetPassword => "reset_password", 196 - EmailTokenPurpose::DeleteAccount => "delete_account", 197 - EmailTokenPurpose::PlcOperation => "plc_operation", 198 - }, 199 - out, 200 - ) 201 - } 202 - } 203 - 204 127 pub async fn delete_email_token( 205 128 did: &str, 206 129 purpose: EmailTokenPurpose, ··· 209 132 deadpool_diesel::sqlite::Object, 210 133 >, 211 134 ) -> Result<()> { 212 - use rsky_pds::schema::pds::email_token::dsl as EmailTokenSchema; 135 + use crate::schema::pds::email_token::dsl as EmailTokenSchema; 213 136 let did = did.to_owned(); 214 - db.get() 137 + _ = db 138 + .get() 215 139 .await? 216 140 .interact(move |conn| { 217 141 delete(EmailTokenSchema::email_token) ··· 231 155 deadpool_diesel::sqlite::Object, 232 156 >, 233 157 ) -> Result<()> { 234 - use rsky_pds::schema::pds::email_token::dsl as EmailTokenSchema; 158 + use crate::schema::pds::email_token::dsl as EmailTokenSchema; 235 159 236 160 let did = did.to_owned(); 237 - db.get() 161 + _ = db 162 + .get() 238 163 .await? 239 164 .interact(move |conn| { 240 165 delete(EmailTokenSchema::email_token)
+40 -30
src/account_manager/helpers/invite.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 + use crate::models::pds as models; 5 6 use anyhow::{Result, bail}; 6 7 use diesel::*; 7 8 use rsky_lexicon::com::atproto::server::AccountCodes; ··· 9 10 InviteCode as LexiconInviteCode, InviteCodeUse as LexiconInviteCodeUse, 10 11 }; 11 12 use rsky_pds::account_manager::DisableInviteCodesOpts; 12 - use rsky_pds::models::models; 13 13 use std::collections::BTreeMap; 14 14 use std::mem; 15 15 ··· 23 23 deadpool_diesel::sqlite::Object, 24 24 >, 25 25 ) -> Result<()> { 26 - use rsky_pds::schema::pds::actor::dsl as ActorSchema; 27 - use rsky_pds::schema::pds::invite_code::dsl as InviteCodeSchema; 28 - use rsky_pds::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 26 + use crate::schema::pds::actor::dsl as ActorSchema; 27 + use crate::schema::pds::invite_code::dsl as InviteCodeSchema; 28 + use crate::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 29 29 30 30 db.get().await?.interact(move |conn| { 31 31 let invite: Option<models::InviteCode> = InviteCodeSchema::invite_code ··· 39 39 .first(conn) 40 40 .optional()?; 41 41 42 - if invite.is_none() || invite.clone().unwrap().disabled > 0 { 43 - bail!("InvalidInviteCode: None or disabled. Provided invite code not available `{invite_code:?}`") 44 - } 42 + if let Some(invite) = invite { 43 + if invite.disabled > 0 { 44 + bail!("InvalidInviteCode: Disabled. Provided invite code not available `{invite_code:?}`"); 45 + } 45 46 46 - let uses: i64 = InviteCodeUseSchema::invite_code_use 47 - .count() 48 - .filter(InviteCodeUseSchema::code.eq(&invite_code)) 49 - .first(conn)?; 47 + let uses: i64 = InviteCodeUseSchema::invite_code_use 48 + .count() 49 + .filter(InviteCodeUseSchema::code.eq(&invite_code)) 50 + .first(conn)?; 50 51 51 - if invite.unwrap().available_uses as i64 <= uses { 52 - bail!("InvalidInviteCode: Not enough uses. Provided invite code not available `{invite_code:?}`") 52 + if invite.available_uses as i64 <= uses { 53 + bail!("InvalidInviteCode: Not enough uses. Provided invite code not available `{invite_code:?}`"); 54 + } 55 + } else { 56 + bail!("InvalidInviteCode: None. Provided invite code not available `{invite_code:?}`"); 53 57 } 58 + 54 59 Ok(()) 55 60 }).await.expect("Failed to check invite code availability")?; 56 61 ··· 67 72 >, 68 73 ) -> Result<()> { 69 74 if let Some(invite_code) = invite_code { 70 - use rsky_pds::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 75 + use crate::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 71 76 72 - db.get() 77 + _ = db 78 + .get() 73 79 .await? 74 80 .interact(move |conn| { 75 81 insert_into(InviteCodeUseSchema::invite_code_use) ··· 94 100 deadpool_diesel::sqlite::Object, 95 101 >, 96 102 ) -> Result<()> { 97 - use rsky_pds::schema::pds::invite_code::dsl as InviteCodeSchema; 103 + use crate::schema::pds::invite_code::dsl as InviteCodeSchema; 98 104 let created_at = rsky_common::now(); 99 105 100 - db.get() 106 + _ = db 107 + .get() 101 108 .await? 102 109 .interact(move |conn| { 103 110 let rows: Vec<models::InviteCode> = to_create ··· 137 144 deadpool_diesel::sqlite::Object, 138 145 >, 139 146 ) -> Result<Vec<CodeDetail>> { 140 - use rsky_pds::schema::pds::invite_code::dsl as InviteCodeSchema; 147 + use crate::schema::pds::invite_code::dsl as InviteCodeSchema; 141 148 142 149 let for_account = for_account.to_owned(); 143 150 let rows = db ··· 158 165 }) 159 166 .collect(); 160 167 161 - insert_into(InviteCodeSchema::invite_code) 168 + _ = insert_into(InviteCodeSchema::invite_code) 162 169 .values(&rows) 163 170 .execute(conn)?; 164 171 ··· 194 201 deadpool_diesel::sqlite::Object, 195 202 >, 196 203 ) -> Result<Vec<CodeDetail>> { 197 - use rsky_pds::schema::pds::invite_code::dsl as InviteCodeSchema; 204 + use crate::schema::pds::invite_code::dsl as InviteCodeSchema; 198 205 199 206 let did = did.to_owned(); 200 207 let res: Vec<models::InviteCode> = db ··· 232 239 deadpool_diesel::sqlite::Object, 233 240 >, 234 241 ) -> Result<BTreeMap<String, Vec<CodeUse>>> { 235 - use rsky_pds::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 242 + use crate::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 236 243 237 244 let mut uses: BTreeMap<String, Vec<CodeUse>> = BTreeMap::new(); 238 245 if !codes.is_empty() { ··· 256 263 } = invite_code_use; 257 264 match uses.get_mut(&code) { 258 265 None => { 259 - uses.insert(code, vec![CodeUse { used_by, used_at }]); 266 + drop(uses.insert(code, vec![CodeUse { used_by, used_at }])); 260 267 } 261 268 Some(matched_uses) => matched_uses.push(CodeUse { used_by, used_at }), 262 269 }; ··· 275 282 if dids.is_empty() { 276 283 return Ok(BTreeMap::new()); 277 284 } 278 - use rsky_pds::schema::pds::invite_code::dsl as InviteCodeSchema; 279 - use rsky_pds::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 285 + use crate::schema::pds::invite_code::dsl as InviteCodeSchema; 286 + use crate::schema::pds::invite_code_use::dsl as InviteCodeUseSchema; 280 287 281 288 let dids = dids.clone(); 282 289 let res: Vec<models::InviteCode> = db ··· 317 324 BTreeMap::new(), 318 325 |mut acc: BTreeMap<String, CodeDetail>, cur| { 319 326 for code_use in &cur.uses { 320 - acc.insert(code_use.used_by.clone(), cur.clone()); 327 + drop(acc.insert(code_use.used_by.clone(), cur.clone())); 321 328 } 322 329 acc 323 330 }, ··· 332 339 deadpool_diesel::sqlite::Object, 333 340 >, 334 341 ) -> Result<()> { 335 - use rsky_pds::schema::pds::account::dsl as AccountSchema; 342 + use crate::schema::pds::account::dsl as AccountSchema; 336 343 337 344 let disabled: i16 = if disabled { 1 } else { 0 }; 338 345 let did = did.to_owned(); 339 - db.get() 346 + _ = db 347 + .get() 340 348 .await? 341 349 .interact(move |conn| { 342 350 update(AccountSchema::account) ··· 356 364 deadpool_diesel::sqlite::Object, 357 365 >, 358 366 ) -> Result<()> { 359 - use rsky_pds::schema::pds::invite_code::dsl as InviteCodeSchema; 367 + use crate::schema::pds::invite_code::dsl as InviteCodeSchema; 360 368 361 369 let DisableInviteCodesOpts { codes, accounts } = opts; 362 370 if !codes.is_empty() { 363 - db.get() 371 + _ = db 372 + .get() 364 373 .await? 365 374 .interact(move |conn| { 366 375 update(InviteCodeSchema::invite_code) ··· 372 381 .expect("Failed to disable invite codes")?; 373 382 } 374 383 if !accounts.is_empty() { 375 - db.get() 384 + _ = db 385 + .get() 376 386 .await? 377 387 .interact(move |conn| { 378 388 update(InviteCodeSchema::invite_code)
+10 -10
src/account_manager/helpers/password.rs
··· 2 2 //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 3 //! 4 4 //! Modified for SQLite backend 5 + use crate::models::pds as models; 6 + use crate::models::pds::AppPassword; 5 7 use anyhow::{Result, bail}; 6 8 use diesel::*; 7 9 use rsky_common::{get_random_str, now}; ··· 10 12 pub(crate) use rsky_pds::account_manager::helpers::password::{ 11 13 UpdateUserPasswordOpts, gen_salt_and_hash, hash_app_password, hash_with_salt, verify, 12 14 }; 13 - use rsky_pds::models; 14 - use rsky_pds::models::AppPassword; 15 15 16 16 pub async fn verify_account_password( 17 17 did: &str, ··· 21 21 deadpool_diesel::sqlite::Object, 22 22 >, 23 23 ) -> Result<bool> { 24 - use rsky_pds::schema::pds::account::dsl as AccountSchema; 24 + use crate::schema::pds::account::dsl as AccountSchema; 25 25 26 26 let did = did.to_owned(); 27 27 let found = db ··· 51 51 deadpool_diesel::sqlite::Object, 52 52 >, 53 53 ) -> Result<Option<String>> { 54 - use rsky_pds::schema::pds::app_password::dsl as AppPasswordSchema; 54 + use crate::schema::pds::app_password::dsl as AppPasswordSchema; 55 55 56 56 let did = did.to_owned(); 57 57 let password = password.to_owned(); ··· 91 91 let password = chunks.join("-"); 92 92 let password_encrypted = hash_app_password(&did, &password).await?; 93 93 94 - use rsky_pds::schema::pds::app_password::dsl as AppPasswordSchema; 94 + use crate::schema::pds::app_password::dsl as AppPasswordSchema; 95 95 96 96 let created_at = now(); 97 97 ··· 129 129 deadpool_diesel::sqlite::Object, 130 130 >, 131 131 ) -> Result<Vec<(String, String)>> { 132 - use rsky_pds::schema::pds::app_password::dsl as AppPasswordSchema; 132 + use crate::schema::pds::app_password::dsl as AppPasswordSchema; 133 133 134 134 let did = did.to_owned(); 135 135 db.get() ··· 151 151 deadpool_diesel::sqlite::Object, 152 152 >, 153 153 ) -> Result<()> { 154 - use rsky_pds::schema::pds::account::dsl as AccountSchema; 154 + use crate::schema::pds::account::dsl as AccountSchema; 155 155 156 156 db.get() 157 157 .await? 158 158 .interact(move |conn| { 159 - update(AccountSchema::account) 159 + _ = update(AccountSchema::account) 160 160 .filter(AccountSchema::did.eq(opts.did)) 161 161 .set(AccountSchema::password.eq(opts.password_encrypted)) 162 162 .execute(conn)?; ··· 174 174 deadpool_diesel::sqlite::Object, 175 175 >, 176 176 ) -> Result<()> { 177 - use rsky_pds::schema::pds::app_password::dsl as AppPasswordSchema; 177 + use crate::schema::pds::app_password::dsl as AppPasswordSchema; 178 178 179 179 let did = did.to_owned(); 180 180 let name = name.to_owned(); 181 181 db.get() 182 182 .await? 183 183 .interact(move |conn| { 184 - delete(AppPasswordSchema::app_password) 184 + _ = delete(AppPasswordSchema::app_password) 185 185 .filter(AppPasswordSchema::did.eq(did)) 186 186 .filter(AppPasswordSchema::name.eq(name)) 187 187 .execute(conn)?;
+5 -6
src/account_manager/helpers/repo.rs
··· 4 4 //! Modified for SQLite backend 5 5 use anyhow::Result; 6 6 use cidv10::Cid; 7 + use deadpool_diesel::{Manager, Pool, sqlite::Object}; 7 8 use diesel::*; 8 9 9 10 pub async fn update_root( 10 11 did: String, 11 12 cid: Cid, 12 13 rev: String, 13 - db: &deadpool_diesel::Pool< 14 - deadpool_diesel::Manager<SqliteConnection>, 15 - deadpool_diesel::sqlite::Object, 16 - >, 14 + db: &Pool<Manager<SqliteConnection>, Object>, 17 15 ) -> Result<()> { 18 16 // @TODO balance risk of a race in the case of a long retry 19 - use rsky_pds::schema::pds::repo_root::dsl as RepoRootSchema; 17 + use crate::schema::actor_store::repo_root::dsl as RepoRootSchema; 20 18 21 19 let now = rsky_common::now(); 22 20 23 - db.get() 21 + _ = db 22 + .get() 24 23 .await? 25 24 .interact(move |conn| { 26 25 insert_into(RepoRootSchema::repo_root)
+71 -14
src/account_manager/mod.rs
··· 10 10 }; 11 11 use crate::account_manager::helpers::invite::CodeDetail; 12 12 use crate::account_manager::helpers::password::UpdateUserPasswordOpts; 13 + use crate::models::pds::EmailTokenPurpose; 14 + use crate::serve::ActorStorage; 13 15 use anyhow::Result; 14 16 use chrono::DateTime; 15 17 use chrono::offset::Utc as UtcOffset; 16 18 use cidv10::Cid; 17 19 use diesel::*; 18 20 use futures::try_join; 19 - use helpers::email_token::EmailTokenPurpose; 20 21 use helpers::{account, auth, email_token, invite, password, repo}; 21 22 use rsky_common::RFC3339_VARIANT; 22 23 use rsky_common::time::{HOUR, from_micros_to_str, from_str_to_micros}; ··· 31 32 use std::collections::BTreeMap; 32 33 use std::env; 33 34 use std::time::SystemTime; 35 + use tokio::sync::RwLock; 34 36 35 37 pub(crate) mod helpers { 36 38 pub mod account; ··· 66 68 >; 67 69 68 70 impl AccountManager { 69 - pub fn new( 71 + pub const fn new( 70 72 db: deadpool_diesel::Pool< 71 73 deadpool_diesel::Manager<SqliteConnection>, 72 74 deadpool_diesel::sqlite::Object, ··· 81 83 deadpool_diesel::Manager<SqliteConnection>, 82 84 deadpool_diesel::sqlite::Object, 83 85 >| 84 - -> AccountManager { AccountManager::new(db) }, 86 + -> Self { Self::new(db) }, 85 87 ) 86 88 } 87 89 ··· 129 131 } 130 132 } 131 133 132 - pub async fn create_account(&self, opts: CreateAccountOpts) -> Result<(String, String)> { 134 + pub async fn create_account( 135 + &self, 136 + opts: CreateAccountOpts, 137 + actor_pools: &mut std::collections::HashMap<String, ActorStorage>, 138 + ) -> Result<(String, String)> { 133 139 let CreateAccountOpts { 134 140 did, 135 141 handle, ··· 153 159 let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts { 154 160 did: did.clone(), 155 161 jwt_key, 156 - service_did: env::var("PDS_SERVICE_DID").unwrap(), 162 + service_did: env::var("PDS_SERVICE_DID").expect("PDS_SERVICE_DID not set"), 157 163 scope: Some(AuthScope::Access), 158 164 jti: None, 159 165 expires_in: None, ··· 170 176 } 171 177 invite::record_invite_use(did.clone(), invite_code, now, &self.db).await?; 172 178 auth::store_refresh_token(refresh_payload, None, &self.db).await?; 173 - repo::update_root(did, repo_cid, repo_rev, &self.db).await?; 179 + 180 + let did_path = did 181 + .strip_prefix("did:plc:") 182 + .ok_or_else(|| anyhow::anyhow!("Invalid DID"))?; 183 + let repo_path = format!("sqlite://data/repo/{}.db", did_path); 184 + let actor_repo_pool = 185 + crate::db::establish_pool(repo_path.as_str()).expect("Failed to establish pool"); 186 + let blob_path = std::path::Path::new("data/blob").to_path_buf(); 187 + let actor_pool = ActorStorage { 188 + repo: actor_repo_pool, 189 + blob: blob_path.clone(), 190 + }; 191 + let blob_path = blob_path.join(did_path); 192 + tokio::fs::create_dir_all(&blob_path) 193 + .await 194 + .map_err(|_| anyhow::anyhow!("Failed to create blob path"))?; 195 + drop( 196 + actor_pools 197 + .insert(did.clone(), actor_pool) 198 + .expect("Failed to insert actor pools"), 199 + ); 200 + let db = actor_pools 201 + .get(&did) 202 + .ok_or_else(|| anyhow::anyhow!("Actor not found"))? 203 + .repo 204 + .clone(); 205 + repo::update_root(did, repo_cid, repo_rev, &db).await?; 174 206 Ok((access_jwt, refresh_jwt)) 175 207 } 176 208 ··· 181 213 account::get_account_admin_status(did, &self.db).await 182 214 } 183 215 184 - pub async fn update_repo_root(&self, did: String, cid: Cid, rev: String) -> Result<()> { 185 - repo::update_root(did, cid, rev, &self.db).await 216 + pub async fn update_repo_root( 217 + &self, 218 + did: String, 219 + cid: Cid, 220 + rev: String, 221 + actor_pools: &std::collections::HashMap<String, ActorStorage>, 222 + ) -> Result<()> { 223 + let db = actor_pools 224 + .get(&did) 225 + .ok_or_else(|| anyhow::anyhow!("Actor not found"))? 226 + .repo 227 + .clone(); 228 + repo::update_root(did, cid, rev, &db).await 186 229 } 187 230 188 - pub async fn delete_account(&self, did: &str) -> Result<()> { 189 - account::delete_account(did, &self.db).await 231 + pub async fn delete_account( 232 + &self, 233 + did: &str, 234 + actor_pools: &std::collections::HashMap<String, ActorStorage>, 235 + ) -> Result<()> { 236 + let db = actor_pools 237 + .get(did) 238 + .ok_or_else(|| anyhow::anyhow!("Actor not found"))? 239 + .repo 240 + .clone(); 241 + account::delete_account(did, &self.db, &db).await 190 242 } 191 243 192 244 pub async fn takedown_account(&self, did: &str, takedown: StatusAttr) -> Result<()> { ··· 246 298 let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts { 247 299 did, 248 300 jwt_key, 249 - service_did: env::var("PDS_SERVICE_DID").unwrap(), 301 + service_did: env::var("PDS_SERVICE_DID").expect("PDS_SERVICE_DID not set"), 250 302 scope: Some(scope), 251 303 jti: None, 252 304 expires_in: None, ··· 289 341 let next_id = token.next_id.unwrap_or_else(auth::get_refresh_token_id); 290 342 291 343 let secp = Secp256k1::new(); 292 - let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX").unwrap(); 344 + let private_key = env::var("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX") 345 + .expect("PDS_JWT_KEY_K256_PRIVATE_KEY_HEX not set"); 293 346 let secret_key = 294 - SecretKey::from_slice(&hex::decode(private_key.as_bytes()).unwrap()).unwrap(); 347 + SecretKey::from_slice(&hex::decode(private_key.as_bytes()).expect("Invalid key"))?; 295 348 let jwt_key = Keypair::from_secret_key(&secp, &secret_key); 296 349 297 350 let (access_jwt, refresh_jwt) = auth::create_tokens(CreateTokensOpts { 298 351 did: token.did, 299 352 jwt_key, 300 - service_did: env::var("PDS_SERVICE_DID").unwrap(), 353 + service_did: env::var("PDS_SERVICE_DID").expect("PDS_SERVICE_DID not set"), 301 354 scope: Some(if token.app_password_name.is_none() { 302 355 AuthScope::Access 303 356 } else { ··· 499 552 email_token::create_email_token(did, purpose, &self.db).await 500 553 } 501 554 } 555 + 556 + pub struct SharedAccountManager { 557 + pub account_manager: RwLock<AccountManager>, 558 + }
+24 -8
src/actor_endpoints.rs
··· 1 + /// HACK: store private user preferences in the PDS. 2 + /// 3 + /// We shouldn't have to know about any bsky endpoints to store private user data. 4 + /// This will _very likely_ be changed in the future. 1 5 use atrium_api::app::bsky::actor; 2 - use axum::{Json, routing::post}; 6 + use axum::{ 7 + Json, Router, 8 + extract::State, 9 + routing::{get, post}, 10 + }; 3 11 use constcat::concat; 4 - use diesel::prelude::*; 5 12 6 - use crate::actor_store::ActorStore; 13 + use crate::auth::AuthenticatedUser; 7 14 8 - use super::*; 15 + use super::serve::*; 9 16 10 17 async fn put_preferences( 11 18 user: AuthenticatedUser, 12 - State(actor_pools): State<std::collections::HashMap<String, ActorPools>>, 19 + State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>, 13 20 Json(input): Json<actor::put_preferences::Input>, 14 21 ) -> Result<()> { 15 22 let did = user.did(); 16 - let json_string = 17 - serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 23 + // let json_string = 24 + // serde_json::to_string(&input.preferences).context("failed to serialize preferences")?; 18 25 19 26 // let conn = &mut actor_pools 20 27 // .get(&did) ··· 31 38 // .context("failed to update user preferences") 32 39 // }); 33 40 todo!("Use actor_store's preferences writer instead"); 41 + // let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 42 + // let values = actor::defs::Preferences { 43 + // private_prefs: Some(json_string), 44 + // ..Default::default() 45 + // }; 46 + // let namespace = actor::defs::PreferencesNamespace::Private; 47 + // let scope = actor::defs::PreferencesScope::User; 48 + // actor_store.pref.put_preferences(values, namespace, scope); 49 + 34 50 Ok(()) 35 51 } 36 52 37 53 async fn get_preferences( 38 54 user: AuthenticatedUser, 39 - State(actor_pools): State<std::collections::HashMap<String, ActorPools>>, 55 + State(actor_pools): State<std::collections::HashMap<String, ActorStorage>>, 40 56 ) -> Result<Json<actor::get_preferences::Output>> { 41 57 let did = user.did(); 42 58 // let conn = &mut actor_pools
+117 -81
src/actor_store/blob.rs
··· 4 4 //! 5 5 //! Modified for SQLite backend 6 6 7 + use crate::models::actor_store as models; 7 8 use anyhow::{Result, bail}; 9 + use axum::body::Bytes; 8 10 use cidv10::Cid; 9 11 use diesel::dsl::{count_distinct, exists, not}; 10 12 use diesel::sql_types::{Integer, Nullable, Text}; ··· 19 21 use rsky_lexicon::com::atproto::admin::StatusAttr; 20 22 use rsky_lexicon::com::atproto::repo::ListMissingBlobsRefRecordBlob; 21 23 use rsky_pds::actor_store::blob::{ 22 - BlobMetadata, GetBlobMetadataOutput, ListBlobsOpts, ListMissingBlobsOpts, sha256_stream, 23 - verify_blob, 24 + BlobMetadata, GetBlobMetadataOutput, ListBlobsOpts, ListMissingBlobsOpts, accepted_mime, 25 + sha256_stream, 24 26 }; 25 27 use rsky_pds::image; 26 - use rsky_pds::models::models; 27 28 use rsky_repo::error::BlobError; 28 29 use rsky_repo::types::{PreparedBlobRef, PreparedWrite}; 29 30 use std::str::FromStr as _; 30 31 31 - use super::sql_blob::{BlobStoreSql, ByteStream}; 32 + use super::blob_fs::{BlobStoreFs, ByteStream}; 32 33 33 34 pub struct GetBlobOutput { 34 35 pub size: i32, ··· 39 40 /// Handles blob operations for an actor store 40 41 pub struct BlobReader { 41 42 /// SQL-based blob storage 42 - pub blobstore: BlobStoreSql, 43 + pub blobstore: BlobStoreFs, 43 44 /// DID of the actor 44 45 pub did: String, 45 46 /// Database connection ··· 52 53 impl BlobReader { 53 54 /// Create a new blob reader 54 55 pub fn new( 55 - blobstore: BlobStoreSql, 56 + blobstore: BlobStoreFs, 56 57 db: deadpool_diesel::Pool< 57 58 deadpool_diesel::Manager<SqliteConnection>, 58 59 deadpool_diesel::sqlite::Object, 59 60 >, 60 61 ) -> Self { 61 - BlobReader { 62 + Self { 62 63 did: blobstore.did.clone(), 63 64 blobstore, 64 65 db, ··· 67 68 68 69 /// Get metadata for a blob by CID 69 70 pub async fn get_blob_metadata(&self, cid: Cid) -> Result<GetBlobMetadataOutput> { 70 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 71 + use crate::schema::actor_store::blob::dsl as BlobSchema; 71 72 72 73 let did = self.did.clone(); 73 74 let found = self ··· 112 113 113 114 /// Get all records that reference a specific blob 114 115 pub async fn get_records_for_blob(&self, cid: Cid) -> Result<Vec<String>> { 115 - use rsky_pds::schema::pds::record_blob::dsl as RecordBlobSchema; 116 + use crate::schema::actor_store::record_blob::dsl as RecordBlobSchema; 116 117 117 118 let did = self.did.clone(); 118 119 let res = self ··· 138 139 pub async fn upload_blob_and_get_metadata( 139 140 &self, 140 141 user_suggested_mime: String, 141 - blob: Vec<u8>, 142 + blob: Bytes, 142 143 ) -> Result<BlobMetadata> { 143 144 let bytes = blob; 144 145 let size = bytes.len() as i64; 145 146 146 147 let (temp_key, sha256, img_info, sniffed_mime) = try_join!( 147 148 self.blobstore.put_temp(bytes.clone()), 148 - sha256_stream(bytes.clone()), 149 - image::maybe_get_info(bytes.clone()), 150 - image::mime_type_from_bytes(bytes.clone()) 149 + // TODO: reimpl funcs to use Bytes instead of Vec<u8> 150 + sha256_stream(bytes.to_vec()), 151 + image::maybe_get_info(bytes.to_vec()), 152 + image::mime_type_from_bytes(bytes.to_vec()) 151 153 )?; 152 154 153 155 let cid = sha256_raw_to_cid(sha256); ··· 158 160 size, 159 161 cid, 160 162 mime_type, 161 - width: if let Some(ref info) = img_info { 162 - Some(info.width as i32) 163 - } else { 164 - None 165 - }, 163 + width: img_info.as_ref().map(|info| info.width as i32), 166 164 height: if let Some(info) = img_info { 167 165 Some(info.height as i32) 168 166 } else { ··· 173 171 174 172 /// Track a blob that hasn't been associated with any records yet 175 173 pub async fn track_untethered_blob(&self, metadata: BlobMetadata) -> Result<BlobRef> { 176 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 174 + use crate::schema::actor_store::blob::dsl as BlobSchema; 177 175 178 176 let did = self.did.clone(); 179 177 self.db.get().await?.interact(move |conn| { ··· 207 205 SET \"tempKey\" = EXCLUDED.\"tempKey\" \ 208 206 WHERE pds.blob.\"tempKey\" is not null;"); 209 207 #[expect(trivial_casts)] 210 - upsert 208 + let _ = upsert 211 209 .bind::<Text, _>(&cid.to_string()) 212 210 .bind::<Text, _>(&did) 213 211 .bind::<Text, _>(&mime_type) 214 212 .bind::<Integer, _>(size as i32) 215 - .bind::<Nullable<Text>, _>(Some(temp_key.clone())) 213 + .bind::<Nullable<Text>, _>(Some(temp_key)) 216 214 .bind::<Nullable<Integer>, _>(width) 217 215 .bind::<Nullable<Integer>, _>(height) 218 216 .bind::<Text, _>(created_at) ··· 227 225 pub async fn process_write_blobs(&self, writes: Vec<PreparedWrite>) -> Result<()> { 228 226 self.delete_dereferenced_blobs(writes.clone()).await?; 229 227 230 - let _ = stream::iter(writes) 231 - .then(|write| async move { 232 - Ok::<(), anyhow::Error>(match write { 233 - PreparedWrite::Create(w) => { 234 - for blob in w.blobs { 235 - self.verify_blob_and_make_permanent(blob.clone()).await?; 236 - self.associate_blob(blob, w.uri.clone()).await?; 228 + drop( 229 + stream::iter(writes) 230 + .then(async move |write| { 231 + match write { 232 + PreparedWrite::Create(w) => { 233 + for blob in w.blobs { 234 + self.verify_blob_and_make_permanent(blob.clone()).await?; 235 + self.associate_blob(blob, w.uri.clone()).await?; 236 + } 237 237 } 238 - } 239 - PreparedWrite::Update(w) => { 240 - for blob in w.blobs { 241 - self.verify_blob_and_make_permanent(blob.clone()).await?; 242 - self.associate_blob(blob, w.uri.clone()).await?; 238 + PreparedWrite::Update(w) => { 239 + for blob in w.blobs { 240 + self.verify_blob_and_make_permanent(blob.clone()).await?; 241 + self.associate_blob(blob, w.uri.clone()).await?; 242 + } 243 243 } 244 - } 245 - _ => (), 244 + _ => (), 245 + }; 246 + Ok::<(), anyhow::Error>(()) 246 247 }) 247 - }) 248 - .collect::<Vec<_>>() 249 - .await 250 - .into_iter() 251 - .collect::<Result<Vec<_>, _>>()?; 248 + .collect::<Vec<_>>() 249 + .await 250 + .into_iter() 251 + .collect::<Result<Vec<_>, _>>()?, 252 + ); 252 253 253 254 Ok(()) 254 255 } 255 256 256 257 /// Delete blobs that are no longer referenced by any records 257 258 pub async fn delete_dereferenced_blobs(&self, writes: Vec<PreparedWrite>) -> Result<()> { 258 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 259 - use rsky_pds::schema::pds::record_blob::dsl as RecordBlobSchema; 259 + use crate::schema::actor_store::blob::dsl as BlobSchema; 260 + use crate::schema::actor_store::record_blob::dsl as RecordBlobSchema; 260 261 261 262 // Extract URIs 262 263 let uris: Vec<String> = writes ··· 295 296 296 297 // Now perform the delete 297 298 let uris_clone = uris.clone(); 298 - self.db 299 + _ = self 300 + .db 299 301 .get() 300 302 .await? 301 303 .interact(move |conn| { ··· 354 356 // Delete from the blob table 355 357 let cids = cids_to_delete.clone(); 356 358 let did_clone = self.did.clone(); 357 - self.db 359 + _ = self 360 + .db 358 361 .get() 359 362 .await? 360 363 .interact(move |conn| { ··· 368 371 369 372 // Delete from blob storage 370 373 // Ideally we'd use a background queue here, but for now: 371 - let _ = stream::iter(cids_to_delete) 372 - .then(|cid| async move { 373 - match Cid::from_str(&cid) { 374 + drop( 375 + stream::iter(cids_to_delete) 376 + .then(async move |cid| match Cid::from_str(&cid) { 374 377 Ok(cid) => self.blobstore.delete(cid.to_string()).await, 375 378 Err(e) => Err(anyhow::Error::new(e)), 376 - } 377 - }) 378 - .collect::<Vec<_>>() 379 - .await 380 - .into_iter() 381 - .collect::<Result<Vec<_>, _>>()?; 379 + }) 380 + .collect::<Vec<_>>() 381 + .await 382 + .into_iter() 383 + .collect::<Result<Vec<_>, _>>()?, 384 + ); 382 385 383 386 Ok(()) 384 387 } 385 388 386 389 /// Verify a blob and make it permanent 387 390 pub async fn verify_blob_and_make_permanent(&self, blob: PreparedBlobRef) -> Result<()> { 388 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 391 + use crate::schema::actor_store::blob::dsl as BlobSchema; 389 392 390 393 let found = self 391 394 .db ··· 412 415 .make_permanent(temp_key.clone(), blob.cid) 413 416 .await?; 414 417 } 415 - self.db 418 + _ = self 419 + .db 416 420 .get() 417 421 .await? 418 422 .interact(move |conn| { ··· 431 435 432 436 /// Associate a blob with a record 433 437 pub async fn associate_blob(&self, blob: PreparedBlobRef, record_uri: String) -> Result<()> { 434 - use rsky_pds::schema::pds::record_blob::dsl as RecordBlobSchema; 438 + use crate::schema::actor_store::record_blob::dsl as RecordBlobSchema; 435 439 436 440 let cid = blob.cid.to_string(); 437 441 let did = self.did.clone(); 438 442 439 - self.db 443 + _ = self 444 + .db 440 445 .get() 441 446 .await? 442 447 .interact(move |conn| { ··· 457 462 458 463 /// Count all blobs for this actor 459 464 pub async fn blob_count(&self) -> Result<i64> { 460 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 465 + use crate::schema::actor_store::blob::dsl as BlobSchema; 461 466 462 467 let did = self.did.clone(); 463 468 self.db ··· 476 481 477 482 /// Count blobs associated with records 478 483 pub async fn record_blob_count(&self) -> Result<i64> { 479 - use rsky_pds::schema::pds::record_blob::dsl as RecordBlobSchema; 484 + use crate::schema::actor_store::record_blob::dsl as RecordBlobSchema; 480 485 481 486 let did = self.did.clone(); 482 487 self.db ··· 498 503 &self, 499 504 opts: ListMissingBlobsOpts, 500 505 ) -> Result<Vec<ListMissingBlobsRefRecordBlob>> { 501 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 502 - use rsky_pds::schema::pds::record_blob::dsl as RecordBlobSchema; 506 + use crate::schema::actor_store::blob::dsl as BlobSchema; 507 + use crate::schema::actor_store::record_blob::dsl as RecordBlobSchema; 503 508 504 509 let did = self.did.clone(); 505 510 self.db ··· 560 565 561 566 /// List all blobs with optional filtering 562 567 pub async fn list_blobs(&self, opts: ListBlobsOpts) -> Result<Vec<String>> { 563 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 564 - use rsky_pds::schema::pds::record_blob::dsl as RecordBlobSchema; 568 + use crate::schema::actor_store::record::dsl as RecordSchema; 569 + use crate::schema::actor_store::record_blob::dsl as RecordBlobSchema; 565 570 566 571 let ListBlobsOpts { 567 572 since, ··· 614 619 615 620 /// Get the takedown status of a blob 616 621 pub async fn get_blob_takedown_status(&self, cid: Cid) -> Result<Option<StatusAttr>> { 617 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 622 + use crate::schema::actor_store::blob::dsl as BlobSchema; 618 623 619 624 self.db 620 625 .get() ··· 628 633 629 634 match res { 630 635 None => Ok(None), 631 - Some(res) => match res.takedown_ref { 632 - None => Ok(Some(StatusAttr { 633 - applied: false, 634 - r#ref: None, 635 - })), 636 - Some(takedown_ref) => Ok(Some(StatusAttr { 637 - applied: true, 638 - r#ref: Some(takedown_ref), 639 - })), 640 - }, 636 + Some(res) => res.takedown_ref.map_or_else( 637 + || { 638 + Ok(Some(StatusAttr { 639 + applied: false, 640 + r#ref: None, 641 + })) 642 + }, 643 + |takedown_ref| { 644 + Ok(Some(StatusAttr { 645 + applied: true, 646 + r#ref: Some(takedown_ref), 647 + })) 648 + }, 649 + ), 641 650 } 642 651 }) 643 652 .await ··· 646 655 647 656 /// Update the takedown status of a blob 648 657 pub async fn update_blob_takedown_status(&self, blob: Cid, takedown: StatusAttr) -> Result<()> { 649 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 658 + use crate::schema::actor_store::blob::dsl as BlobSchema; 650 659 651 660 let takedown_ref: Option<String> = match takedown.applied { 652 - true => match takedown.r#ref { 653 - Some(takedown_ref) => Some(takedown_ref), 654 - None => Some(now()), 655 - }, 661 + true => takedown.r#ref.map_or_else(|| Some(now()), Some), 656 662 false => None, 657 663 }; 658 664 659 665 let blob_cid = blob.to_string(); 660 666 let did_clone = self.did.clone(); 661 667 662 - self.db 668 + _ = self 669 + .db 663 670 .get() 664 671 .await? 665 672 .interact(move |conn| { 666 - update(BlobSchema::blob) 673 + _ = update(BlobSchema::blob) 667 674 .filter(BlobSchema::cid.eq(blob_cid)) 668 675 .filter(BlobSchema::did.eq(did_clone)) 669 676 .set(BlobSchema::takedownRef.eq(takedown_ref)) ··· 687 694 } 688 695 } 689 696 } 697 + 698 + pub async fn verify_blob(blob: &PreparedBlobRef, found: &models::Blob) -> Result<()> { 699 + if let Some(max_size) = blob.constraints.max_size { 700 + if found.size as usize > max_size { 701 + bail!( 702 + "BlobTooLarge: This file is too large. It is {:?} but the maximum size is {:?}", 703 + found.size, 704 + max_size 705 + ) 706 + } 707 + } 708 + if blob.mime_type != found.mime_type { 709 + bail!( 710 + "InvalidMimeType: Referenced MimeType does not match stored blob. Expected: {:?}, Got: {:?}", 711 + found.mime_type, 712 + blob.mime_type 713 + ) 714 + } 715 + if let Some(ref accept) = blob.constraints.accept { 716 + if !accepted_mime(blob.mime_type.clone(), accept.clone()).await { 717 + bail!( 718 + "Wrong type of file. It is {:?} but it must match {:?}.", 719 + blob.mime_type, 720 + accept 721 + ) 722 + } 723 + } 724 + Ok(()) 725 + }
+287
src/actor_store/blob_fs.rs
··· 1 + //! File system implementation of blob storage 2 + //! Based on the S3 implementation but using local file system instead 3 + use anyhow::Result; 4 + use axum::body::Bytes; 5 + use cidv10::Cid; 6 + use rsky_common::get_random_str; 7 + use rsky_repo::error::BlobError; 8 + use std::path::PathBuf; 9 + use std::str::FromStr; 10 + use tokio::fs as async_fs; 11 + use tokio::io::AsyncWriteExt; 12 + use tracing::{debug, error, warn}; 13 + 14 + /// ByteStream implementation for blob data 15 + pub struct ByteStream { 16 + pub bytes: Bytes, 17 + } 18 + 19 + impl ByteStream { 20 + /// Create a new ByteStream with the given bytes 21 + pub const fn new(bytes: Bytes) -> Self { 22 + Self { bytes } 23 + } 24 + 25 + /// Collect the bytes from the stream 26 + pub async fn collect(self) -> Result<Bytes> { 27 + Ok(self.bytes) 28 + } 29 + } 30 + 31 + /// Path information for moving a blob 32 + struct MoveObject { 33 + from: PathBuf, 34 + to: PathBuf, 35 + } 36 + 37 + /// File system implementation of blob storage 38 + pub struct BlobStoreFs { 39 + /// Base directory for storing blobs 40 + pub base_dir: PathBuf, 41 + /// DID of the actor 42 + pub did: String, 43 + } 44 + 45 + impl BlobStoreFs { 46 + /// Create a new file system blob store for the given DID and base directory 47 + pub const fn new(did: String, base_dir: PathBuf) -> Self { 48 + Self { base_dir, did } 49 + } 50 + 51 + /// Create a factory function for blob stores 52 + pub fn creator(base_dir: PathBuf) -> Box<dyn Fn(String) -> Self> { 53 + let base_dir_clone = base_dir; 54 + Box::new(move |did: String| Self::new(did, base_dir_clone.clone())) 55 + } 56 + 57 + /// Generate a random key for temporary storage 58 + fn gen_key(&self) -> String { 59 + get_random_str() 60 + } 61 + 62 + /// Get path to the temporary blob storage 63 + fn get_tmp_path(&self, key: &str) -> PathBuf { 64 + self.base_dir.join("tmp").join(&self.did).join(key) 65 + } 66 + 67 + /// Get path to the stored blob with appropriate sharding 68 + fn get_stored_path(&self, cid: Cid) -> PathBuf { 69 + let cid_str = cid.to_string(); 70 + 71 + // Create two-level sharded structure based on CID 72 + // First 10 chars for level 1, next 10 chars for level 2 73 + let first_level = if cid_str.len() >= 10 { 74 + &cid_str[0..10] 75 + } else { 76 + "short" 77 + }; 78 + 79 + let second_level = if cid_str.len() >= 20 { 80 + &cid_str[10..20] 81 + } else { 82 + "short" 83 + }; 84 + 85 + self.base_dir 86 + .join("blocks") 87 + .join(&self.did) 88 + .join(first_level) 89 + .join(second_level) 90 + .join(&cid_str) 91 + } 92 + 93 + /// Get path to the quarantined blob 94 + fn get_quarantined_path(&self, cid: Cid) -> PathBuf { 95 + let cid_str = cid.to_string(); 96 + self.base_dir 97 + .join("quarantine") 98 + .join(&self.did) 99 + .join(&cid_str) 100 + } 101 + 102 + /// Store a blob temporarily 103 + pub async fn put_temp(&self, bytes: Bytes) -> Result<String> { 104 + let key = self.gen_key(); 105 + let temp_path = self.get_tmp_path(&key); 106 + 107 + // Ensure the directory exists 108 + if let Some(parent) = temp_path.parent() { 109 + async_fs::create_dir_all(parent).await?; 110 + } 111 + 112 + // Write the temporary blob 113 + let mut file = async_fs::File::create(&temp_path).await?; 114 + file.write_all(&bytes).await?; 115 + file.flush().await?; 116 + 117 + debug!("Stored temp blob at: {:?}", temp_path); 118 + Ok(key) 119 + } 120 + 121 + /// Make a temporary blob permanent by moving it to the blob store 122 + pub async fn make_permanent(&self, key: String, cid: Cid) -> Result<()> { 123 + let already_has = self.has_stored(cid).await?; 124 + 125 + if !already_has { 126 + // Move the temporary blob to permanent storage 127 + self.move_object(MoveObject { 128 + from: self.get_tmp_path(&key), 129 + to: self.get_stored_path(cid), 130 + }) 131 + .await?; 132 + debug!("Moved temp blob to permanent: {} -> {}", key, cid); 133 + } else { 134 + // Already saved, so just delete the temp 135 + let temp_path = self.get_tmp_path(&key); 136 + if temp_path.exists() { 137 + async_fs::remove_file(temp_path).await?; 138 + debug!("Deleted temp blob as permanent already exists: {}", key); 139 + } 140 + } 141 + 142 + Ok(()) 143 + } 144 + 145 + /// Store a blob directly as permanent 146 + pub async fn put_permanent(&self, cid: Cid, bytes: Bytes) -> Result<()> { 147 + let target_path = self.get_stored_path(cid); 148 + 149 + // Ensure the directory exists 150 + if let Some(parent) = target_path.parent() { 151 + async_fs::create_dir_all(parent).await?; 152 + } 153 + 154 + // Write the blob 155 + let mut file = async_fs::File::create(&target_path).await?; 156 + file.write_all(&bytes).await?; 157 + file.flush().await?; 158 + 159 + debug!("Stored permanent blob: {}", cid); 160 + Ok(()) 161 + } 162 + 163 + /// Quarantine a blob by moving it to the quarantine area 164 + pub async fn quarantine(&self, cid: Cid) -> Result<()> { 165 + self.move_object(MoveObject { 166 + from: self.get_stored_path(cid), 167 + to: self.get_quarantined_path(cid), 168 + }) 169 + .await?; 170 + 171 + debug!("Quarantined blob: {}", cid); 172 + Ok(()) 173 + } 174 + 175 + /// Unquarantine a blob by moving it back to regular storage 176 + pub async fn unquarantine(&self, cid: Cid) -> Result<()> { 177 + self.move_object(MoveObject { 178 + from: self.get_quarantined_path(cid), 179 + to: self.get_stored_path(cid), 180 + }) 181 + .await?; 182 + 183 + debug!("Unquarantined blob: {}", cid); 184 + Ok(()) 185 + } 186 + 187 + /// Get a blob as a stream 188 + async fn get_object(&self, cid: Cid) -> Result<ByteStream> { 189 + let blob_path = self.get_stored_path(cid); 190 + 191 + match async_fs::read(&blob_path).await { 192 + Ok(bytes) => Ok(ByteStream::new(Bytes::from(bytes))), 193 + Err(e) => { 194 + error!("Failed to read blob at path {:?}: {}", blob_path, e); 195 + Err(anyhow::Error::new(BlobError::BlobNotFoundError)) 196 + } 197 + } 198 + } 199 + 200 + /// Get blob bytes 201 + pub async fn get_bytes(&self, cid: Cid) -> Result<Bytes> { 202 + let stream = self.get_object(cid).await?; 203 + stream.collect().await 204 + } 205 + 206 + /// Get a blob as a stream 207 + pub async fn get_stream(&self, cid: Cid) -> Result<ByteStream> { 208 + self.get_object(cid).await 209 + } 210 + 211 + /// Delete a blob by CID string 212 + pub async fn delete(&self, cid_str: String) -> Result<()> { 213 + match Cid::from_str(&cid_str) { 214 + Ok(cid) => self.delete_path(self.get_stored_path(cid)).await, 215 + Err(e) => { 216 + warn!("Invalid CID: {} - {}", cid_str, e); 217 + Err(anyhow::anyhow!("Invalid CID: {}", e)) 218 + } 219 + } 220 + } 221 + 222 + /// Delete multiple blobs by CID 223 + pub async fn delete_many(&self, cids: Vec<Cid>) -> Result<()> { 224 + let mut futures = Vec::with_capacity(cids.len()); 225 + 226 + for cid in cids { 227 + futures.push(self.delete_path(self.get_stored_path(cid))); 228 + } 229 + 230 + // Execute all delete operations concurrently 231 + let results = futures::future::join_all(futures).await; 232 + 233 + // Count errors but don't fail the operation 234 + let error_count = results.iter().filter(|r| r.is_err()).count(); 235 + if error_count > 0 { 236 + warn!( 237 + "{} errors occurred while deleting {} blobs", 238 + error_count, 239 + results.len() 240 + ); 241 + } 242 + 243 + Ok(()) 244 + } 245 + 246 + /// Check if a blob is stored in the regular storage 247 + pub async fn has_stored(&self, cid: Cid) -> Result<bool> { 248 + let blob_path = self.get_stored_path(cid); 249 + Ok(blob_path.exists()) 250 + } 251 + 252 + /// Check if a temporary blob exists 253 + pub async fn has_temp(&self, key: String) -> Result<bool> { 254 + let temp_path = self.get_tmp_path(&key); 255 + Ok(temp_path.exists()) 256 + } 257 + 258 + /// Helper function to delete a file at the given path 259 + async fn delete_path(&self, path: PathBuf) -> Result<()> { 260 + if path.exists() { 261 + async_fs::remove_file(&path).await?; 262 + debug!("Deleted file at: {:?}", path); 263 + Ok(()) 264 + } else { 265 + Err(anyhow::Error::new(BlobError::BlobNotFoundError)) 266 + } 267 + } 268 + 269 + /// Move a blob from one path to another 270 + async fn move_object(&self, mov: MoveObject) -> Result<()> { 271 + // Ensure the source exists 272 + if !mov.from.exists() { 273 + return Err(anyhow::Error::new(BlobError::BlobNotFoundError)); 274 + } 275 + 276 + // Ensure the target directory exists 277 + if let Some(parent) = mov.to.parent() { 278 + async_fs::create_dir_all(parent).await?; 279 + } 280 + 281 + // Move the file 282 + async_fs::rename(&mov.from, &mov.to).await?; 283 + 284 + debug!("Moved blob: {:?} -> {:?}", mov.from, mov.to); 285 + Ok(()) 286 + } 287 + }
+124 -73
src/actor_store/mod.rs
··· 7 7 //! Modified for SQLite backend 8 8 9 9 mod blob; 10 + pub(crate) mod blob_fs; 10 11 mod preference; 11 12 mod record; 12 13 pub(crate) mod sql_blob; ··· 33 34 use tokio::sync::RwLock; 34 35 35 36 use blob::BlobReader; 37 + use blob_fs::BlobStoreFs; 36 38 use preference::PreferenceReader; 37 39 use record::RecordReader; 38 - use sql_blob::BlobStoreSql; 39 40 use sql_repo::SqlRepoReader; 41 + 42 + use crate::serve::ActorStorage; 40 43 41 44 #[derive(Debug)] 42 45 enum FormatCommitError { ··· 71 74 72 75 // Combination of RepoReader/Transactor, BlobReader/Transactor, SqlRepoReader/Transactor 73 76 impl ActorStore { 74 - /// Concrete reader of an individual repo (hence BlobStoreSql which takes `did` param) 77 + /// Concrete reader of an individual repo (hence BlobStoreFs which takes `did` param) 75 78 pub fn new( 76 79 did: String, 77 - blobstore: BlobStoreSql, 80 + blobstore: BlobStoreFs, 78 81 db: deadpool_diesel::Pool< 79 82 deadpool_diesel::Manager<SqliteConnection>, 80 83 deadpool_diesel::sqlite::Object, 81 84 >, 82 85 conn: deadpool_diesel::sqlite::Object, 83 86 ) -> Self { 84 - ActorStore { 87 + Self { 85 88 storage: Arc::new(RwLock::new(SqlRepoReader::new(did.clone(), None, conn))), 86 89 record: RecordReader::new(did.clone(), db.clone()), 87 90 pref: PreferenceReader::new(did.clone(), db.clone()), 88 91 did, 89 - blob: BlobReader::new(blobstore, db.clone()), 92 + blob: BlobReader::new(blobstore, db), 90 93 } 91 94 } 92 95 96 + /// Create a new ActorStore taking ActorPools HashMap as input 97 + pub async fn from_actor_pools( 98 + did: &String, 99 + hashmap_actor_pools: &std::collections::HashMap<String, ActorStorage>, 100 + ) -> Self { 101 + let actor_pool = hashmap_actor_pools 102 + .get(did) 103 + .expect("Actor pool not found") 104 + .clone(); 105 + let blobstore = BlobStoreFs::new(did.clone(), actor_pool.blob); 106 + let conn = actor_pool 107 + .repo 108 + .clone() 109 + .get() 110 + .await 111 + .expect("Failed to get connection"); 112 + Self::new(did.clone(), blobstore, actor_pool.repo, conn) 113 + } 114 + 93 115 pub async fn get_repo_root(&self) -> Option<Cid> { 94 116 let storage_guard = self.storage.read().await; 95 117 storage_guard.get_root().await ··· 124 146 Some(write_ops), 125 147 ) 126 148 .await?; 127 - let storage_guard = self.storage.read().await; 128 - storage_guard.apply_commit(commit.clone(), None).await?; 149 + self.storage 150 + .read() 151 + .await 152 + .apply_commit(commit.clone(), None) 153 + .await?; 129 154 let writes = writes 130 155 .into_iter() 131 156 .map(PreparedWrite::Create) ··· 159 184 Some(write_ops), 160 185 ) 161 186 .await?; 162 - let storage_guard = self.storage.read().await; 163 - storage_guard.apply_commit(commit.clone(), None).await?; 187 + self.storage 188 + .read() 189 + .await 190 + .apply_commit(commit.clone(), None) 191 + .await?; 164 192 let write_commit_ops = writes.iter().try_fold( 165 193 Vec::with_capacity(writes.len()), 166 194 |mut acc, w| -> Result<Vec<CommitOp>> { ··· 168 196 acc.push(CommitOp { 169 197 action: CommitAction::Create, 170 198 path: format_data_key(aturi.get_collection(), aturi.get_rkey()), 171 - cid: Some(w.cid.clone()), 199 + cid: Some(w.cid), 172 200 prev: None, 173 201 }); 174 202 Ok(acc) ··· 199 227 .await?; 200 228 } 201 229 // persist the commit to repo storage 202 - let storage_guard = self.storage.read().await; 203 - storage_guard.apply_commit(commit.clone(), None).await?; 230 + self.storage 231 + .read() 232 + .await 233 + .apply_commit(commit.clone(), None) 234 + .await?; 204 235 // process blobs 205 236 self.blob.process_write_blobs(writes).await?; 206 237 Ok(()) ··· 226 257 .await?; 227 258 } 228 259 // persist the commit to repo storage 229 - let storage_guard = self.storage.read().await; 230 - storage_guard 260 + self.storage 261 + .read() 262 + .await 231 263 .apply_commit(commit.commit_data.clone(), None) 232 264 .await?; 233 265 // process blobs ··· 236 268 } 237 269 238 270 pub async fn get_sync_event_data(&mut self) -> Result<SyncEvtData> { 239 - let storage_guard = self.storage.read().await; 240 - let current_root = storage_guard.get_root_detailed().await?; 241 - let blocks_and_missing = storage_guard.get_blocks(vec![current_root.cid]).await?; 271 + let current_root = self.storage.read().await.get_root_detailed().await?; 272 + let blocks_and_missing = self 273 + .storage 274 + .read() 275 + .await 276 + .get_blocks(vec![current_root.cid]) 277 + .await?; 242 278 Ok(SyncEvtData { 243 279 cid: current_root.cid, 244 280 rev: current_root.rev, ··· 264 300 } 265 301 } 266 302 { 267 - let mut storage_guard = self.storage.write().await; 268 - storage_guard.cache_rev(current_root.rev).await?; 303 + self.storage 304 + .write() 305 + .await 306 + .cache_rev(current_root.rev) 307 + .await?; 269 308 } 270 309 let mut new_record_cids: Vec<Cid> = vec![]; 271 310 let mut delete_and_update_uris = vec![]; ··· 306 345 cid, 307 346 prev: None, 308 347 }; 309 - if let Some(_) = current_record { 348 + if current_record.is_some() { 310 349 op.prev = current_record; 311 350 }; 312 351 commit_ops.push(op); ··· 352 391 .collect::<Result<Vec<RecordWriteOp>>>()?; 353 392 // @TODO: Use repo signing key global config 354 393 let secp = Secp256k1::new(); 355 - let repo_private_key = env::var("PDS_REPO_SIGNING_KEY_K256_PRIVATE_KEY_HEX").unwrap(); 356 - let repo_secret_key = 357 - SecretKey::from_slice(&hex::decode(repo_private_key.as_bytes()).unwrap()).unwrap(); 394 + let repo_private_key = env::var("PDS_REPO_SIGNING_KEY_K256_PRIVATE_KEY_HEX") 395 + .expect("PDS_REPO_SIGNING_KEY_K256_PRIVATE_KEY_HEX not set"); 396 + let repo_secret_key = SecretKey::from_slice( 397 + &hex::decode(repo_private_key.as_bytes()).expect("Failed to decode hex"), 398 + ) 399 + .expect("Failed to create secret key from hex"); 358 400 let repo_signing_key = Keypair::from_secret_key(&secp, &repo_secret_key); 359 401 360 402 let mut commit = repo ··· 393 435 pub async fn index_writes(&self, writes: Vec<PreparedWrite>, rev: &str) -> Result<()> { 394 436 let now: &str = &rsky_common::now(); 395 437 396 - let _ = stream::iter(writes) 397 - .then(|write| async move { 398 - Ok::<(), anyhow::Error>(match write { 399 - PreparedWrite::Create(write) => { 400 - let write_at_uri: AtUri = write.uri.try_into()?; 401 - self.record 402 - .index_record( 403 - write_at_uri.clone(), 404 - write.cid, 405 - Some(write.record), 406 - Some(write.action), 407 - rev.to_owned(), 408 - Some(now.to_string()), 409 - ) 410 - .await? 411 - } 412 - PreparedWrite::Update(write) => { 413 - let write_at_uri: AtUri = write.uri.try_into()?; 414 - self.record 415 - .index_record( 416 - write_at_uri.clone(), 417 - write.cid, 418 - Some(write.record), 419 - Some(write.action), 420 - rev.to_owned(), 421 - Some(now.to_string()), 422 - ) 423 - .await? 424 - } 425 - PreparedWrite::Delete(write) => { 426 - let write_at_uri: AtUri = write.uri.try_into()?; 427 - self.record.delete_record(&write_at_uri).await? 438 + drop( 439 + stream::iter(writes) 440 + .then(async move |write| { 441 + match write { 442 + PreparedWrite::Create(write) => { 443 + let write_at_uri: AtUri = write.uri.try_into()?; 444 + self.record 445 + .index_record( 446 + write_at_uri.clone(), 447 + write.cid, 448 + Some(write.record), 449 + Some(write.action), 450 + rev.to_owned(), 451 + Some(now.to_owned()), 452 + ) 453 + .await?; 454 + } 455 + PreparedWrite::Update(write) => { 456 + let write_at_uri: AtUri = write.uri.try_into()?; 457 + self.record 458 + .index_record( 459 + write_at_uri.clone(), 460 + write.cid, 461 + Some(write.record), 462 + Some(write.action), 463 + rev.to_owned(), 464 + Some(now.to_owned()), 465 + ) 466 + .await?; 467 + } 468 + PreparedWrite::Delete(write) => { 469 + let write_at_uri: AtUri = write.uri.try_into()?; 470 + self.record.delete_record(&write_at_uri).await?; 471 + } 428 472 } 473 + Ok::<(), anyhow::Error>(()) 429 474 }) 430 - }) 431 - .collect::<Vec<_>>() 432 - .await 433 - .into_iter() 434 - .collect::<Result<Vec<_>, _>>()?; 475 + .collect::<Vec<_>>() 476 + .await 477 + .into_iter() 478 + .collect::<Result<Vec<_>, _>>()?, 479 + ); 435 480 Ok(()) 436 481 } 437 482 438 483 pub async fn destroy(&mut self) -> Result<()> { 439 484 let did: String = self.did.clone(); 440 - let storage_guard = self.storage.read().await; 441 - use rsky_pds::schema::pds::blob::dsl as BlobSchema; 485 + use crate::schema::actor_store::blob::dsl as BlobSchema; 442 486 443 - let blob_rows: Vec<String> = storage_guard 487 + let blob_rows: Vec<String> = self 488 + .storage 489 + .read() 490 + .await 444 491 .db 445 492 .interact(move |conn| { 446 493 BlobSchema::blob ··· 454 501 .into_iter() 455 502 .map(|row| Ok(Cid::from_str(&row)?)) 456 503 .collect::<Result<Vec<Cid>>>()?; 457 - let _ = stream::iter(cids.chunks(500)) 458 - .then(|chunk| async { self.blob.blobstore.delete_many(chunk.to_vec()).await }) 459 - .collect::<Vec<_>>() 460 - .await 461 - .into_iter() 462 - .collect::<Result<Vec<_>, _>>()?; 504 + drop( 505 + stream::iter(cids.chunks(500)) 506 + .then(|chunk| async { self.blob.blobstore.delete_many(chunk.to_vec()).await }) 507 + .collect::<Vec<_>>() 508 + .await 509 + .into_iter() 510 + .collect::<Result<Vec<_>, _>>()?, 511 + ); 463 512 Ok(()) 464 513 } 465 514 ··· 472 521 return Ok(vec![]); 473 522 } 474 523 let did: String = self.did.clone(); 475 - let storage_guard = self.storage.read().await; 476 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 524 + use crate::schema::actor_store::record::dsl as RecordSchema; 477 525 478 526 let cid_strs: Vec<String> = cids.into_iter().map(|c| c.to_string()).collect(); 479 527 let touched_uri_strs: Vec<String> = touched_uris.iter().map(|t| t.to_string()).collect(); 480 - let res: Vec<String> = storage_guard 528 + let res: Vec<String> = self 529 + .storage 530 + .read() 531 + .await 481 532 .db 482 533 .interact(move |conn| { 483 534 RecordSchema::record ··· 490 541 .await 491 542 .expect("Failed to get duplicate record cids")?; 492 543 res.into_iter() 493 - .map(|row| Cid::from_str(&row).map_err(|error| anyhow::Error::new(error))) 544 + .map(|row| Cid::from_str(&row).map_err(anyhow::Error::new)) 494 545 .collect::<Result<Vec<Cid>>>() 495 546 } 496 547 }
+13 -15
src/actor_store/preference.rs
··· 4 4 //! 5 5 //! Modified for SQLite backend 6 6 7 + use crate::models::actor_store::AccountPref; 7 8 use anyhow::{Result, bail}; 8 9 use diesel::*; 9 10 use rsky_lexicon::app::bsky::actor::RefPreferences; 10 11 use rsky_pds::actor_store::preference::pref_match_namespace; 11 12 use rsky_pds::actor_store::preference::util::pref_in_scope; 12 13 use rsky_pds::auth_verifier::AuthScope; 13 - use rsky_pds::models::AccountPref; 14 14 15 15 pub struct PreferenceReader { 16 16 pub did: String, ··· 21 21 } 22 22 23 23 impl PreferenceReader { 24 - pub fn new( 24 + pub const fn new( 25 25 did: String, 26 26 db: deadpool_diesel::Pool< 27 27 deadpool_diesel::Manager<SqliteConnection>, 28 28 deadpool_diesel::sqlite::Object, 29 29 >, 30 30 ) -> Self { 31 - PreferenceReader { did, db } 31 + Self { did, db } 32 32 } 33 33 34 34 pub async fn get_preferences( ··· 36 36 namespace: Option<String>, 37 37 scope: AuthScope, 38 38 ) -> Result<Vec<RefPreferences>> { 39 - use rsky_pds::schema::pds::account_pref::dsl as AccountPrefSchema; 39 + use crate::schema::actor_store::account_pref::dsl as AccountPrefSchema; 40 40 41 41 let did = self.did.clone(); 42 42 self.db ··· 50 50 .load(conn)?; 51 51 let account_prefs = prefs_res 52 52 .into_iter() 53 - .filter(|pref| match &namespace { 54 - None => true, 55 - Some(namespace) => pref_match_namespace(namespace, &pref.name), 53 + .filter(|pref| { 54 + namespace 55 + .as_ref() 56 + .is_none_or(|namespace| pref_match_namespace(namespace, &pref.name)) 56 57 }) 57 58 .filter(|pref| pref_in_scope(scope.clone(), pref.name.clone())) 58 59 .map(|pref| { ··· 88 89 { 89 90 false => bail!("Some preferences are not in the {namespace} namespace"), 90 91 true => { 91 - let not_in_scope = values 92 - .iter() 93 - .filter(|value| !pref_in_scope(scope.clone(), value.get_type())) 94 - .collect::<Vec<&RefPreferences>>(); 95 - if !not_in_scope.is_empty() { 92 + if values 93 + .iter().any(|value| !pref_in_scope(scope.clone(), value.get_type())) { 96 94 tracing::info!( 97 95 "@LOG: PreferenceReader::put_preferences() debug scope: {:?}, values: {:?}", 98 96 scope, ··· 101 99 bail!("Do not have authorization to set preferences."); 102 100 } 103 101 // get all current prefs for user and prep new pref rows 104 - use rsky_pds::schema::pds::account_pref::dsl as AccountPrefSchema; 102 + use crate::schema::actor_store::account_pref::dsl as AccountPrefSchema; 105 103 let all_prefs = AccountPrefSchema::account_pref 106 104 .filter(AccountPrefSchema::did.eq(&did)) 107 105 .select(AccountPref::as_select()) ··· 125 123 .collect::<Vec<i32>>(); 126 124 // replace all prefs in given namespace 127 125 if !all_pref_ids_in_namespace.is_empty() { 128 - delete(AccountPrefSchema::account_pref) 126 + _ = delete(AccountPrefSchema::account_pref) 129 127 .filter(AccountPrefSchema::id.eq_any(all_pref_ids_in_namespace)) 130 128 .execute(conn)?; 131 129 } 132 130 if !put_prefs.is_empty() { 133 - insert_into(AccountPrefSchema::account_pref) 131 + _ = insert_into(AccountPrefSchema::account_pref) 134 132 .values( 135 133 put_prefs 136 134 .into_iter()
+103 -65
src/actor_store/record.rs
··· 4 4 //! 5 5 //! Modified for SQLite backend 6 6 7 + use crate::models::actor_store::{Backlink, Record, RepoBlock}; 7 8 use anyhow::{Result, bail}; 8 9 use cidv10::Cid; 9 10 use diesel::result::Error; 10 11 use diesel::*; 11 12 use futures::stream::{self, StreamExt}; 12 13 use rsky_lexicon::com::atproto::admin::StatusAttr; 13 - use rsky_pds::actor_store::record::{GetRecord, RecordsForCollection, get_backlinks}; 14 - use rsky_pds::models::{Backlink, Record, RepoBlock}; 15 - use rsky_repo::types::{RepoRecord, WriteOpAction}; 14 + use rsky_pds::actor_store::record::{GetRecord, RecordsForCollection}; 15 + use rsky_repo::storage::Ipld; 16 + use rsky_repo::types::{Ids, Lex, RepoRecord, WriteOpAction}; 16 17 use rsky_repo::util::cbor_to_lex_record; 17 18 use rsky_syntax::aturi::AtUri; 19 + use rsky_syntax::aturi_validation::ensure_valid_at_uri; 20 + use rsky_syntax::did::ensure_valid_did; 21 + use serde_json::Value as JsonValue; 18 22 use std::env; 19 23 use std::str::FromStr; 20 24 25 + // @NOTE in the future this can be replaced with a more generic routine that pulls backlinks based on lex docs. 26 + // For now, we just want to ensure we're tracking links from follows, blocks, likes, and reposts. 27 + pub fn get_backlinks(uri: &AtUri, record: &RepoRecord) -> Result<Vec<Backlink>> { 28 + if let Some(Lex::Ipld(Ipld::Json(JsonValue::String(record_type)))) = record.get("$type") { 29 + if record_type == Ids::AppBskyGraphFollow.as_str() 30 + || record_type == Ids::AppBskyGraphBlock.as_str() 31 + { 32 + if let Some(Lex::Ipld(Ipld::Json(JsonValue::String(subject)))) = record.get("subject") { 33 + match ensure_valid_did(uri) { 34 + Ok(_) => { 35 + return Ok(vec![Backlink { 36 + uri: uri.to_string(), 37 + path: "subject".to_owned(), 38 + link_to: subject.clone(), 39 + }]); 40 + } 41 + Err(e) => bail!("get_backlinks Error: invalid did {}", e), 42 + }; 43 + } 44 + } else if record_type == Ids::AppBskyFeedLike.as_str() 45 + || record_type == Ids::AppBskyFeedRepost.as_str() 46 + { 47 + if let Some(Lex::Map(ref_object)) = record.get("subject") { 48 + if let Some(Lex::Ipld(Ipld::Json(JsonValue::String(subject_uri)))) = 49 + ref_object.get("uri") 50 + { 51 + match ensure_valid_at_uri(uri) { 52 + Ok(_) => { 53 + return Ok(vec![Backlink { 54 + uri: uri.to_string(), 55 + path: "subject.uri".to_owned(), 56 + link_to: subject_uri.clone(), 57 + }]); 58 + } 59 + Err(e) => bail!("get_backlinks Error: invalid AtUri {}", e), 60 + }; 61 + } 62 + } 63 + } 64 + } 65 + Ok(Vec::new()) 66 + } 67 + 21 68 /// Combined handler for record operations with both read and write capabilities. 22 69 pub(crate) struct RecordReader { 23 70 /// Database connection. ··· 31 78 32 79 impl RecordReader { 33 80 /// Create a new record handler. 34 - pub(crate) fn new( 81 + pub(crate) const fn new( 35 82 did: String, 36 83 db: deadpool_diesel::Pool< 37 84 deadpool_diesel::Manager<SqliteConnection>, ··· 43 90 44 91 /// Count the total number of records. 45 92 pub(crate) async fn record_count(&mut self) -> Result<i64> { 46 - use rsky_pds::schema::pds::record::dsl::*; 93 + use crate::schema::actor_store::record::dsl::*; 47 94 48 95 let other_did = self.did.clone(); 49 96 self.db ··· 59 106 60 107 /// List all collections in the repository. 61 108 pub(crate) async fn list_collections(&self) -> Result<Vec<String>> { 62 - use rsky_pds::schema::pds::record::dsl::*; 109 + use crate::schema::actor_store::record::dsl::*; 63 110 64 111 let other_did = self.did.clone(); 65 112 self.db ··· 90 137 rkey_end: Option<String>, 91 138 include_soft_deleted: Option<bool>, 92 139 ) -> Result<Vec<RecordsForCollection>> { 93 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 94 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 140 + use crate::schema::actor_store::record::dsl as RecordSchema; 141 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 95 142 96 - let include_soft_deleted: bool = if let Some(include_soft_deleted) = include_soft_deleted { 97 - include_soft_deleted 98 - } else { 99 - false 100 - }; 143 + let include_soft_deleted: bool = include_soft_deleted.unwrap_or(false); 101 144 let mut builder = RecordSchema::record 102 145 .inner_join(RepoBlockSchema::repo_block.on(RepoBlockSchema::cid.eq(RecordSchema::cid))) 103 146 .limit(limit) ··· 153 196 cid: Option<String>, 154 197 include_soft_deleted: Option<bool>, 155 198 ) -> Result<Option<GetRecord>> { 156 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 157 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 199 + use crate::schema::actor_store::record::dsl as RecordSchema; 200 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 158 201 159 - let include_soft_deleted: bool = if let Some(include_soft_deleted) = include_soft_deleted { 160 - include_soft_deleted 161 - } else { 162 - false 163 - }; 202 + let include_soft_deleted: bool = include_soft_deleted.unwrap_or(false); 164 203 let mut builder = RecordSchema::record 165 204 .inner_join(RepoBlockSchema::repo_block.on(RepoBlockSchema::cid.eq(RecordSchema::cid))) 166 205 .select((Record::as_select(), RepoBlock::as_select())) ··· 199 238 cid: Option<String>, 200 239 include_soft_deleted: Option<bool>, 201 240 ) -> Result<bool> { 202 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 241 + use crate::schema::actor_store::record::dsl as RecordSchema; 203 242 204 - let include_soft_deleted: bool = if let Some(include_soft_deleted) = include_soft_deleted { 205 - include_soft_deleted 206 - } else { 207 - false 208 - }; 243 + let include_soft_deleted: bool = include_soft_deleted.unwrap_or(false); 209 244 let mut builder = RecordSchema::record 210 245 .select(RecordSchema::uri) 211 246 .filter(RecordSchema::uri.eq(uri)) ··· 223 258 .interact(move |conn| builder.first::<String>(conn).optional()) 224 259 .await 225 260 .expect("Failed to check record")?; 226 - Ok(!!record_uri.is_some()) 261 + Ok(record_uri.is_some()) 227 262 } 228 263 229 264 /// Get the takedown status of a record. ··· 231 266 &self, 232 267 uri: String, 233 268 ) -> Result<Option<StatusAttr>> { 234 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 269 + use crate::schema::actor_store::record::dsl as RecordSchema; 235 270 236 271 let res = self 237 272 .db ··· 246 281 }) 247 282 .await 248 283 .expect("Failed to get takedown status")?; 249 - if let Some(res) = res { 250 - if let Some(takedown_ref) = res { 251 - Ok(Some(StatusAttr { 252 - applied: true, 253 - r#ref: Some(takedown_ref), 254 - })) 255 - } else { 256 - Ok(Some(StatusAttr { 257 - applied: false, 258 - r#ref: None, 259 - })) 260 - } 261 - } else { 262 - Ok(None) 263 - } 284 + res.map_or_else( 285 + || Ok(None), 286 + |res| { 287 + res.map_or_else( 288 + || { 289 + Ok(Some(StatusAttr { 290 + applied: false, 291 + r#ref: None, 292 + })) 293 + }, 294 + |takedown_ref| { 295 + Ok(Some(StatusAttr { 296 + applied: true, 297 + r#ref: Some(takedown_ref), 298 + })) 299 + }, 300 + ) 301 + }, 302 + ) 264 303 } 265 304 266 305 /// Get the current CID for a record URI. 267 306 pub(crate) async fn get_current_record_cid(&self, uri: String) -> Result<Option<Cid>> { 268 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 307 + use crate::schema::actor_store::record::dsl as RecordSchema; 269 308 270 309 let res = self 271 310 .db ··· 294 333 path: String, 295 334 link_to: String, 296 335 ) -> Result<Vec<Record>> { 297 - use rsky_pds::schema::pds::backlink::dsl as BacklinkSchema; 298 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 336 + use crate::schema::actor_store::backlink::dsl as BacklinkSchema; 337 + use crate::schema::actor_store::record::dsl as RecordSchema; 299 338 300 339 let res = self 301 340 .db ··· 373 412 let rkey = uri.get_rkey(); 374 413 let hostname = uri.get_hostname().to_string(); 375 414 let action = action.unwrap_or(WriteOpAction::Create); 376 - let indexed_at = timestamp.unwrap_or_else(|| rsky_common::now()); 415 + let indexed_at = timestamp.unwrap_or_else(rsky_common::now); 377 416 let row = Record { 378 417 did: self.did.clone(), 379 418 uri: uri.to_string(), ··· 393 432 bail!("Expected indexed URI to contain a record key") 394 433 } 395 434 396 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 435 + use crate::schema::actor_store::record::dsl as RecordSchema; 397 436 398 437 // Track current version of record 399 438 let (record, uri) = self ··· 401 440 .get() 402 441 .await? 403 442 .interact(move |conn| { 404 - insert_into(RecordSchema::record) 443 + _ = insert_into(RecordSchema::record) 405 444 .values(row) 406 445 .on_conflict(RecordSchema::uri) 407 446 .do_update() ··· 419 458 if let Some(record) = record { 420 459 // Maintain backlinks 421 460 let backlinks = get_backlinks(&uri, &record)?; 422 - if let WriteOpAction::Update = action { 461 + if action == WriteOpAction::Update { 423 462 // On update just recreate backlinks from scratch for the record, so we can clear out 424 463 // the old ones. E.g. for weird cases like updating a follow to be for a different did. 425 464 self.remove_backlinks_by_uri(&uri).await?; ··· 434 473 #[tracing::instrument(skip_all)] 435 474 pub(crate) async fn delete_record(&self, uri: &AtUri) -> Result<()> { 436 475 tracing::debug!("@LOG DEBUG RecordReader::delete_record, deleting indexed record {uri}"); 437 - use rsky_pds::schema::pds::backlink::dsl as BacklinkSchema; 438 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 476 + use crate::schema::actor_store::backlink::dsl as BacklinkSchema; 477 + use crate::schema::actor_store::record::dsl as RecordSchema; 439 478 let uri = uri.to_string(); 440 479 self.db 441 480 .get() 442 481 .await? 443 482 .interact(move |conn| { 444 - delete(RecordSchema::record) 483 + _ = delete(RecordSchema::record) 445 484 .filter(RecordSchema::uri.eq(&uri)) 446 485 .execute(conn)?; 447 - delete(BacklinkSchema::backlink) 486 + _ = delete(BacklinkSchema::backlink) 448 487 .filter(BacklinkSchema::uri.eq(&uri)) 449 488 .execute(conn)?; 450 489 tracing::debug!( ··· 458 497 459 498 /// Remove backlinks for a URI. 460 499 pub(crate) async fn remove_backlinks_by_uri(&self, uri: &AtUri) -> Result<()> { 461 - use rsky_pds::schema::pds::backlink::dsl as BacklinkSchema; 500 + use crate::schema::actor_store::backlink::dsl as BacklinkSchema; 462 501 let uri = uri.to_string(); 463 502 self.db 464 503 .get() 465 504 .await? 466 505 .interact(move |conn| { 467 - delete(BacklinkSchema::backlink) 506 + _ = delete(BacklinkSchema::backlink) 468 507 .filter(BacklinkSchema::uri.eq(uri)) 469 508 .execute(conn)?; 470 509 Ok(()) ··· 475 514 476 515 /// Add backlinks to the database. 477 516 pub(crate) async fn add_backlinks(&self, backlinks: Vec<Backlink>) -> Result<()> { 478 - if backlinks.len() == 0 { 517 + if backlinks.is_empty() { 479 518 Ok(()) 480 519 } else { 481 - use rsky_pds::schema::pds::backlink::dsl as BacklinkSchema; 520 + use crate::schema::actor_store::backlink::dsl as BacklinkSchema; 482 521 self.db 483 522 .get() 484 523 .await? 485 524 .interact(move |conn| { 486 - insert_or_ignore_into(BacklinkSchema::backlink) 525 + _ = insert_or_ignore_into(BacklinkSchema::backlink) 487 526 .values(&backlinks) 488 527 .execute(conn)?; 489 528 Ok(()) ··· 499 538 uri: &AtUri, 500 539 takedown: StatusAttr, 501 540 ) -> Result<()> { 502 - use rsky_pds::schema::pds::record::dsl as RecordSchema; 541 + use crate::schema::actor_store::record::dsl as RecordSchema; 503 542 504 543 let takedown_ref: Option<String> = match takedown.applied { 505 - true => match takedown.r#ref { 506 - Some(takedown_ref) => Some(takedown_ref), 507 - None => Some(rsky_common::now()), 508 - }, 544 + true => takedown 545 + .r#ref 546 + .map_or_else(|| Some(rsky_common::now()), Some), 509 547 false => None, 510 548 }; 511 549 let uri_string = uri.to_string(); ··· 514 552 .get() 515 553 .await? 516 554 .interact(move |conn| { 517 - update(RecordSchema::record) 555 + _ = update(RecordSchema::record) 518 556 .filter(RecordSchema::uri.eq(uri_string)) 519 557 .set(RecordSchema::takedownRef.eq(takedown_ref)) 520 558 .execute(conn)?;
+30 -23
src/actor_store/sql_blob.rs
··· 2 2 #![expect( 3 3 clippy::pub_use, 4 4 clippy::single_char_lifetime_names, 5 - unused_qualifications 5 + unused_qualifications, 6 + unnameable_types 6 7 )] 7 8 use anyhow::{Context, Result}; 8 9 use cidv10::Cid; ··· 14 15 } 15 16 16 17 impl ByteStream { 17 - pub fn new(bytes: Vec<u8>) -> Self { 18 + pub const fn new(bytes: Vec<u8>) -> Self { 18 19 Self { bytes } 19 20 } 20 21 ··· 60 61 61 62 impl BlobStoreSql { 62 63 /// Create a new SQL-based blob store for the given DID 63 - pub fn new( 64 + pub const fn new( 64 65 did: String, 65 66 db: deadpool_diesel::Pool< 66 67 deadpool_diesel::Manager<SqliteConnection>, 67 68 deadpool_diesel::sqlite::Object, 68 69 >, 69 70 ) -> Self { 70 - BlobStoreSql { db, did } 71 + Self { db, did } 71 72 } 72 73 73 74 // /// Create a factory function for blob stores 74 - // pub fn creator( 75 - // db: deadpool_diesel::Pool< 76 - // deadpool_diesel::Manager<SqliteConnection>, 77 - // deadpool_diesel::sqlite::Object, 78 - // >, 79 - // ) -> Box<dyn Fn(String) -> BlobStoreSql> { 80 - // let db_clone = db.clone(); 81 - // Box::new(move |did: String| BlobStoreSql::new(did, db_clone.clone())) 82 - // } 75 + pub fn creator( 76 + db: deadpool_diesel::Pool< 77 + deadpool_diesel::Manager<SqliteConnection>, 78 + deadpool_diesel::sqlite::Object, 79 + >, 80 + ) -> Box<dyn Fn(String) -> BlobStoreSql> { 81 + let db_clone = db.clone(); 82 + Box::new(move |did: String| BlobStoreSql::new(did, db_clone.clone())) 83 + } 83 84 84 85 /// Store a blob temporarily - now just stores permanently with a key returned for API compatibility 85 86 pub async fn put_temp(&self, bytes: Vec<u8>) -> Result<String> { 86 87 // Generate a unique key as a CID based on the data 87 - use sha2::{Digest, Sha256}; 88 - let digest = Sha256::digest(&bytes); 89 - let key = hex::encode(digest); 88 + // use sha2::{Digest, Sha256}; 89 + // let digest = Sha256::digest(&bytes); 90 + // let key = hex::encode(digest); 91 + let key = rsky_common::get_random_str(); 90 92 91 93 // Just store the blob directly 92 94 self.put_permanent_with_mime( 93 95 Cid::try_from(format!("bafy{}", key)).unwrap_or_else(|_| Cid::default()), 94 96 bytes, 95 - "application/octet-stream".to_string(), 97 + "application/octet-stream".to_owned(), 96 98 ) 97 99 .await?; 98 100 ··· 118 120 let bytes_len = bytes.len() as i32; 119 121 120 122 // Store directly in the database 121 - self.db 123 + _ = self 124 + .db 122 125 .get() 123 126 .await? 124 127 .interact(move |conn| { ··· 148 151 149 152 /// Store a blob directly as permanent 150 153 pub async fn put_permanent(&self, cid: Cid, bytes: Vec<u8>) -> Result<()> { 151 - self.put_permanent_with_mime(cid, bytes, "application/octet-stream".to_string()) 154 + self.put_permanent_with_mime(cid, bytes, "application/octet-stream".to_owned()) 152 155 .await 153 156 } 154 157 ··· 158 161 let did_clone = self.did.clone(); 159 162 160 163 // Update the quarantine flag in the database 161 - self.db 164 + _ = self 165 + .db 162 166 .get() 163 167 .await? 164 168 .interact(move |conn| { ··· 181 185 let did_clone = self.did.clone(); 182 186 183 187 // Update the quarantine flag in the database 184 - self.db 188 + _ = self 189 + .db 185 190 .get() 186 191 .await? 187 192 .interact(move |conn| { ··· 248 253 let did_clone = self.did.clone(); 249 254 250 255 // Delete from database 251 - self.db 256 + _ = self 257 + .db 252 258 .get() 253 259 .await? 254 260 .interact(move |conn| { ··· 272 278 let did_clone = self.did.clone(); 273 279 274 280 // Delete all blobs in one operation 275 - self.db 281 + _ = self 282 + .db 276 283 .get() 277 284 .await? 278 285 .interact(move |conn| {
+28 -22
src/actor_store/sql_repo.rs
··· 3 3 //! 4 4 //! Modified for SQLite backend 5 5 6 + use crate::models::actor_store as models; 7 + use crate::models::actor_store::RepoBlock; 6 8 use anyhow::Result; 7 9 use cidv10::Cid; 8 10 use diesel::dsl::sql; ··· 10 12 use diesel::sql_types::{Bool, Text}; 11 13 use diesel::*; 12 14 use futures::{StreamExt, TryStreamExt, stream}; 13 - use rsky_pds::models; 14 - use rsky_pds::models::RepoBlock; 15 15 use rsky_repo::block_map::{BlockMap, BlocksAndMissing}; 16 16 use rsky_repo::car::blocks_to_car_file; 17 17 use rsky_repo::cid_set::CidSet; ··· 50 50 cid: &'life Cid, 51 51 ) -> Pin<Box<dyn Future<Output = Result<Option<Vec<u8>>>> + Send + Sync + 'life>> { 52 52 let did: String = self.did.clone(); 53 - let cid = cid.clone(); 53 + let cid = *cid; 54 54 55 55 Box::pin(async move { 56 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 56 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 57 57 let cached = { 58 58 let cache_guard = self.cache.read().await; 59 - cache_guard.get(cid).map(|v| v.clone()) 59 + cache_guard.get(cid).cloned() 60 60 }; 61 61 if let Some(cached_result) = cached { 62 - return Ok(Some(cached_result.clone())); 62 + return Ok(Some(cached_result)); 63 63 } 64 64 65 65 let found: Option<Vec<u8>> = self ··· 104 104 let did: String = self.did.clone(); 105 105 106 106 Box::pin(async move { 107 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 107 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 108 108 let cached = { 109 109 let mut cache_guard = self.cache.write().await; 110 110 cache_guard.get_many(cids)? ··· 120 120 let blocks = Arc::new(tokio::sync::Mutex::new(BlockMap::new())); 121 121 let missing_set = Arc::new(tokio::sync::Mutex::new(missing)); 122 122 123 - let _: Vec<_> = stream::iter(missing_strings.chunks(500)) 123 + let stream: Vec<_> = stream::iter(missing_strings.chunks(500)) 124 124 .then(|batch| { 125 125 let this_did = did.clone(); 126 126 let blocks = Arc::clone(&blocks); ··· 156 156 }) 157 157 .try_collect() 158 158 .await?; 159 + drop(stream); 159 160 160 161 // Extract values from synchronization primitives 161 162 let mut blocks = Arc::try_unwrap(blocks) ··· 201 202 let did: String = self.did.clone(); 202 203 let bytes_cloned = bytes.clone(); 203 204 Box::pin(async move { 204 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 205 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 205 206 206 - self.db 207 + _ = self 208 + .db 207 209 .interact(move |conn| { 208 210 insert_into(RepoBlockSchema::repo_block) 209 211 .values(( ··· 233 235 let did: String = self.did.clone(); 234 236 235 237 Box::pin(async move { 236 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 238 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 237 239 238 240 let blocks: Vec<RepoBlock> = to_put 239 241 .map ··· 251 253 blocks.chunks(50).map(|chunk| chunk.to_vec()).collect(); 252 254 253 255 for batch in chunks { 254 - self.db 256 + _ = self 257 + .db 255 258 .interact(move |conn| { 256 259 insert_or_ignore_into(RepoBlockSchema::repo_block) 257 260 .values(&batch) ··· 274 277 let now: String = self.now.clone(); 275 278 276 279 Box::pin(async move { 277 - use rsky_pds::schema::pds::repo_root::dsl as RepoRootSchema; 280 + use crate::schema::actor_store::repo_root::dsl as RepoRootSchema; 278 281 279 282 let is_create = is_create.unwrap_or(false); 280 283 if is_create { 281 - self.db 284 + _ = self 285 + .db 282 286 .interact(move |conn| { 283 287 insert_into(RepoRootSchema::repo_root) 284 288 .values(( ··· 292 296 .await 293 297 .expect("Failed to create root")?; 294 298 } else { 295 - self.db 299 + _ = self 300 + .db 296 301 .interact(move |conn| { 297 302 update(RepoRootSchema::repo_root) 298 303 .filter(RepoRootSchema::did.eq(did)) ··· 329 334 impl SqlRepoReader { 330 335 pub fn new(did: String, now: Option<String>, db: deadpool_diesel::sqlite::Object) -> Self { 331 336 let now = now.unwrap_or_else(rsky_common::now); 332 - SqlRepoReader { 337 + Self { 333 338 cache: Arc::new(RwLock::new(BlockMap::new())), 334 339 root: None, 335 340 rev: None, ··· 376 381 let did: String = self.did.clone(); 377 382 let since = since.clone(); 378 383 let cursor = cursor.clone(); 379 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 384 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 380 385 381 386 Ok(self 382 387 .db ··· 413 418 414 419 pub async fn count_blocks(&self) -> Result<i64> { 415 420 let did: String = self.did.clone(); 416 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 421 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 417 422 418 423 let res = self 419 424 .db ··· 434 439 /// Proactively cache all blocks from a particular commit (to prevent multiple roundtrips) 435 440 pub async fn cache_rev(&mut self, rev: String) -> Result<()> { 436 441 let did: String = self.did.clone(); 437 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 442 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 438 443 439 444 let result: Vec<(String, Vec<u8>)> = self 440 445 .db ··· 460 465 return Ok(()); 461 466 } 462 467 let did: String = self.did.clone(); 463 - use rsky_pds::schema::pds::repo_block::dsl as RepoBlockSchema; 468 + use crate::schema::actor_store::repo_block::dsl as RepoBlockSchema; 464 469 465 470 let cid_strings: Vec<String> = cids.into_iter().map(|c| c.to_string()).collect(); 466 - self.db 471 + _ = self 472 + .db 467 473 .interact(move |conn| { 468 474 delete(RepoBlockSchema::repo_block) 469 475 .filter(RepoBlockSchema::did.eq(did)) ··· 477 483 478 484 pub async fn get_root_detailed(&self) -> Result<CidAndRev> { 479 485 let did: String = self.did.clone(); 480 - use rsky_pds::schema::pds::repo_root::dsl as RepoRootSchema; 486 + use crate::schema::actor_store::repo_root::dsl as RepoRootSchema; 481 487 482 488 let res = self 483 489 .db
+245
src/apis/com/atproto/identity/identity.rs
··· 1 + //! Identity endpoints (/xrpc/com.atproto.identity.*) 2 + use std::collections::HashMap; 3 + 4 + use anyhow::{Context as _, anyhow}; 5 + use atrium_api::{ 6 + com::atproto::identity, 7 + types::string::{Datetime, Handle}, 8 + }; 9 + use atrium_crypto::keypair::Did as _; 10 + use atrium_repo::blockstore::{AsyncBlockStoreWrite as _, CarStore, DAG_CBOR, SHA2_256}; 11 + use axum::{ 12 + Json, Router, 13 + extract::{Query, State}, 14 + http::StatusCode, 15 + routing::{get, post}, 16 + }; 17 + use constcat::concat; 18 + 19 + use crate::{ 20 + AppState, Client, Db, Error, Result, RotationKey, SigningKey, 21 + auth::AuthenticatedUser, 22 + config::AppConfig, 23 + did, 24 + firehose::FirehoseProducer, 25 + plc::{self, PlcOperation, PlcService}, 26 + }; 27 + 28 + /// (GET) Resolves an atproto handle (hostname) to a DID. Does not necessarily bi-directionally verify against the the DID document. 29 + /// ### Query Parameters 30 + /// - handle: The handle to resolve. 31 + /// ### Responses 32 + /// - 200 OK: {did: did} 33 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `HandleNotFound`]} 34 + /// - 401 Unauthorized 35 + async fn resolve_handle( 36 + State(db): State<Db>, 37 + State(client): State<Client>, 38 + Query(input): Query<identity::resolve_handle::ParametersData>, 39 + ) -> Result<Json<identity::resolve_handle::Output>> { 40 + let handle = input.handle.as_str(); 41 + if let Ok(did) = sqlx::query_scalar!(r#"SELECT did FROM handles WHERE handle = ?"#, handle) 42 + .fetch_one(&db) 43 + .await 44 + { 45 + return Ok(Json( 46 + identity::resolve_handle::OutputData { 47 + did: atrium_api::types::string::Did::new(did).expect("should be valid DID format"), 48 + } 49 + .into(), 50 + )); 51 + } 52 + 53 + // HACK: Query bsky to see if they have this handle cached. 54 + let response = client 55 + .get(format!( 56 + "https://api.bsky.app/xrpc/com.atproto.identity.resolveHandle?handle={handle}" 57 + )) 58 + .send() 59 + .await 60 + .context("failed to query upstream server")? 61 + .json() 62 + .await 63 + .context("failed to decode response as JSON")?; 64 + 65 + Ok(Json(response)) 66 + } 67 + 68 + #[expect(unused_variables, clippy::todo, reason = "Not yet implemented")] 69 + /// Request an email with a code to in order to request a signed PLC operation. Requires Auth. 70 + /// - POST /xrpc/com.atproto.identity.requestPlcOperationSignature 71 + /// ### Responses 72 + /// - 200 OK 73 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 74 + /// - 401 Unauthorized 75 + async fn request_plc_operation_signature(user: AuthenticatedUser) -> Result<()> { 76 + todo!() 77 + } 78 + 79 + #[expect(unused_variables, clippy::todo, reason = "Not yet implemented")] 80 + /// Signs a PLC operation to update some value(s) in the requesting DID's document. 81 + /// - POST /xrpc/com.atproto.identity.signPlcOperation 82 + /// ### Request Body 83 + /// - token: string // A token received through com.atproto.identity.requestPlcOperationSignature 84 + /// - rotationKeys: string[] 85 + /// - alsoKnownAs: string[] 86 + /// - verificationMethods: services 87 + /// ### Responses 88 + /// - 200 OK: {operation: string} 89 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 90 + /// - 401 Unauthorized 91 + async fn sign_plc_operation( 92 + user: AuthenticatedUser, 93 + State(skey): State<SigningKey>, 94 + State(rkey): State<RotationKey>, 95 + State(config): State<AppConfig>, 96 + Json(input): Json<identity::sign_plc_operation::Input>, 97 + ) -> Result<Json<identity::sign_plc_operation::Output>> { 98 + todo!() 99 + } 100 + 101 + #[expect( 102 + clippy::too_many_arguments, 103 + reason = "Many parameters are required for this endpoint" 104 + )] 105 + /// Updates the current account's handle. Verifies handle validity, and updates did:plc document if necessary. Implemented by PDS, and requires auth. 106 + /// - POST /xrpc/com.atproto.identity.updateHandle 107 + /// ### Query Parameters 108 + /// - handle: handle // The new handle. 109 + /// ### Responses 110 + /// - 200 OK 111 + /// ## Errors 112 + /// - If the handle is already in use. 113 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 114 + /// - 401 Unauthorized 115 + /// ## Panics 116 + /// - If the handle is not valid. 117 + async fn update_handle( 118 + user: AuthenticatedUser, 119 + State(skey): State<SigningKey>, 120 + State(rkey): State<RotationKey>, 121 + State(client): State<Client>, 122 + State(config): State<AppConfig>, 123 + State(db): State<Db>, 124 + State(fhp): State<FirehoseProducer>, 125 + Json(input): Json<identity::update_handle::Input>, 126 + ) -> Result<()> { 127 + let handle = input.handle.as_str(); 128 + let did_str = user.did(); 129 + let did = atrium_api::types::string::Did::new(user.did()).expect("should be valid DID format"); 130 + 131 + if let Some(existing_did) = 132 + sqlx::query_scalar!(r#"SELECT did FROM handles WHERE handle = ?"#, handle) 133 + .fetch_optional(&db) 134 + .await 135 + .context("failed to query did count")? 136 + { 137 + if existing_did != did_str { 138 + return Err(Error::with_status( 139 + StatusCode::BAD_REQUEST, 140 + anyhow!("attempted to update handle to one that is already in use"), 141 + )); 142 + } 143 + } 144 + 145 + // Ensure the existing DID is resolvable. 146 + // If not, we need to register the original handle. 147 + let _did = did::resolve(&client, did.clone()) 148 + .await 149 + .with_context(|| format!("failed to resolve DID for {did_str}")) 150 + .context("should be able to resolve DID")?; 151 + 152 + let op = plc::sign_op( 153 + &rkey, 154 + PlcOperation { 155 + typ: "plc_operation".to_owned(), 156 + rotation_keys: vec![rkey.did()], 157 + verification_methods: HashMap::from([("atproto".to_owned(), skey.did())]), 158 + also_known_as: vec![input.handle.as_str().to_owned()], 159 + services: HashMap::from([( 160 + "atproto_pds".to_owned(), 161 + PlcService::Pds { 162 + endpoint: config.host_name.clone(), 163 + }, 164 + )]), 165 + prev: Some( 166 + sqlx::query_scalar!(r#"SELECT plc_root FROM accounts WHERE did = ?"#, did_str) 167 + .fetch_one(&db) 168 + .await 169 + .context("failed to fetch user PLC root")?, 170 + ), 171 + }, 172 + ) 173 + .context("failed to sign plc op")?; 174 + 175 + if !config.test { 176 + plc::submit(&client, did.as_str(), &op) 177 + .await 178 + .context("failed to submit PLC operation")?; 179 + } 180 + 181 + // FIXME: Properly abstract these implementation details. 182 + let did_hash = did_str 183 + .strip_prefix("did:plc:") 184 + .context("should be valid DID format")?; 185 + let doc = tokio::fs::File::options() 186 + .read(true) 187 + .write(true) 188 + .open(config.plc.path.join(format!("{did_hash}.car"))) 189 + .await 190 + .context("failed to open did doc")?; 191 + 192 + let op_bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode plc op")?; 193 + 194 + let plc_cid = CarStore::open(doc) 195 + .await 196 + .context("failed to open did carstore")? 197 + .write_block(DAG_CBOR, SHA2_256, &op_bytes) 198 + .await 199 + .context("failed to write genesis commit")?; 200 + 201 + let cid_str = plc_cid.to_string(); 202 + 203 + _ = sqlx::query!( 204 + r#"UPDATE accounts SET plc_root = ? WHERE did = ?"#, 205 + cid_str, 206 + did_str 207 + ) 208 + .execute(&db) 209 + .await 210 + .context("failed to update account PLC root")?; 211 + 212 + // Broadcast the identity event now that the new identity is resolvable on the public directory. 213 + fhp.identity( 214 + atrium_api::com::atproto::sync::subscribe_repos::IdentityData { 215 + did: did.clone(), 216 + handle: Some(Handle::new(handle.to_owned()).expect("should be valid handle")), 217 + seq: 0, // Filled by firehose later. 218 + time: Datetime::now(), 219 + }, 220 + ) 221 + .await; 222 + 223 + Ok(()) 224 + } 225 + 226 + async fn todo() -> Result<()> { 227 + Err(Error::unimplemented(anyhow!("not implemented"))) 228 + } 229 + 230 + #[rustfmt::skip] 231 + /// Identity endpoints (/xrpc/com.atproto.identity.*) 232 + /// ### Routes 233 + /// - AP /xrpc/com.atproto.identity.updateHandle -> [`update_handle`] 234 + /// - AP /xrpc/com.atproto.identity.requestPlcOperationSignature -> [`request_plc_operation_signature`] 235 + /// - AP /xrpc/com.atproto.identity.signPlcOperation -> [`sign_plc_operation`] 236 + /// - UG /xrpc/com.atproto.identity.resolveHandle -> [`resolve_handle`] 237 + pub(super) fn routes() -> Router<AppState> { 238 + Router::new() 239 + .route(concat!("/", identity::get_recommended_did_credentials::NSID), get(todo)) 240 + .route(concat!("/", identity::request_plc_operation_signature::NSID), post(request_plc_operation_signature)) 241 + .route(concat!("/", identity::resolve_handle::NSID), get(resolve_handle)) 242 + .route(concat!("/", identity::sign_plc_operation::NSID), post(sign_plc_operation)) 243 + .route(concat!("/", identity::submit_plc_operation::NSID), post(todo)) 244 + .route(concat!("/", identity::update_handle::NSID), post(update_handle)) 245 + }
+5
src/apis/com/atproto/mod.rs
··· 1 + // pub mod admin; 2 + // pub mod identity; 3 + pub mod repo; 4 + // pub mod server; 5 + // pub mod sync;
+142
src/apis/com/atproto/repo/apply_writes.rs
··· 1 + //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 + 3 + use super::*; 4 + 5 + async fn inner_apply_writes( 6 + body: ApplyWritesInput, 7 + auth: AuthenticatedUser, 8 + sequencer: Arc<RwLock<Sequencer>>, 9 + actor_pools: HashMap<String, ActorStorage>, 10 + account_manager: Arc<RwLock<AccountManager>>, 11 + ) -> Result<()> { 12 + let tx: ApplyWritesInput = body; 13 + let ApplyWritesInput { 14 + repo, 15 + validate, 16 + swap_commit, 17 + .. 18 + } = tx; 19 + let account = account_manager 20 + .read() 21 + .await 22 + .get_account( 23 + &repo, 24 + Some(AvailabilityFlags { 25 + include_deactivated: Some(true), 26 + include_taken_down: None, 27 + }), 28 + ) 29 + .await?; 30 + 31 + if let Some(account) = account { 32 + if account.deactivated_at.is_some() { 33 + bail!("Account is deactivated") 34 + } 35 + let did = account.did; 36 + if did != auth.did() { 37 + bail!("AuthRequiredError") 38 + } 39 + let did: &String = &did; 40 + if tx.writes.len() > 200 { 41 + bail!("Too many writes. Max: 200") 42 + } 43 + 44 + let writes: Vec<PreparedWrite> = stream::iter(tx.writes) 45 + .then(async |write| { 46 + Ok::<PreparedWrite, anyhow::Error>(match write { 47 + ApplyWritesInputRefWrite::Create(write) => PreparedWrite::Create( 48 + prepare_create(PrepareCreateOpts { 49 + did: did.clone(), 50 + collection: write.collection, 51 + rkey: write.rkey, 52 + swap_cid: None, 53 + record: serde_json::from_value(write.value)?, 54 + validate, 55 + }) 56 + .await?, 57 + ), 58 + ApplyWritesInputRefWrite::Update(write) => PreparedWrite::Update( 59 + prepare_update(PrepareUpdateOpts { 60 + did: did.clone(), 61 + collection: write.collection, 62 + rkey: write.rkey, 63 + swap_cid: None, 64 + record: serde_json::from_value(write.value)?, 65 + validate, 66 + }) 67 + .await?, 68 + ), 69 + ApplyWritesInputRefWrite::Delete(write) => { 70 + PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts { 71 + did: did.clone(), 72 + collection: write.collection, 73 + rkey: write.rkey, 74 + swap_cid: None, 75 + })?) 76 + } 77 + }) 78 + }) 79 + .collect::<Vec<_>>() 80 + .await 81 + .into_iter() 82 + .collect::<Result<Vec<PreparedWrite>, _>>()?; 83 + 84 + let swap_commit_cid = match swap_commit { 85 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 86 + None => None, 87 + }; 88 + 89 + let mut actor_store = ActorStore::from_actor_pools(did, &actor_pools).await; 90 + 91 + let commit = actor_store 92 + .process_writes(writes.clone(), swap_commit_cid) 93 + .await?; 94 + 95 + _ = sequencer 96 + .write() 97 + .await 98 + .sequence_commit(did.clone(), commit.clone()) 99 + .await?; 100 + account_manager 101 + .write() 102 + .await 103 + .update_repo_root( 104 + did.to_string(), 105 + commit.commit_data.cid, 106 + commit.commit_data.rev, 107 + &actor_pools, 108 + ) 109 + .await?; 110 + Ok(()) 111 + } else { 112 + bail!("Could not find repo: `{repo}`") 113 + } 114 + } 115 + 116 + /// Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 117 + /// - POST /xrpc/com.atproto.repo.applyWrites 118 + /// ### Request Body 119 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 120 + /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data across all operations, 'true' to require it, or leave unset to validate only for known Lexicons. 121 + /// - `writes`: `object[]` // One of: 122 + /// - - com.atproto.repo.applyWrites.create 123 + /// - - com.atproto.repo.applyWrites.update 124 + /// - - com.atproto.repo.applyWrites.delete 125 + /// - `swap_commit`: `cid` // If provided, the entire operation will fail if the current repo commit CID does not match this value. Used to prevent conflicting repo mutations. 126 + #[axum::debug_handler(state = AppState)] 127 + pub(crate) async fn apply_writes( 128 + auth: AuthenticatedUser, 129 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 130 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 131 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 132 + Json(body): Json<ApplyWritesInput>, 133 + ) -> Result<(), ApiError> { 134 + tracing::debug!("@LOG: debug apply_writes {body:#?}"); 135 + match inner_apply_writes(body, auth, sequencer, actor_pools, account_manager).await { 136 + Ok(()) => Ok(()), 137 + Err(error) => { 138 + tracing::error!("@LOG: ERROR: {error}"); 139 + Err(ApiError::RuntimeError) 140 + } 141 + } 142 + }
+140
src/apis/com/atproto/repo/create_record.rs
··· 1 + //! Create a single new repository record. Requires auth, implemented by PDS. 2 + 3 + use super::*; 4 + 5 + async fn inner_create_record( 6 + body: CreateRecordInput, 7 + user: AuthenticatedUser, 8 + sequencer: Arc<RwLock<Sequencer>>, 9 + actor_pools: HashMap<String, ActorStorage>, 10 + account_manager: Arc<RwLock<AccountManager>>, 11 + ) -> Result<CreateRecordOutput> { 12 + let CreateRecordInput { 13 + repo, 14 + collection, 15 + record, 16 + rkey, 17 + validate, 18 + swap_commit, 19 + } = body; 20 + let account = account_manager 21 + .read() 22 + .await 23 + .get_account( 24 + &repo, 25 + Some(AvailabilityFlags { 26 + include_deactivated: Some(true), 27 + include_taken_down: None, 28 + }), 29 + ) 30 + .await?; 31 + if let Some(account) = account { 32 + if account.deactivated_at.is_some() { 33 + bail!("Account is deactivated") 34 + } 35 + let did = account.did; 36 + // if did != auth.access.credentials.unwrap().did.unwrap() { 37 + if did != user.did() { 38 + bail!("AuthRequiredError") 39 + } 40 + let swap_commit_cid = match swap_commit { 41 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 42 + None => None, 43 + }; 44 + let write = prepare_create(PrepareCreateOpts { 45 + did: did.clone(), 46 + collection: collection.clone(), 47 + record: serde_json::from_value(record)?, 48 + rkey, 49 + validate, 50 + swap_cid: None, 51 + }) 52 + .await?; 53 + 54 + let did: &String = &did; 55 + let mut actor_store = ActorStore::from_actor_pools(did, &actor_pools).await; 56 + let backlink_conflicts: Vec<AtUri> = match validate { 57 + Some(true) => { 58 + let write_at_uri: AtUri = write.uri.clone().try_into()?; 59 + actor_store 60 + .record 61 + .get_backlink_conflicts(&write_at_uri, &write.record) 62 + .await? 63 + } 64 + _ => Vec::new(), 65 + }; 66 + 67 + let backlink_deletions: Vec<PreparedDelete> = backlink_conflicts 68 + .iter() 69 + .map(|at_uri| { 70 + prepare_delete(PrepareDeleteOpts { 71 + did: at_uri.get_hostname().to_string(), 72 + collection: at_uri.get_collection(), 73 + rkey: at_uri.get_rkey(), 74 + swap_cid: None, 75 + }) 76 + }) 77 + .collect::<Result<Vec<PreparedDelete>>>()?; 78 + let mut writes: Vec<PreparedWrite> = vec![PreparedWrite::Create(write.clone())]; 79 + for delete in backlink_deletions { 80 + writes.push(PreparedWrite::Delete(delete)); 81 + } 82 + let commit = actor_store 83 + .process_writes(writes.clone(), swap_commit_cid) 84 + .await?; 85 + 86 + _ = sequencer 87 + .write() 88 + .await 89 + .sequence_commit(did.clone(), commit.clone()) 90 + .await?; 91 + account_manager 92 + .write() 93 + .await 94 + .update_repo_root( 95 + did.to_string(), 96 + commit.commit_data.cid, 97 + commit.commit_data.rev, 98 + &actor_pools, 99 + ) 100 + .await?; 101 + 102 + Ok(CreateRecordOutput { 103 + uri: write.uri.clone(), 104 + cid: write.cid.to_string(), 105 + }) 106 + } else { 107 + bail!("Could not find repo: `{repo}`") 108 + } 109 + } 110 + 111 + /// Create a single new repository record. Requires auth, implemented by PDS. 112 + /// - POST /xrpc/com.atproto.repo.createRecord 113 + /// ### Request Body 114 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 115 + /// - `collection`: `nsid` // The NSID of the record collection. 116 + /// - `rkey`: `string` // The record key. <= 512 characters. 117 + /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 118 + /// - `record` 119 + /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 120 + /// ### Responses 121 + /// - 200 OK: {`cid`: `cid`, `uri`: `at-uri`, `commit`: {`cid`: `cid`, `rev`: `tid`}, `validation_status`: [`valid`, `unknown`]} 122 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 123 + /// - 401 Unauthorized 124 + #[axum::debug_handler(state = AppState)] 125 + pub async fn create_record( 126 + user: AuthenticatedUser, 127 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 128 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 129 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 130 + Json(body): Json<CreateRecordInput>, 131 + ) -> Result<Json<CreateRecordOutput>, ApiError> { 132 + tracing::debug!("@LOG: debug create_record {body:#?}"); 133 + match inner_create_record(body, user, sequencer, db_actors, account_manager).await { 134 + Ok(res) => Ok(Json(res)), 135 + Err(error) => { 136 + tracing::error!("@LOG: ERROR: {error}"); 137 + Err(ApiError::RuntimeError) 138 + } 139 + } 140 + }
+117
src/apis/com/atproto/repo/delete_record.rs
··· 1 + //! Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS. 2 + use super::*; 3 + 4 + async fn inner_delete_record( 5 + body: DeleteRecordInput, 6 + user: AuthenticatedUser, 7 + sequencer: Arc<RwLock<Sequencer>>, 8 + actor_pools: HashMap<String, ActorStorage>, 9 + account_manager: Arc<RwLock<AccountManager>>, 10 + ) -> Result<()> { 11 + let DeleteRecordInput { 12 + repo, 13 + collection, 14 + rkey, 15 + swap_record, 16 + swap_commit, 17 + } = body; 18 + let account = account_manager 19 + .read() 20 + .await 21 + .get_account( 22 + &repo, 23 + Some(AvailabilityFlags { 24 + include_deactivated: Some(true), 25 + include_taken_down: None, 26 + }), 27 + ) 28 + .await?; 29 + match account { 30 + None => bail!("Could not find repo: `{repo}`"), 31 + Some(account) if account.deactivated_at.is_some() => bail!("Account is deactivated"), 32 + Some(account) => { 33 + let did = account.did; 34 + // if did != auth.access.credentials.unwrap().did.unwrap() { 35 + if did != user.did() { 36 + bail!("AuthRequiredError") 37 + } 38 + 39 + let swap_commit_cid = match swap_commit { 40 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 41 + None => None, 42 + }; 43 + let swap_record_cid = match swap_record { 44 + Some(swap_record) => Some(Cid::from_str(&swap_record)?), 45 + None => None, 46 + }; 47 + 48 + let write = prepare_delete(PrepareDeleteOpts { 49 + did: did.clone(), 50 + collection, 51 + rkey, 52 + swap_cid: swap_record_cid, 53 + })?; 54 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 55 + let write_at_uri: AtUri = write.uri.clone().try_into()?; 56 + let record = actor_store 57 + .record 58 + .get_record(&write_at_uri, None, Some(true)) 59 + .await?; 60 + let commit = match record { 61 + None => return Ok(()), // No-op if record already doesn't exist 62 + Some(_) => { 63 + actor_store 64 + .process_writes(vec![PreparedWrite::Delete(write.clone())], swap_commit_cid) 65 + .await? 66 + } 67 + }; 68 + 69 + _ = sequencer 70 + .write() 71 + .await 72 + .sequence_commit(did.clone(), commit.clone()) 73 + .await?; 74 + account_manager 75 + .write() 76 + .await 77 + .update_repo_root( 78 + did, 79 + commit.commit_data.cid, 80 + commit.commit_data.rev, 81 + &actor_pools, 82 + ) 83 + .await?; 84 + 85 + Ok(()) 86 + } 87 + } 88 + } 89 + 90 + /// Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS. 91 + /// - POST /xrpc/com.atproto.repo.deleteRecord 92 + /// ### Request Body 93 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 94 + /// - `collection`: `nsid` // The NSID of the record collection. 95 + /// - `rkey`: `string` // The record key. <= 512 characters. 96 + /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. 97 + /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 98 + /// ### Responses 99 + /// - 200 OK: {"commit": {"cid": "string","rev": "string"}} 100 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 101 + /// - 401 Unauthorized 102 + #[axum::debug_handler(state = AppState)] 103 + pub async fn delete_record( 104 + user: AuthenticatedUser, 105 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 106 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 107 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 108 + Json(body): Json<DeleteRecordInput>, 109 + ) -> Result<(), ApiError> { 110 + match inner_delete_record(body, user, sequencer, db_actors, account_manager).await { 111 + Ok(()) => Ok(()), 112 + Err(error) => { 113 + tracing::error!("@LOG: ERROR: {error}"); 114 + Err(ApiError::RuntimeError) 115 + } 116 + } 117 + }
+70
src/apis/com/atproto/repo/describe_repo.rs
··· 1 + //! Get information about an account and repository, including the list of collections. Does not require auth. 2 + use super::*; 3 + 4 + async fn inner_describe_repo( 5 + repo: String, 6 + id_resolver: Arc<RwLock<IdResolver>>, 7 + actor_pools: HashMap<String, ActorStorage>, 8 + account_manager: Arc<RwLock<AccountManager>>, 9 + ) -> Result<DescribeRepoOutput> { 10 + let account = account_manager 11 + .read() 12 + .await 13 + .get_account(&repo, None) 14 + .await?; 15 + match account { 16 + None => bail!("Cound not find user: `{repo}`"), 17 + Some(account) => { 18 + let did_doc: DidDocument = match id_resolver 19 + .write() 20 + .await 21 + .did 22 + .ensure_resolve(&account.did, None) 23 + .await 24 + { 25 + Err(err) => bail!("Could not resolve DID: `{err}`"), 26 + Ok(res) => res, 27 + }; 28 + let handle = rsky_common::get_handle(&did_doc); 29 + let handle_is_correct = handle == account.handle; 30 + 31 + let actor_store = 32 + ActorStore::from_actor_pools(&account.did.clone(), &actor_pools).await; 33 + let collections = actor_store.record.list_collections().await?; 34 + 35 + Ok(DescribeRepoOutput { 36 + handle: account.handle.unwrap_or_else(|| INVALID_HANDLE.to_owned()), 37 + did: account.did, 38 + did_doc: serde_json::to_value(did_doc)?, 39 + collections, 40 + handle_is_correct, 41 + }) 42 + } 43 + } 44 + } 45 + 46 + /// Get information about an account and repository, including the list of collections. Does not require auth. 47 + /// - GET /xrpc/com.atproto.repo.describeRepo 48 + /// ### Query Parameters 49 + /// - `repo`: `at-identifier` // The handle or DID of the repo. 50 + /// ### Responses 51 + /// - 200 OK: {"handle": "string","did": "string","didDoc": {},"collections": [string],"handleIsCorrect": true} \ 52 + /// handeIsCorrect - boolean - Indicates if handle is currently valid (resolves bi-directionally) 53 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 54 + /// - 401 Unauthorized 55 + #[tracing::instrument(skip_all)] 56 + #[axum::debug_handler(state = AppState)] 57 + pub async fn describe_repo( 58 + Query(input): Query<atrium_repo::describe_repo::ParametersData>, 59 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 60 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 61 + State(id_resolver): State<Arc<RwLock<IdResolver>>>, 62 + ) -> Result<Json<DescribeRepoOutput>, ApiError> { 63 + match inner_describe_repo(input.repo.into(), id_resolver, db_actors, account_manager).await { 64 + Ok(res) => Ok(Json(res)), 65 + Err(error) => { 66 + tracing::error!("{error:?}"); 67 + Err(ApiError::RuntimeError) 68 + } 69 + } 70 + }
+37
src/apis/com/atproto/repo/ex.rs
··· 1 + //! 2 + use crate::account_manager::AccountManager; 3 + use crate::serve::ActorStorage; 4 + use crate::{actor_store::ActorStore, error::ApiError, serve::AppState}; 5 + use anyhow::{Result, bail}; 6 + use axum::extract::Query; 7 + use axum::{Json, extract::State}; 8 + use rsky_identity::IdResolver; 9 + use rsky_pds::sequencer::Sequencer; 10 + use std::collections::HashMap; 11 + use std::hash::RandomState; 12 + use std::sync::Arc; 13 + use tokio::sync::RwLock; 14 + 15 + async fn fun( 16 + actor_pools: HashMap<String, ActorStorage>, 17 + account_manager: Arc<RwLock<AccountManager>>, 18 + id_resolver: Arc<RwLock<IdResolver>>, 19 + sequencer: Arc<RwLock<Sequencer>>, 20 + ) -> Result<_> { 21 + todo!(); 22 + } 23 + 24 + /// 25 + #[tracing::instrument(skip_all)] 26 + #[axum::debug_handler(state = AppState)] 27 + pub async fn fun( 28 + auth: AuthenticatedUser, 29 + Query(input): Query<atrium_api::com::atproto::repo::describe_repo::ParametersData>, 30 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 31 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 32 + State(id_resolver): State<Arc<RwLock<IdResolver>>>, 33 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 34 + Json(body): Json<ApplyWritesInput>, 35 + ) -> Result<Json<_>, ApiError> { 36 + todo!(); 37 + }
+102
src/apis/com/atproto/repo/get_record.rs
··· 1 + //! Get a single record from a repository. Does not require auth. 2 + 3 + use crate::pipethrough::{ProxyRequest, pipethrough}; 4 + 5 + use super::*; 6 + 7 + use rsky_pds::pipethrough::OverrideOpts; 8 + 9 + async fn inner_get_record( 10 + repo: String, 11 + collection: String, 12 + rkey: String, 13 + cid: Option<String>, 14 + req: ProxyRequest, 15 + actor_pools: HashMap<String, ActorStorage>, 16 + account_manager: Arc<RwLock<AccountManager>>, 17 + ) -> Result<GetRecordOutput> { 18 + let did = account_manager 19 + .read() 20 + .await 21 + .get_did_for_actor(&repo, None) 22 + .await?; 23 + 24 + // fetch from pds if available, if not then fetch from appview 25 + if let Some(did) = did { 26 + let uri = AtUri::make(did.clone(), Some(collection), Some(rkey))?; 27 + 28 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 29 + 30 + match actor_store.record.get_record(&uri, cid, None).await { 31 + Ok(Some(record)) if record.takedown_ref.is_none() => Ok(GetRecordOutput { 32 + uri: uri.to_string(), 33 + cid: Some(record.cid), 34 + value: serde_json::to_value(record.value)?, 35 + }), 36 + _ => bail!("Could not locate record: `{uri}`"), 37 + } 38 + } else { 39 + match req.cfg.bsky_app_view { 40 + None => bail!("Could not locate record"), 41 + Some(_) => match pipethrough( 42 + &req, 43 + None, 44 + OverrideOpts { 45 + aud: None, 46 + lxm: None, 47 + }, 48 + ) 49 + .await 50 + { 51 + Err(error) => { 52 + tracing::error!("@LOG: ERROR: {error}"); 53 + bail!("Could not locate record") 54 + } 55 + Ok(res) => { 56 + let output: GetRecordOutput = serde_json::from_slice(res.buffer.as_slice())?; 57 + Ok(output) 58 + } 59 + }, 60 + } 61 + } 62 + } 63 + 64 + /// Get a single record from a repository. Does not require auth. 65 + /// - GET /xrpc/com.atproto.repo.getRecord 66 + /// ### Query Parameters 67 + /// - `repo`: `at-identifier` // The handle or DID of the repo. 68 + /// - `collection`: `nsid` // The NSID of the record collection. 69 + /// - `rkey`: `string` // The record key. <= 512 characters. 70 + /// - `cid`: `cid` // The CID of the version of the record. If not specified, then return the most recent version. 71 + /// ### Responses 72 + /// - 200 OK: {"uri": "string","cid": "string","value": {}} 73 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`]} 74 + /// - 401 Unauthorized 75 + #[tracing::instrument(skip_all)] 76 + #[axum::debug_handler(state = AppState)] 77 + pub async fn get_record( 78 + Query(input): Query<ParametersData>, 79 + State(db_actors): State<HashMap<String, ActorStorage, RandomState>>, 80 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 81 + req: ProxyRequest, 82 + ) -> Result<Json<GetRecordOutput>, ApiError> { 83 + let repo = input.repo; 84 + let collection = input.collection; 85 + let rkey = input.rkey; 86 + let cid = input.cid; 87 + match inner_get_record(repo, collection, rkey, cid, req, db_actors, account_manager).await { 88 + Ok(res) => Ok(Json(res)), 89 + Err(error) => { 90 + tracing::error!("@LOG: ERROR: {error}"); 91 + Err(ApiError::RecordNotFound) 92 + } 93 + } 94 + } 95 + 96 + #[derive(serde::Deserialize, Debug)] 97 + pub struct ParametersData { 98 + pub cid: Option<String>, 99 + pub collection: String, 100 + pub repo: String, 101 + pub rkey: String, 102 + }
+183
src/apis/com/atproto/repo/import_repo.rs
··· 1 + use axum::{body::Bytes, http::HeaderMap}; 2 + use reqwest::header; 3 + use rsky_common::env::env_int; 4 + use rsky_repo::block_map::BlockMap; 5 + use rsky_repo::car::{CarWithRoot, read_stream_car_with_root}; 6 + use rsky_repo::parse::get_and_parse_record; 7 + use rsky_repo::repo::Repo; 8 + use rsky_repo::sync::consumer::{VerifyRepoInput, verify_diff}; 9 + use rsky_repo::types::{RecordWriteDescript, VerifiedDiff}; 10 + use ubyte::ToByteUnit; 11 + 12 + use super::*; 13 + 14 + async fn from_data(bytes: Bytes) -> Result<CarWithRoot, ApiError> { 15 + let max_import_size = env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes(); 16 + if bytes.len() > max_import_size { 17 + return Err(ApiError::InvalidRequest(format!( 18 + "Content-Length is greater than maximum of {max_import_size}" 19 + ))); 20 + } 21 + 22 + let mut cursor = std::io::Cursor::new(bytes); 23 + match read_stream_car_with_root(&mut cursor).await { 24 + Ok(car_with_root) => Ok(car_with_root), 25 + Err(error) => { 26 + tracing::error!("Error reading stream car with root\n{error}"); 27 + Err(ApiError::InvalidRequest("Invalid CAR file".to_owned())) 28 + } 29 + } 30 + } 31 + 32 + #[tracing::instrument(skip_all)] 33 + #[axum::debug_handler(state = AppState)] 34 + /// Import a repo in the form of a CAR file. Requires Content-Length HTTP header to be set. 35 + /// Request 36 + /// mime application/vnd.ipld.car 37 + /// Body - required 38 + pub async fn import_repo( 39 + // auth: AccessFullImport, 40 + auth: AuthenticatedUser, 41 + headers: HeaderMap, 42 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 43 + body: Bytes, 44 + ) -> Result<(), ApiError> { 45 + // let requester = auth.access.credentials.unwrap().did.unwrap(); 46 + let requester = auth.did(); 47 + let mut actor_store = ActorStore::from_actor_pools(&requester, &actor_pools).await; 48 + 49 + // Check headers 50 + let content_length = headers 51 + .get(header::CONTENT_LENGTH) 52 + .expect("no content length provided") 53 + .to_str() 54 + .map_err(anyhow::Error::from) 55 + .and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from)) 56 + .expect("invalid content-length header"); 57 + if content_length > env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes() { 58 + return Err(ApiError::InvalidRequest(format!( 59 + "Content-Length is greater than maximum of {}", 60 + env_int("IMPORT_REPO_LIMIT").unwrap_or(100).megabytes() 61 + ))); 62 + }; 63 + 64 + // Get current repo if it exists 65 + let curr_root: Option<Cid> = actor_store.get_repo_root().await; 66 + let curr_repo: Option<Repo> = match curr_root { 67 + None => None, 68 + Some(_root) => Some(Repo::load(actor_store.storage.clone(), curr_root).await?), 69 + }; 70 + 71 + // Process imported car 72 + // let car_with_root = import_repo_input.car_with_root; 73 + let car_with_root: CarWithRoot = match from_data(body).await { 74 + Ok(car) => car, 75 + Err(error) => { 76 + tracing::error!("Error importing repo\n{error:?}"); 77 + return Err(ApiError::InvalidRequest("Invalid CAR file".to_owned())); 78 + } 79 + }; 80 + 81 + // Get verified difference from current repo and imported repo 82 + let mut imported_blocks: BlockMap = car_with_root.blocks; 83 + let imported_root: Cid = car_with_root.root; 84 + let opts = VerifyRepoInput { 85 + ensure_leaves: Some(false), 86 + }; 87 + 88 + let diff: VerifiedDiff = match verify_diff( 89 + curr_repo, 90 + &mut imported_blocks, 91 + imported_root, 92 + None, 93 + None, 94 + Some(opts), 95 + ) 96 + .await 97 + { 98 + Ok(res) => res, 99 + Err(error) => { 100 + tracing::error!("{:?}", error); 101 + return Err(ApiError::RuntimeError); 102 + } 103 + }; 104 + 105 + let commit_data = diff.commit; 106 + let prepared_writes: Vec<PreparedWrite> = 107 + prepare_import_repo_writes(requester, diff.writes, &imported_blocks).await?; 108 + match actor_store 109 + .process_import_repo(commit_data, prepared_writes) 110 + .await 111 + { 112 + Ok(_res) => {} 113 + Err(error) => { 114 + tracing::error!("Error importing repo\n{error}"); 115 + return Err(ApiError::RuntimeError); 116 + } 117 + } 118 + 119 + Ok(()) 120 + } 121 + 122 + /// Converts list of RecordWriteDescripts into a list of PreparedWrites 123 + async fn prepare_import_repo_writes( 124 + did: String, 125 + writes: Vec<RecordWriteDescript>, 126 + blocks: &BlockMap, 127 + ) -> Result<Vec<PreparedWrite>, ApiError> { 128 + match stream::iter(writes) 129 + .then(|write| { 130 + let did = did.clone(); 131 + async move { 132 + Ok::<PreparedWrite, anyhow::Error>(match write { 133 + RecordWriteDescript::Create(write) => { 134 + let parsed_record = get_and_parse_record(blocks, write.cid)?; 135 + PreparedWrite::Create( 136 + prepare_create(PrepareCreateOpts { 137 + did: did.clone(), 138 + collection: write.collection, 139 + rkey: Some(write.rkey), 140 + swap_cid: None, 141 + record: parsed_record.record, 142 + validate: Some(true), 143 + }) 144 + .await?, 145 + ) 146 + } 147 + RecordWriteDescript::Update(write) => { 148 + let parsed_record = get_and_parse_record(blocks, write.cid)?; 149 + PreparedWrite::Update( 150 + prepare_update(PrepareUpdateOpts { 151 + did: did.clone(), 152 + collection: write.collection, 153 + rkey: write.rkey, 154 + swap_cid: None, 155 + record: parsed_record.record, 156 + validate: Some(true), 157 + }) 158 + .await?, 159 + ) 160 + } 161 + RecordWriteDescript::Delete(write) => { 162 + PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts { 163 + did: did.clone(), 164 + collection: write.collection, 165 + rkey: write.rkey, 166 + swap_cid: None, 167 + })?) 168 + } 169 + }) 170 + } 171 + }) 172 + .collect::<Vec<_>>() 173 + .await 174 + .into_iter() 175 + .collect::<Result<Vec<PreparedWrite>, _>>() 176 + { 177 + Ok(res) => Ok(res), 178 + Err(error) => { 179 + tracing::error!("Error preparing import repo writes\n{error}"); 180 + Err(ApiError::RuntimeError) 181 + } 182 + } 183 + }
+48
src/apis/com/atproto/repo/list_missing_blobs.rs
··· 1 + //! Returns a list of missing blobs for the requesting account. Intended to be used in the account migration flow. 2 + use rsky_lexicon::com::atproto::repo::ListMissingBlobsOutput; 3 + use rsky_pds::actor_store::blob::ListMissingBlobsOpts; 4 + 5 + use super::*; 6 + 7 + /// Returns a list of missing blobs for the requesting account. Intended to be used in the account migration flow. 8 + /// Request 9 + /// Query Parameters 10 + /// limit integer 11 + /// Possible values: >= 1 and <= 1000 12 + /// Default value: 500 13 + /// cursor string 14 + /// Responses 15 + /// cursor string 16 + /// blobs object[] 17 + #[tracing::instrument(skip_all)] 18 + #[axum::debug_handler(state = AppState)] 19 + pub async fn list_missing_blobs( 20 + user: AuthenticatedUser, 21 + Query(input): Query<atrium_repo::list_missing_blobs::ParametersData>, 22 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 23 + ) -> Result<Json<ListMissingBlobsOutput>, ApiError> { 24 + let cursor = input.cursor; 25 + let limit = input.limit; 26 + let default_limit: atrium_api::types::LimitedNonZeroU16<1000> = 27 + atrium_api::types::LimitedNonZeroU16::try_from(500).expect("default limit"); 28 + let limit: u16 = limit.unwrap_or(default_limit).into(); 29 + // let did = auth.access.credentials.unwrap().did.unwrap(); 30 + let did = user.did(); 31 + 32 + let actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 33 + 34 + match actor_store 35 + .blob 36 + .list_missing_blobs(ListMissingBlobsOpts { cursor, limit }) 37 + .await 38 + { 39 + Ok(blobs) => { 40 + let cursor = blobs.last().map(|last_blob| last_blob.cid.clone()); 41 + Ok(Json(ListMissingBlobsOutput { cursor, blobs })) 42 + } 43 + Err(error) => { 44 + tracing::error!("{error:?}"); 45 + Err(ApiError::RuntimeError) 46 + } 47 + } 48 + }
+146
src/apis/com/atproto/repo/list_records.rs
··· 1 + //! List a range of records in a repository, matching a specific collection. Does not require auth. 2 + use super::*; 3 + 4 + // #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 5 + // #[serde(rename_all = "camelCase")] 6 + // /// Parameters for [`list_records`]. 7 + // pub(super) struct ListRecordsParameters { 8 + // ///The NSID of the record type. 9 + // pub collection: Nsid, 10 + // /// The cursor to start from. 11 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 12 + // pub cursor: Option<String>, 13 + // ///The number of records to return. 14 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 15 + // pub limit: Option<String>, 16 + // ///The handle or DID of the repo. 17 + // pub repo: AtIdentifier, 18 + // ///Flag to reverse the order of the returned records. 19 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 20 + // pub reverse: Option<bool>, 21 + // ///DEPRECATED: The highest sort-ordered rkey to stop at (exclusive) 22 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 23 + // pub rkey_end: Option<String>, 24 + // ///DEPRECATED: The lowest sort-ordered rkey to start from (exclusive) 25 + // #[serde(skip_serializing_if = "core::option::Option::is_none")] 26 + // pub rkey_start: Option<String>, 27 + // } 28 + 29 + #[expect(non_snake_case, clippy::too_many_arguments)] 30 + async fn inner_list_records( 31 + // The handle or DID of the repo. 32 + repo: String, 33 + // The NSID of the record type. 34 + collection: String, 35 + // The number of records to return. 36 + limit: u16, 37 + cursor: Option<String>, 38 + // DEPRECATED: The lowest sort-ordered rkey to start from (exclusive) 39 + rkeyStart: Option<String>, 40 + // DEPRECATED: The highest sort-ordered rkey to stop at (exclusive) 41 + rkeyEnd: Option<String>, 42 + // Flag to reverse the order of the returned records. 43 + reverse: bool, 44 + // The actor pools 45 + actor_pools: HashMap<String, ActorStorage>, 46 + account_manager: Arc<RwLock<AccountManager>>, 47 + ) -> Result<ListRecordsOutput> { 48 + if limit > 100 { 49 + bail!("Error: limit can not be greater than 100") 50 + } 51 + let did = account_manager 52 + .read() 53 + .await 54 + .get_did_for_actor(&repo, None) 55 + .await?; 56 + if let Some(did) = did { 57 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 58 + 59 + let records: Vec<Record> = actor_store 60 + .record 61 + .list_records_for_collection( 62 + collection, 63 + limit as i64, 64 + reverse, 65 + cursor, 66 + rkeyStart, 67 + rkeyEnd, 68 + None, 69 + ) 70 + .await? 71 + .into_iter() 72 + .map(|record| { 73 + Ok(Record { 74 + uri: record.uri.clone(), 75 + cid: record.cid.clone(), 76 + value: serde_json::to_value(record)?, 77 + }) 78 + }) 79 + .collect::<Result<Vec<Record>>>()?; 80 + 81 + let last_record = records.last(); 82 + let cursor: Option<String>; 83 + if let Some(last_record) = last_record { 84 + let last_at_uri: AtUri = last_record.uri.clone().try_into()?; 85 + cursor = Some(last_at_uri.get_rkey()); 86 + } else { 87 + cursor = None; 88 + } 89 + Ok(ListRecordsOutput { records, cursor }) 90 + } else { 91 + bail!("Could not find repo: {repo}") 92 + } 93 + } 94 + 95 + /// List a range of records in a repository, matching a specific collection. Does not require auth. 96 + /// - GET /xrpc/com.atproto.repo.listRecords 97 + /// ### Query Parameters 98 + /// - `repo`: `at-identifier` // The handle or DID of the repo. 99 + /// - `collection`: `nsid` // The NSID of the record type. 100 + /// - `limit`: `integer` // The maximum number of records to return. Default 50, >=1 and <=100. 101 + /// - `cursor`: `string` 102 + /// - `reverse`: `boolean` // Flag to reverse the order of the returned records. 103 + /// ### Responses 104 + /// - 200 OK: {"cursor": "string","records": [{"uri": "string","cid": "string","value": {}}]} 105 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 106 + /// - 401 Unauthorized 107 + #[tracing::instrument(skip_all)] 108 + #[allow(non_snake_case)] 109 + #[axum::debug_handler(state = AppState)] 110 + pub async fn list_records( 111 + Query(input): Query<atrium_repo::list_records::ParametersData>, 112 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 113 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 114 + ) -> Result<Json<ListRecordsOutput>, ApiError> { 115 + let repo = input.repo; 116 + let collection = input.collection; 117 + let limit: Option<u8> = input.limit.map(u8::from); 118 + let limit: Option<u16> = limit.map(|x| x.into()); 119 + let cursor = input.cursor; 120 + let reverse = input.reverse; 121 + let rkeyStart = None; 122 + let rkeyEnd = None; 123 + 124 + let limit = limit.unwrap_or(50); 125 + let reverse = reverse.unwrap_or(false); 126 + 127 + match inner_list_records( 128 + repo.into(), 129 + collection.into(), 130 + limit, 131 + cursor, 132 + rkeyStart, 133 + rkeyEnd, 134 + reverse, 135 + actor_pools, 136 + account_manager, 137 + ) 138 + .await 139 + { 140 + Ok(res) => Ok(Json(res)), 141 + Err(error) => { 142 + tracing::error!("@LOG: ERROR: {error}"); 143 + Err(ApiError::RuntimeError) 144 + } 145 + } 146 + }
+111
src/apis/com/atproto/repo/mod.rs
··· 1 + use atrium_api::com::atproto::repo as atrium_repo; 2 + use axum::{ 3 + Router, 4 + routing::{get, post}, 5 + }; 6 + use constcat::concat; 7 + 8 + pub mod apply_writes; 9 + pub mod create_record; 10 + pub mod delete_record; 11 + pub mod describe_repo; 12 + pub mod get_record; 13 + pub mod import_repo; 14 + pub mod list_missing_blobs; 15 + pub mod list_records; 16 + pub mod put_record; 17 + pub mod upload_blob; 18 + 19 + use crate::account_manager::AccountManager; 20 + use crate::account_manager::helpers::account::AvailabilityFlags; 21 + use crate::{ 22 + actor_store::ActorStore, 23 + auth::AuthenticatedUser, 24 + error::ApiError, 25 + serve::{ActorStorage, AppState}, 26 + }; 27 + use anyhow::{Result, bail}; 28 + use axum::extract::Query; 29 + use axum::{Json, extract::State}; 30 + use cidv10::Cid; 31 + use futures::stream::{self, StreamExt}; 32 + use rsky_identity::IdResolver; 33 + use rsky_identity::types::DidDocument; 34 + use rsky_lexicon::com::atproto::repo::DeleteRecordInput; 35 + use rsky_lexicon::com::atproto::repo::DescribeRepoOutput; 36 + use rsky_lexicon::com::atproto::repo::GetRecordOutput; 37 + use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite}; 38 + use rsky_lexicon::com::atproto::repo::{CreateRecordInput, CreateRecordOutput}; 39 + use rsky_lexicon::com::atproto::repo::{ListRecordsOutput, Record}; 40 + // use rsky_pds::pipethrough::{OverrideOpts, ProxyRequest, pipethrough}; 41 + use rsky_pds::repo::prepare::{ 42 + PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete, 43 + prepare_update, 44 + }; 45 + use rsky_pds::sequencer::Sequencer; 46 + use rsky_repo::types::PreparedDelete; 47 + use rsky_repo::types::PreparedWrite; 48 + use rsky_syntax::aturi::AtUri; 49 + use rsky_syntax::handle::INVALID_HANDLE; 50 + use std::collections::HashMap; 51 + use std::hash::RandomState; 52 + use std::str::FromStr; 53 + use std::sync::Arc; 54 + use tokio::sync::RwLock; 55 + 56 + /// These endpoints are part of the atproto PDS repository management APIs. \ 57 + /// Requests usually require authentication (unlike the com.atproto.sync.* endpoints), and are made directly to the user's own PDS instance. 58 + /// ### Routes 59 + /// - AP /xrpc/com.atproto.repo.applyWrites -> [`apply_writes`] 60 + /// - AP /xrpc/com.atproto.repo.createRecord -> [`create_record`] 61 + /// - AP /xrpc/com.atproto.repo.putRecord -> [`put_record`] 62 + /// - AP /xrpc/com.atproto.repo.deleteRecord -> [`delete_record`] 63 + /// - AP /xrpc/com.atproto.repo.uploadBlob -> [`upload_blob`] 64 + /// - UG /xrpc/com.atproto.repo.describeRepo -> [`describe_repo`] 65 + /// - UG /xrpc/com.atproto.repo.getRecord -> [`get_record`] 66 + /// - UG /xrpc/com.atproto.repo.listRecords -> [`list_records`] 67 + /// - [ ] xx /xrpc/com.atproto.repo.importRepo 68 + // - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs 69 + pub(crate) fn routes() -> Router<AppState> { 70 + Router::new() 71 + .route( 72 + concat!("/", atrium_repo::apply_writes::NSID), 73 + post(apply_writes::apply_writes), 74 + ) 75 + .route( 76 + concat!("/", atrium_repo::create_record::NSID), 77 + post(create_record::create_record), 78 + ) 79 + .route( 80 + concat!("/", atrium_repo::put_record::NSID), 81 + post(put_record::put_record), 82 + ) 83 + .route( 84 + concat!("/", atrium_repo::delete_record::NSID), 85 + post(delete_record::delete_record), 86 + ) 87 + .route( 88 + concat!("/", atrium_repo::upload_blob::NSID), 89 + post(upload_blob::upload_blob), 90 + ) 91 + .route( 92 + concat!("/", atrium_repo::describe_repo::NSID), 93 + get(describe_repo::describe_repo), 94 + ) 95 + .route( 96 + concat!("/", atrium_repo::get_record::NSID), 97 + get(get_record::get_record), 98 + ) 99 + .route( 100 + concat!("/", atrium_repo::import_repo::NSID), 101 + post(import_repo::import_repo), 102 + ) 103 + .route( 104 + concat!("/", atrium_repo::list_missing_blobs::NSID), 105 + get(list_missing_blobs::list_missing_blobs), 106 + ) 107 + .route( 108 + concat!("/", atrium_repo::list_records::NSID), 109 + get(list_records::list_records), 110 + ) 111 + }
+157
src/apis/com/atproto/repo/put_record.rs
··· 1 + //! Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS. 2 + use anyhow::bail; 3 + use rsky_lexicon::com::atproto::repo::{PutRecordInput, PutRecordOutput}; 4 + use rsky_repo::types::CommitDataWithOps; 5 + 6 + use super::*; 7 + 8 + #[tracing::instrument(skip_all)] 9 + async fn inner_put_record( 10 + body: PutRecordInput, 11 + auth: AuthenticatedUser, 12 + sequencer: Arc<RwLock<Sequencer>>, 13 + actor_pools: HashMap<String, ActorStorage>, 14 + account_manager: Arc<RwLock<AccountManager>>, 15 + ) -> Result<PutRecordOutput> { 16 + let PutRecordInput { 17 + repo, 18 + collection, 19 + rkey, 20 + validate, 21 + record, 22 + swap_record, 23 + swap_commit, 24 + } = body; 25 + let account = account_manager 26 + .read() 27 + .await 28 + .get_account( 29 + &repo, 30 + Some(AvailabilityFlags { 31 + include_deactivated: Some(true), 32 + include_taken_down: None, 33 + }), 34 + ) 35 + .await?; 36 + if let Some(account) = account { 37 + if account.deactivated_at.is_some() { 38 + bail!("Account is deactivated") 39 + } 40 + let did = account.did; 41 + // if did != auth.access.credentials.unwrap().did.unwrap() { 42 + if did != auth.did() { 43 + bail!("AuthRequiredError") 44 + } 45 + let uri = AtUri::make(did.clone(), Some(collection.clone()), Some(rkey.clone()))?; 46 + let swap_commit_cid = match swap_commit { 47 + Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 48 + None => None, 49 + }; 50 + let swap_record_cid = match swap_record { 51 + Some(swap_record) => Some(Cid::from_str(&swap_record)?), 52 + None => None, 53 + }; 54 + let (commit, write): (Option<CommitDataWithOps>, PreparedWrite) = { 55 + let mut actor_store = ActorStore::from_actor_pools(&did, &actor_pools).await; 56 + 57 + let current = actor_store 58 + .record 59 + .get_record(&uri, None, Some(true)) 60 + .await?; 61 + tracing::debug!("@LOG: debug inner_put_record, current: {current:?}"); 62 + let write: PreparedWrite = if current.is_some() { 63 + PreparedWrite::Update( 64 + prepare_update(PrepareUpdateOpts { 65 + did: did.clone(), 66 + collection, 67 + rkey, 68 + swap_cid: swap_record_cid, 69 + record: serde_json::from_value(record)?, 70 + validate, 71 + }) 72 + .await?, 73 + ) 74 + } else { 75 + PreparedWrite::Create( 76 + prepare_create(PrepareCreateOpts { 77 + did: did.clone(), 78 + collection, 79 + rkey: Some(rkey), 80 + swap_cid: swap_record_cid, 81 + record: serde_json::from_value(record)?, 82 + validate, 83 + }) 84 + .await?, 85 + ) 86 + }; 87 + 88 + match current { 89 + Some(current) if current.cid == write.cid().expect("write cid").to_string() => { 90 + (None, write) 91 + } 92 + _ => { 93 + let commit = actor_store 94 + .process_writes(vec![write.clone()], swap_commit_cid) 95 + .await?; 96 + (Some(commit), write) 97 + } 98 + } 99 + }; 100 + 101 + if let Some(commit) = commit { 102 + _ = sequencer 103 + .write() 104 + .await 105 + .sequence_commit(did.clone(), commit.clone()) 106 + .await?; 107 + account_manager 108 + .write() 109 + .await 110 + .update_repo_root( 111 + did, 112 + commit.commit_data.cid, 113 + commit.commit_data.rev, 114 + &actor_pools, 115 + ) 116 + .await?; 117 + } 118 + Ok(PutRecordOutput { 119 + uri: write.uri().to_string(), 120 + cid: write.cid().expect("write cid").to_string(), 121 + }) 122 + } else { 123 + bail!("Could not find repo: `{repo}`") 124 + } 125 + } 126 + 127 + /// Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS. 128 + /// - POST /xrpc/com.atproto.repo.putRecord 129 + /// ### Request Body 130 + /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 131 + /// - `collection`: `nsid` // The NSID of the record collection. 132 + /// - `rkey`: `string` // The record key. <= 512 characters. 133 + /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 134 + /// - `record` 135 + /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. WARNING: nullable and optional field; may cause problems with golang implementation 136 + /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 137 + /// ### Responses 138 + /// - 200 OK: {"uri": "string","cid": "string","commit": {"cid": "string","rev": "string"},"validationStatus": "valid | unknown"} 139 + /// - 400 Bad Request: {error:"`InvalidRequest` | `ExpiredToken` | `InvalidToken` | `InvalidSwap`"} 140 + /// - 401 Unauthorized 141 + #[tracing::instrument(skip_all)] 142 + pub async fn put_record( 143 + auth: AuthenticatedUser, 144 + State(sequencer): State<Arc<RwLock<Sequencer>>>, 145 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 146 + State(account_manager): State<Arc<RwLock<AccountManager>>>, 147 + Json(body): Json<PutRecordInput>, 148 + ) -> Result<Json<PutRecordOutput>, ApiError> { 149 + tracing::debug!("@LOG: debug put_record {body:#?}"); 150 + match inner_put_record(body, auth, sequencer, actor_pools, account_manager).await { 151 + Ok(res) => Ok(Json(res)), 152 + Err(error) => { 153 + tracing::error!("@LOG: ERROR: {error}"); 154 + Err(ApiError::RuntimeError) 155 + } 156 + } 157 + }
+117
src/apis/com/atproto/repo/upload_blob.rs
··· 1 + //! Upload a new blob, to be referenced from a repository record. 2 + use crate::config::AppConfig; 3 + use anyhow::Context as _; 4 + use axum::{ 5 + body::Bytes, 6 + http::{self, HeaderMap}, 7 + }; 8 + use rsky_lexicon::com::atproto::repo::{Blob, BlobOutput}; 9 + use rsky_repo::types::{BlobConstraint, PreparedBlobRef}; 10 + // use rsky_common::BadContentTypeError; 11 + 12 + use super::*; 13 + 14 + async fn inner_upload_blob( 15 + auth: AuthenticatedUser, 16 + blob: Bytes, 17 + content_type: String, 18 + actor_pools: HashMap<String, ActorStorage>, 19 + ) -> Result<BlobOutput> { 20 + // let requester = auth.access.credentials.unwrap().did.unwrap(); 21 + let requester = auth.did(); 22 + 23 + let actor_store = ActorStore::from_actor_pools(&requester, &actor_pools).await; 24 + 25 + let metadata = actor_store 26 + .blob 27 + .upload_blob_and_get_metadata(content_type, blob) 28 + .await?; 29 + let blobref = actor_store.blob.track_untethered_blob(metadata).await?; 30 + 31 + // make the blob permanent if an associated record is already indexed 32 + let records_for_blob = actor_store 33 + .blob 34 + .get_records_for_blob(blobref.get_cid()?) 35 + .await?; 36 + 37 + if !records_for_blob.is_empty() { 38 + actor_store 39 + .blob 40 + .verify_blob_and_make_permanent(PreparedBlobRef { 41 + cid: blobref.get_cid()?, 42 + mime_type: blobref.get_mime_type().to_string(), 43 + constraints: BlobConstraint { 44 + max_size: None, 45 + accept: None, 46 + }, 47 + }) 48 + .await?; 49 + } 50 + 51 + Ok(BlobOutput { 52 + blob: Blob { 53 + r#type: Some("blob".to_owned()), 54 + r#ref: Some(blobref.get_cid()?), 55 + cid: None, 56 + mime_type: blobref.get_mime_type().to_string(), 57 + size: blobref.get_size(), 58 + original: None, 59 + }, 60 + }) 61 + } 62 + 63 + /// Upload a new blob, to be referenced from a repository record. \ 64 + /// The blob will be deleted if it is not referenced within a time window (eg, minutes). \ 65 + /// Blob restrictions (mimetype, size, etc) are enforced when the reference is created. \ 66 + /// Requires auth, implemented by PDS. 67 + /// - POST /xrpc/com.atproto.repo.uploadBlob 68 + /// ### Request Body 69 + /// ### Responses 70 + /// - 200 OK: {"blob": "binary"} 71 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 72 + /// - 401 Unauthorized 73 + #[tracing::instrument(skip_all)] 74 + #[axum::debug_handler(state = AppState)] 75 + pub async fn upload_blob( 76 + auth: AuthenticatedUser, 77 + headers: HeaderMap, 78 + State(config): State<AppConfig>, 79 + State(actor_pools): State<HashMap<String, ActorStorage, RandomState>>, 80 + blob: Bytes, 81 + ) -> Result<Json<BlobOutput>, ApiError> { 82 + let content_length = headers 83 + .get(http::header::CONTENT_LENGTH) 84 + .context("no content length provided")? 85 + .to_str() 86 + .map_err(anyhow::Error::from) 87 + .and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from)) 88 + .context("invalid content-length header")?; 89 + let content_type = headers 90 + .get(http::header::CONTENT_TYPE) 91 + .context("no content-type provided")? 92 + .to_str() 93 + // .map_err(BadContentTypeError::MissingType) 94 + .context("invalid content-type provided")? 95 + .to_owned(); 96 + 97 + if content_length > config.blob.limit { 98 + return Err(ApiError::InvalidRequest(format!( 99 + "Content-Length is greater than maximum of {}", 100 + config.blob.limit 101 + ))); 102 + }; 103 + if blob.len() as u64 > config.blob.limit { 104 + return Err(ApiError::InvalidRequest(format!( 105 + "Blob size is greater than maximum of {} despite content-length header", 106 + config.blob.limit 107 + ))); 108 + }; 109 + 110 + match inner_upload_blob(auth, blob, content_type, actor_pools).await { 111 + Ok(res) => Ok(Json(res)), 112 + Err(error) => { 113 + tracing::error!("{error:?}"); 114 + Err(ApiError::RuntimeError) 115 + } 116 + } 117 + }
+791
src/apis/com/atproto/server/server.rs
··· 1 + //! Server endpoints. (/xrpc/com.atproto.server.*) 2 + use std::{collections::HashMap, str::FromStr as _}; 3 + 4 + use anyhow::{Context as _, anyhow}; 5 + use argon2::{ 6 + Argon2, PasswordHash, PasswordHasher as _, PasswordVerifier as _, password_hash::SaltString, 7 + }; 8 + use atrium_api::{ 9 + com::atproto::server, 10 + types::string::{Datetime, Did, Handle, Tid}, 11 + }; 12 + use atrium_crypto::keypair::Did as _; 13 + use atrium_repo::{ 14 + Cid, Repository, 15 + blockstore::{AsyncBlockStoreWrite as _, CarStore, DAG_CBOR, SHA2_256}, 16 + }; 17 + use axum::{ 18 + Json, Router, 19 + extract::{Query, Request, State}, 20 + http::StatusCode, 21 + routing::{get, post}, 22 + }; 23 + use constcat::concat; 24 + use metrics::counter; 25 + use rand::Rng as _; 26 + use sha2::Digest as _; 27 + use uuid::Uuid; 28 + 29 + use crate::{ 30 + AppState, Client, Db, Error, Result, RotationKey, SigningKey, 31 + auth::{self, AuthenticatedUser}, 32 + config::AppConfig, 33 + firehose::{Commit, FirehoseProducer}, 34 + metrics::AUTH_FAILED, 35 + plc::{self, PlcOperation, PlcService}, 36 + storage, 37 + }; 38 + 39 + /// This is a dummy password that can be used in absence of a real password. 40 + const DUMMY_PASSWORD: &str = "$argon2id$v=19$m=19456,t=2,p=1$En2LAfHjeO0SZD5IUU1Abg$RpS8nHhhqY4qco2uyd41p9Y/1C+Lvi214MAWukzKQMI"; 41 + 42 + /// Create an invite code. 43 + /// - POST /xrpc/com.atproto.server.createInviteCode 44 + /// ### Request Body 45 + /// - `useCount`: integer 46 + /// - `forAccount`: string (optional) 47 + /// ### Responses 48 + /// - 200 OK: {code: string} 49 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 50 + /// - 401 Unauthorized 51 + async fn create_invite_code( 52 + _user: AuthenticatedUser, 53 + State(db): State<Db>, 54 + Json(input): Json<server::create_invite_code::Input>, 55 + ) -> Result<Json<server::create_invite_code::Output>> { 56 + let uuid = Uuid::new_v4().to_string(); 57 + let did = input.for_account.as_deref(); 58 + let count = std::cmp::min(input.use_count, 100); // Maximum of 100 uses for any code. 59 + 60 + if count <= 0 { 61 + return Err(anyhow!("use_count must be greater than 0").into()); 62 + } 63 + 64 + Ok(Json( 65 + server::create_invite_code::OutputData { 66 + code: sqlx::query_scalar!( 67 + r#" 68 + INSERT INTO invites (id, did, count, created_at) 69 + VALUES (?, ?, ?, datetime('now')) 70 + RETURNING id 71 + "#, 72 + uuid, 73 + did, 74 + count, 75 + ) 76 + .fetch_one(&db) 77 + .await 78 + .context("failed to create new invite code")?, 79 + } 80 + .into(), 81 + )) 82 + } 83 + 84 + #[expect(clippy::too_many_lines, reason = "TODO: refactor")] 85 + /// Create an account. Implemented by PDS. 86 + /// - POST /xrpc/com.atproto.server.createAccount 87 + /// ### Request Body 88 + /// - `email`: string 89 + /// - `handle`: string (required) 90 + /// - `did`: string - Pre-existing atproto DID, being imported to a new account. 91 + /// - `inviteCode`: string 92 + /// - `verificationCode`: string 93 + /// - `verificationPhone`: string 94 + /// - `password`: string - Initial account password. May need to meet instance-specific password strength requirements. 95 + /// - `recoveryKey`: string - DID PLC rotation key (aka, recovery key) to be included in PLC creation operation. 96 + /// - `plcOp`: object 97 + /// ## Responses 98 + /// - 200 OK: {"accessJwt": "string","refreshJwt": "string","handle": "string","did": "string","didDoc": {}} 99 + /// - 400 Bad Request: {error: [`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidHandle`, `InvalidPassword`, \ 100 + /// `InvalidInviteCode`, `HandleNotAvailable`, `UnsupportedDomain`, `UnresolvableDid`, `IncompatibleDidDoc`)} 101 + /// - 401 Unauthorized 102 + async fn create_account( 103 + State(db): State<Db>, 104 + State(skey): State<SigningKey>, 105 + State(rkey): State<RotationKey>, 106 + State(client): State<Client>, 107 + State(config): State<AppConfig>, 108 + State(fhp): State<FirehoseProducer>, 109 + Json(input): Json<server::create_account::Input>, 110 + ) -> Result<Json<server::create_account::Output>> { 111 + let email = input.email.as_deref().context("no email provided")?; 112 + // Hash the user's password. 113 + let pass = Argon2::default() 114 + .hash_password( 115 + input 116 + .password 117 + .as_deref() 118 + .context("no password provided")? 119 + .as_bytes(), 120 + SaltString::generate(&mut rand::thread_rng()).as_salt(), 121 + ) 122 + .context("failed to hash password")? 123 + .to_string(); 124 + let handle = input.handle.as_str().to_owned(); 125 + 126 + // TODO: Handle the account migration flow. 127 + // Users will hit this endpoint with a service-level authentication token. 128 + // 129 + // https://github.com/bluesky-social/pds/blob/main/ACCOUNT_MIGRATION.md 130 + 131 + // TODO: `input.plc_op` 132 + if input.plc_op.is_some() { 133 + return Err(Error::unimplemented(anyhow!("plc_op"))); 134 + } 135 + 136 + let recovery_keys = if let Some(ref key) = input.recovery_key { 137 + // Ensure the provided recovery key is valid. 138 + if let Err(error) = atrium_crypto::did::parse_did_key(key) { 139 + return Err(Error::with_status( 140 + StatusCode::BAD_REQUEST, 141 + anyhow::Error::new(error).context("provided recovery key is in invalid format"), 142 + )); 143 + } 144 + 145 + // Enroll the user-provided recovery key at a higher priority than our own. 146 + vec![key.clone(), rkey.did()] 147 + } else { 148 + vec![rkey.did()] 149 + }; 150 + 151 + // Begin a new transaction to actually create the user's profile. 152 + // Unless committed, the transaction will be automatically rolled back. 153 + let mut tx = db.begin().await.context("failed to begin transaction")?; 154 + 155 + // TODO: Make this its own toggle instead of tied to test mode 156 + if !config.test { 157 + let _invite = match input.invite_code { 158 + Some(ref code) => { 159 + let invite: Option<String> = sqlx::query_scalar!( 160 + r#" 161 + UPDATE invites 162 + SET count = count - 1 163 + WHERE id = ? 164 + AND count > 0 165 + RETURNING id 166 + "#, 167 + code 168 + ) 169 + .fetch_optional(&mut *tx) 170 + .await 171 + .context("failed to check invite code")?; 172 + 173 + invite.context("invalid invite code")? 174 + } 175 + None => { 176 + return Err(anyhow!("invite code required").into()); 177 + } 178 + }; 179 + } 180 + 181 + // Account can be created. Synthesize a new DID for the user. 182 + // https://github.com/did-method-plc/did-method-plc?tab=readme-ov-file#did-creation 183 + let op = plc::sign_op( 184 + &rkey, 185 + PlcOperation { 186 + typ: "plc_operation".to_owned(), 187 + rotation_keys: recovery_keys, 188 + verification_methods: HashMap::from([("atproto".to_owned(), skey.did())]), 189 + also_known_as: vec![format!("at://{}", input.handle.as_str())], 190 + services: HashMap::from([( 191 + "atproto_pds".to_owned(), 192 + PlcService::Pds { 193 + endpoint: format!("https://{}", config.host_name), 194 + }, 195 + )]), 196 + prev: None, 197 + }, 198 + ) 199 + .context("failed to sign genesis op")?; 200 + let op_bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode genesis op")?; 201 + 202 + let did_hash = { 203 + let digest = base32::encode( 204 + base32::Alphabet::Rfc4648Lower { padding: false }, 205 + sha2::Sha256::digest(&op_bytes).as_slice(), 206 + ); 207 + if digest.len() < 24 { 208 + return Err(anyhow!("digest too short").into()); 209 + } 210 + #[expect(clippy::string_slice, reason = "digest length confirmed")] 211 + digest[..24].to_owned() 212 + }; 213 + let did = format!("did:plc:{did_hash}"); 214 + 215 + let doc = tokio::fs::File::create(config.plc.path.join(format!("{did_hash}.car"))) 216 + .await 217 + .context("failed to create did doc")?; 218 + 219 + let mut plc_doc = CarStore::create(doc) 220 + .await 221 + .context("failed to create did doc")?; 222 + 223 + let plc_cid = plc_doc 224 + .write_block(DAG_CBOR, SHA2_256, &op_bytes) 225 + .await 226 + .context("failed to write genesis commit")? 227 + .to_string(); 228 + 229 + if !config.test { 230 + // Send the new account's data to the PLC directory. 231 + plc::submit(&client, &did, &op) 232 + .await 233 + .context("failed to submit PLC operation to directory")?; 234 + } 235 + 236 + // Write out an initial commit for the user. 237 + // https://atproto.com/guides/account-lifecycle 238 + let (cid, rev, store) = async { 239 + let store = storage::create_storage_for_did(&config.repo, &did_hash) 240 + .await 241 + .context("failed to create storage")?; 242 + 243 + // Initialize the repository with the storage 244 + let repo_builder = Repository::create( 245 + store, 246 + Did::from_str(&did).expect("should be valid DID format"), 247 + ) 248 + .await 249 + .context("failed to initialize user repo")?; 250 + 251 + // Sign the root commit. 252 + let sig = skey 253 + .sign(&repo_builder.bytes()) 254 + .context("failed to sign root commit")?; 255 + let mut repo = repo_builder 256 + .finalize(sig) 257 + .await 258 + .context("failed to attach signature to root commit")?; 259 + 260 + let root = repo.root(); 261 + let rev = repo.commit().rev(); 262 + 263 + // Create a temporary CAR store for firehose events 264 + let mut mem = Vec::new(); 265 + let mut firehose_store = 266 + CarStore::create_with_roots(std::io::Cursor::new(&mut mem), [repo.root()]) 267 + .await 268 + .context("failed to create temp carstore")?; 269 + 270 + repo.export_into(&mut firehose_store) 271 + .await 272 + .context("failed to export repository")?; 273 + 274 + Ok::<(Cid, Tid, Vec<u8>), anyhow::Error>((root, rev, mem)) 275 + } 276 + .await 277 + .context("failed to create user repo")?; 278 + 279 + let cid_str = cid.to_string(); 280 + let rev_str = rev.as_str(); 281 + 282 + _ = sqlx::query!( 283 + r#" 284 + INSERT INTO accounts (did, email, password, root, plc_root, rev, created_at) 285 + VALUES (?, ?, ?, ?, ?, ?, datetime('now')); 286 + 287 + INSERT INTO handles (did, handle, created_at) 288 + VALUES (?, ?, datetime('now')); 289 + 290 + -- Cleanup stale invite codes 291 + DELETE FROM invites 292 + WHERE count <= 0; 293 + "#, 294 + did, 295 + email, 296 + pass, 297 + cid_str, 298 + plc_cid, 299 + rev_str, 300 + did, 301 + handle 302 + ) 303 + .execute(&mut *tx) 304 + .await 305 + .context("failed to create new account")?; 306 + 307 + // The account is fully created. Commit the SQL transaction to the database. 308 + tx.commit().await.context("failed to commit transaction")?; 309 + 310 + // Broadcast the identity event now that the new identity is resolvable on the public directory. 311 + fhp.identity( 312 + atrium_api::com::atproto::sync::subscribe_repos::IdentityData { 313 + did: Did::from_str(&did).expect("should be valid DID format"), 314 + handle: Some(Handle::new(handle).expect("should be valid handle")), 315 + seq: 0, // Filled by firehose later. 316 + time: Datetime::now(), 317 + }, 318 + ) 319 + .await; 320 + 321 + // The new account is now active on this PDS, so we can broadcast the account firehose event. 322 + fhp.account( 323 + atrium_api::com::atproto::sync::subscribe_repos::AccountData { 324 + active: true, 325 + did: Did::from_str(&did).expect("should be valid DID format"), 326 + seq: 0, // Filled by firehose later. 327 + status: None, // "takedown" / "suspended" / "deactivated" 328 + time: Datetime::now(), 329 + }, 330 + ) 331 + .await; 332 + 333 + let did = Did::from_str(&did).expect("should be valid DID format"); 334 + 335 + fhp.commit(Commit { 336 + car: store, 337 + ops: Vec::new(), 338 + cid, 339 + rev: rev.to_string(), 340 + did: did.clone(), 341 + pcid: None, 342 + blobs: Vec::new(), 343 + }) 344 + .await; 345 + 346 + // Finally, sign some authentication tokens for the new user. 347 + let token = auth::sign( 348 + &skey, 349 + "at+jwt", 350 + &serde_json::json!({ 351 + "scope": "com.atproto.access", 352 + "sub": did, 353 + "iat": chrono::Utc::now().timestamp(), 354 + "exp": chrono::Utc::now().checked_add_signed(chrono::Duration::hours(4)).context("should be valid time")?.timestamp(), 355 + "aud": format!("did:web:{}", config.host_name) 356 + }), 357 + ) 358 + .context("failed to sign jwt")?; 359 + 360 + let refresh_token = auth::sign( 361 + &skey, 362 + "refresh+jwt", 363 + &serde_json::json!({ 364 + "scope": "com.atproto.refresh", 365 + "sub": did, 366 + "iat": chrono::Utc::now().timestamp(), 367 + "exp": chrono::Utc::now().checked_add_days(chrono::Days::new(90)).context("should be valid time")?.timestamp(), 368 + "aud": format!("did:web:{}", config.host_name) 369 + }), 370 + ) 371 + .context("failed to sign refresh jwt")?; 372 + 373 + Ok(Json( 374 + server::create_account::OutputData { 375 + access_jwt: token, 376 + did, 377 + did_doc: None, 378 + handle: input.handle.clone(), 379 + refresh_jwt: refresh_token, 380 + } 381 + .into(), 382 + )) 383 + } 384 + 385 + /// Create an authentication session. 386 + /// - POST /xrpc/com.atproto.server.createSession 387 + /// ### Request Body 388 + /// - `identifier`: string - Handle or other identifier supported by the server for the authenticating user. 389 + /// - `password`: string - Password for the authenticating user. 390 + /// - `authFactorToken` - string (optional) 391 + /// - `allowTakedown` - boolean (optional) - When true, instead of throwing error for takendown accounts, a valid response with a narrow scoped token will be returned 392 + /// ### Responses 393 + /// - 200 OK: {"accessJwt": "string","refreshJwt": "string","handle": "string","did": "string","didDoc": {},"email": "string","emailConfirmed": true,"emailAuthFactor": true,"active": true,"status": "takendown"} 394 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `AccountTakedown`, `AuthFactorTokenRequired`]} 395 + /// - 401 Unauthorized 396 + async fn create_session( 397 + State(db): State<Db>, 398 + State(skey): State<SigningKey>, 399 + State(config): State<AppConfig>, 400 + Json(input): Json<server::create_session::Input>, 401 + ) -> Result<Json<server::create_session::Output>> { 402 + let handle = &input.identifier; 403 + let password = &input.password; 404 + 405 + // TODO: `input.allow_takedown` 406 + // TODO: `input.auth_factor_token` 407 + 408 + let Some(account) = sqlx::query!( 409 + r#" 410 + WITH LatestHandles AS ( 411 + SELECT did, handle 412 + FROM handles 413 + WHERE (did, created_at) IN ( 414 + SELECT did, MAX(created_at) AS max_created_at 415 + FROM handles 416 + GROUP BY did 417 + ) 418 + ) 419 + SELECT a.did, a.password, h.handle 420 + FROM accounts a 421 + LEFT JOIN LatestHandles h ON a.did = h.did 422 + WHERE h.handle = ? 423 + "#, 424 + handle 425 + ) 426 + .fetch_optional(&db) 427 + .await 428 + .context("failed to authenticate")? 429 + else { 430 + counter!(AUTH_FAILED).increment(1); 431 + 432 + // SEC: Call argon2's `verify_password` to simulate password verification and discard the result. 433 + // We do this to avoid exposing a timing attack where attackers can measure the response time to 434 + // determine whether or not an account exists. 435 + _ = Argon2::default().verify_password( 436 + password.as_bytes(), 437 + &PasswordHash::new(DUMMY_PASSWORD).context("should be valid password hash")?, 438 + ); 439 + 440 + return Err(Error::with_status( 441 + StatusCode::UNAUTHORIZED, 442 + anyhow!("failed to validate credentials"), 443 + )); 444 + }; 445 + 446 + match Argon2::default().verify_password( 447 + password.as_bytes(), 448 + &PasswordHash::new(account.password.as_str()).context("invalid password hash in db")?, 449 + ) { 450 + Ok(()) => {} 451 + Err(_e) => { 452 + counter!(AUTH_FAILED).increment(1); 453 + 454 + return Err(Error::with_status( 455 + StatusCode::UNAUTHORIZED, 456 + anyhow!("failed to validate credentials"), 457 + )); 458 + } 459 + } 460 + 461 + let did = account.did; 462 + 463 + let token = auth::sign( 464 + &skey, 465 + "at+jwt", 466 + &serde_json::json!({ 467 + "scope": "com.atproto.access", 468 + "sub": did, 469 + "iat": chrono::Utc::now().timestamp(), 470 + "exp": chrono::Utc::now().checked_add_signed(chrono::Duration::hours(4)).context("should be valid time")?.timestamp(), 471 + "aud": format!("did:web:{}", config.host_name) 472 + }), 473 + ) 474 + .context("failed to sign jwt")?; 475 + 476 + let refresh_token = auth::sign( 477 + &skey, 478 + "refresh+jwt", 479 + &serde_json::json!({ 480 + "scope": "com.atproto.refresh", 481 + "sub": did, 482 + "iat": chrono::Utc::now().timestamp(), 483 + "exp": chrono::Utc::now().checked_add_days(chrono::Days::new(90)).context("should be valid time")?.timestamp(), 484 + "aud": format!("did:web:{}", config.host_name) 485 + }), 486 + ) 487 + .context("failed to sign refresh jwt")?; 488 + 489 + Ok(Json( 490 + server::create_session::OutputData { 491 + access_jwt: token, 492 + refresh_jwt: refresh_token, 493 + 494 + active: Some(true), 495 + did: Did::from_str(&did).expect("should be valid DID format"), 496 + did_doc: None, 497 + email: None, 498 + email_auth_factor: None, 499 + email_confirmed: None, 500 + handle: Handle::new(account.handle).expect("should be valid handle"), 501 + status: None, 502 + } 503 + .into(), 504 + )) 505 + } 506 + 507 + /// Refresh an authentication session. Requires auth using the 'refreshJwt' (not the 'accessJwt'). 508 + /// - POST /xrpc/com.atproto.server.refreshSession 509 + /// ### Responses 510 + /// - 200 OK: {"accessJwt": "string","refreshJwt": "string","handle": "string","did": "string","didDoc": {},"active": true,"status": "takendown"} 511 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `AccountTakedown`]} 512 + /// - 401 Unauthorized 513 + async fn refresh_session( 514 + State(db): State<Db>, 515 + State(skey): State<SigningKey>, 516 + State(config): State<AppConfig>, 517 + req: Request, 518 + ) -> Result<Json<server::refresh_session::Output>> { 519 + // TODO: store hashes of refresh tokens and enforce single-use 520 + let auth_token = req 521 + .headers() 522 + .get(axum::http::header::AUTHORIZATION) 523 + .context("no authorization header provided")? 524 + .to_str() 525 + .ok() 526 + .and_then(|auth| auth.strip_prefix("Bearer ")) 527 + .context("invalid authentication token")?; 528 + 529 + let (typ, claims) = 530 + auth::verify(&skey.did(), auth_token).context("failed to verify refresh token")?; 531 + if typ != "refresh+jwt" { 532 + return Err(Error::with_status( 533 + StatusCode::UNAUTHORIZED, 534 + anyhow!("invalid refresh token"), 535 + )); 536 + } 537 + if claims 538 + .get("exp") 539 + .and_then(serde_json::Value::as_i64) 540 + .context("failed to get `exp`")? 541 + < chrono::Utc::now().timestamp() 542 + { 543 + return Err(Error::with_status( 544 + StatusCode::UNAUTHORIZED, 545 + anyhow!("refresh token expired"), 546 + )); 547 + } 548 + if claims 549 + .get("aud") 550 + .and_then(|audience| audience.as_str()) 551 + .context("invalid jwt")? 552 + != format!("did:web:{}", config.host_name) 553 + { 554 + return Err(Error::with_status( 555 + StatusCode::UNAUTHORIZED, 556 + anyhow!("invalid audience"), 557 + )); 558 + } 559 + 560 + let did = claims 561 + .get("sub") 562 + .and_then(|subject| subject.as_str()) 563 + .context("invalid jwt")?; 564 + 565 + let user = sqlx::query!( 566 + r#" 567 + SELECT a.status, h.handle 568 + FROM accounts a 569 + JOIN handles h ON a.did = h.did 570 + WHERE a.did = ? 571 + ORDER BY h.created_at ASC 572 + LIMIT 1 573 + "#, 574 + did 575 + ) 576 + .fetch_one(&db) 577 + .await 578 + .context("failed to fetch user account")?; 579 + 580 + let token = auth::sign( 581 + &skey, 582 + "at+jwt", 583 + &serde_json::json!({ 584 + "scope": "com.atproto.access", 585 + "sub": did, 586 + "iat": chrono::Utc::now().timestamp(), 587 + "exp": chrono::Utc::now().checked_add_signed(chrono::Duration::hours(4)).context("should be valid time")?.timestamp(), 588 + "aud": format!("did:web:{}", config.host_name) 589 + }), 590 + ) 591 + .context("failed to sign jwt")?; 592 + 593 + let refresh_token = auth::sign( 594 + &skey, 595 + "refresh+jwt", 596 + &serde_json::json!({ 597 + "scope": "com.atproto.refresh", 598 + "sub": did, 599 + "iat": chrono::Utc::now().timestamp(), 600 + "exp": chrono::Utc::now().checked_add_days(chrono::Days::new(90)).context("should be valid time")?.timestamp(), 601 + "aud": format!("did:web:{}", config.host_name) 602 + }), 603 + ) 604 + .context("failed to sign refresh jwt")?; 605 + 606 + let active = user.status == "active"; 607 + let status = if active { None } else { Some(user.status) }; 608 + 609 + Ok(Json( 610 + server::refresh_session::OutputData { 611 + access_jwt: token, 612 + refresh_jwt: refresh_token, 613 + 614 + active: Some(active), // TODO? 615 + did: Did::new(did.to_owned()).expect("should be valid DID format"), 616 + did_doc: None, 617 + handle: Handle::new(user.handle).expect("should be valid handle"), 618 + status, 619 + } 620 + .into(), 621 + )) 622 + } 623 + 624 + /// Get a signed token on behalf of the requesting DID for the requested service. 625 + /// - GET /xrpc/com.atproto.server.getServiceAuth 626 + /// ### Request Query Parameters 627 + /// - `aud`: string - The DID of the service that the token will be used to authenticate with 628 + /// - `exp`: integer (optional) - The time in Unix Epoch seconds that the JWT expires. Defaults to 60 seconds in the future. The service may enforce certain time bounds on tokens depending on the requested scope. 629 + /// - `lxm`: string (optional) - Lexicon (XRPC) method to bind the requested token to 630 + /// ### Responses 631 + /// - 200 OK: {token: string} 632 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `BadExpiration`]} 633 + /// - 401 Unauthorized 634 + async fn get_service_auth( 635 + user: AuthenticatedUser, 636 + State(skey): State<SigningKey>, 637 + Query(input): Query<server::get_service_auth::ParametersData>, 638 + ) -> Result<Json<server::get_service_auth::Output>> { 639 + let user_did = user.did(); 640 + let aud = input.aud.as_str(); 641 + 642 + let exp = (chrono::Utc::now().checked_add_signed(chrono::Duration::minutes(1))) 643 + .context("should be valid expiration datetime")? 644 + .timestamp(); 645 + let jti = rand::thread_rng() 646 + .sample_iter(rand::distributions::Alphanumeric) 647 + .take(10) 648 + .map(char::from) 649 + .collect::<String>(); 650 + 651 + let mut claims = serde_json::json!({ 652 + "iss": user_did.as_str(), 653 + "aud": aud, 654 + "exp": exp, 655 + "jti": jti, 656 + }); 657 + 658 + if let Some(ref lxm) = input.lxm { 659 + claims = claims 660 + .as_object_mut() 661 + .context("should be a valid object")? 662 + .insert("lxm".to_owned(), serde_json::Value::String(lxm.to_string())) 663 + .context("should be able to insert lxm into claims")?; 664 + } 665 + 666 + // Mint a bearer token by signing a JSON web token. 667 + let token = auth::sign(&skey, "JWT", &claims).context("failed to sign jwt")?; 668 + 669 + Ok(Json(server::get_service_auth::OutputData { token }.into())) 670 + } 671 + 672 + /// Get information about the current auth session. Requires auth. 673 + /// - GET /xrpc/com.atproto.server.getSession 674 + /// ### Responses 675 + /// - 200 OK: {"handle": "string","did": "string","email": "string","emailConfirmed": true,"emailAuthFactor": true,"didDoc": {},"active": true,"status": "takendown"} 676 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 677 + /// - 401 Unauthorized 678 + async fn get_session( 679 + user: AuthenticatedUser, 680 + State(db): State<Db>, 681 + ) -> Result<Json<server::get_session::Output>> { 682 + let did = user.did(); 683 + #[expect(clippy::shadow_unrelated, reason = "is related")] 684 + if let Some(user) = sqlx::query!( 685 + r#" 686 + SELECT a.email, a.status, ( 687 + SELECT h.handle 688 + FROM handles h 689 + WHERE h.did = a.did 690 + ORDER BY h.created_at ASC 691 + LIMIT 1 692 + ) AS handle 693 + FROM accounts a 694 + WHERE a.did = ? 695 + "#, 696 + did 697 + ) 698 + .fetch_optional(&db) 699 + .await 700 + .context("failed to fetch session")? 701 + { 702 + let active = user.status == "active"; 703 + let status = if active { None } else { Some(user.status) }; 704 + 705 + Ok(Json( 706 + server::get_session::OutputData { 707 + active: Some(active), 708 + did: Did::from_str(&did).expect("should be valid DID format"), 709 + did_doc: None, 710 + email: Some(user.email), 711 + email_auth_factor: None, 712 + email_confirmed: None, 713 + handle: Handle::new(user.handle).expect("should be valid handle"), 714 + status, 715 + } 716 + .into(), 717 + )) 718 + } else { 719 + Err(Error::with_status( 720 + StatusCode::UNAUTHORIZED, 721 + anyhow!("user not found"), 722 + )) 723 + } 724 + } 725 + 726 + /// Describes the server's account creation requirements and capabilities. Implemented by PDS. 727 + /// - GET /xrpc/com.atproto.server.describeServer 728 + /// ### Responses 729 + /// - 200 OK: {"inviteCodeRequired": true,"phoneVerificationRequired": true,"availableUserDomains": [`string`],"links": {"privacyPolicy": "string","termsOfService": "string"},"contact": {"email": "string"},"did": "string"} 730 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 731 + /// - 401 Unauthorized 732 + async fn describe_server( 733 + State(config): State<AppConfig>, 734 + ) -> Result<Json<server::describe_server::Output>> { 735 + Ok(Json( 736 + server::describe_server::OutputData { 737 + available_user_domains: vec![], 738 + contact: None, 739 + did: Did::from_str(&format!("did:web:{}", config.host_name)) 740 + .expect("should be valid DID format"), 741 + invite_code_required: Some(true), 742 + links: None, 743 + phone_verification_required: Some(false), // email verification 744 + } 745 + .into(), 746 + )) 747 + } 748 + 749 + async fn todo() -> Result<()> { 750 + Err(Error::unimplemented(anyhow!("not implemented"))) 751 + } 752 + 753 + #[rustfmt::skip] 754 + /// These endpoints are part of the atproto PDS server and account management APIs. \ 755 + /// Requests often require authentication and are made directly to the user's own PDS instance. 756 + /// ### Routes 757 + /// - `POST /xrpc/com.atproto.server.createAccount` -> [`create_account`] 758 + /// - `POST /xrpc/com.atproto.server.createInviteCode` -> [`create_invite_code`] 759 + /// - `POST /xrpc/com.atproto.server.createSession` -> [`create_session`] 760 + /// - `GET /xrpc/com.atproto.server.describeServer` -> [`describe_server`] 761 + /// - `GET /xrpc/com.atproto.server.getServiceAuth` -> [`get_service_auth`] 762 + /// - `GET /xrpc/com.atproto.server.getSession` -> [`get_session`] 763 + /// - `POST /xrpc/com.atproto.server.refreshSession` -> [`refresh_session`] 764 + pub(super) fn routes() -> Router<AppState> { 765 + Router::new() 766 + .route(concat!("/", server::activate_account::NSID), post(todo)) 767 + .route(concat!("/", server::check_account_status::NSID), post(todo)) 768 + .route(concat!("/", server::confirm_email::NSID), post(todo)) 769 + .route(concat!("/", server::create_account::NSID), post(create_account)) 770 + .route(concat!("/", server::create_app_password::NSID), post(todo)) 771 + .route(concat!("/", server::create_invite_code::NSID), post(create_invite_code)) 772 + .route(concat!("/", server::create_invite_codes::NSID), post(todo)) 773 + .route(concat!("/", server::create_session::NSID), post(create_session)) 774 + .route(concat!("/", server::deactivate_account::NSID), post(todo)) 775 + .route(concat!("/", server::delete_account::NSID), post(todo)) 776 + .route(concat!("/", server::delete_session::NSID), post(todo)) 777 + .route(concat!("/", server::describe_server::NSID), get(describe_server)) 778 + .route(concat!("/", server::get_account_invite_codes::NSID), post(todo)) 779 + .route(concat!("/", server::get_service_auth::NSID), get(get_service_auth)) 780 + .route(concat!("/", server::get_session::NSID), get(get_session)) 781 + .route(concat!("/", server::list_app_passwords::NSID), post(todo)) 782 + .route(concat!("/", server::refresh_session::NSID), post(refresh_session)) 783 + .route(concat!("/", server::request_account_delete::NSID), post(todo)) 784 + .route(concat!("/", server::request_email_confirmation::NSID), post(todo)) 785 + .route(concat!("/", server::request_email_update::NSID), post(todo)) 786 + .route(concat!("/", server::request_password_reset::NSID), post(todo)) 787 + .route(concat!("/", server::reserve_signing_key::NSID), post(todo)) 788 + .route(concat!("/", server::reset_password::NSID), post(todo)) 789 + .route(concat!("/", server::revoke_app_password::NSID), post(todo)) 790 + .route(concat!("/", server::update_email::NSID), post(todo)) 791 + }
+428
src/apis/com/atproto/sync/sync.rs
··· 1 + //! Endpoints for the `ATProto` sync API. (/xrpc/com.atproto.sync.*) 2 + use std::str::FromStr as _; 3 + 4 + use anyhow::{Context as _, anyhow}; 5 + use atrium_api::{ 6 + com::atproto::sync, 7 + types::{LimitedNonZeroU16, string::Did}, 8 + }; 9 + use atrium_repo::{ 10 + Cid, 11 + blockstore::{ 12 + AsyncBlockStoreRead as _, AsyncBlockStoreWrite as _, CarStore, DAG_CBOR, SHA2_256, 13 + }, 14 + }; 15 + use axum::{ 16 + Json, Router, 17 + body::Body, 18 + extract::{Query, State, WebSocketUpgrade}, 19 + http::{self, Response, StatusCode}, 20 + response::IntoResponse, 21 + routing::get, 22 + }; 23 + use constcat::concat; 24 + use futures::stream::TryStreamExt as _; 25 + use tokio_util::io::ReaderStream; 26 + 27 + use crate::{ 28 + AppState, Db, Error, Result, 29 + config::AppConfig, 30 + firehose::FirehoseProducer, 31 + storage::{open_repo_db, open_store}, 32 + }; 33 + 34 + #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 35 + #[serde(rename_all = "camelCase")] 36 + /// Parameters for `/xrpc/com.atproto.sync.listBlobs` \ 37 + /// HACK: `limit` may be passed as a string, so we must treat it as one. 38 + pub(super) struct ListBlobsParameters { 39 + #[serde(skip_serializing_if = "core::option::Option::is_none")] 40 + /// Optional cursor to paginate through blobs. 41 + pub cursor: Option<String>, 42 + ///The DID of the repo. 43 + pub did: Did, 44 + #[serde(skip_serializing_if = "core::option::Option::is_none")] 45 + /// Optional limit of blobs to return. 46 + pub limit: Option<String>, 47 + ///Optional revision of the repo to list blobs since. 48 + #[serde(skip_serializing_if = "core::option::Option::is_none")] 49 + pub since: Option<String>, 50 + } 51 + #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 52 + #[serde(rename_all = "camelCase")] 53 + /// Parameters for `/xrpc/com.atproto.sync.listRepos` \ 54 + /// HACK: `limit` may be passed as a string, so we must treat it as one. 55 + pub(super) struct ListReposParameters { 56 + #[serde(skip_serializing_if = "core::option::Option::is_none")] 57 + /// Optional cursor to paginate through repos. 58 + pub cursor: Option<String>, 59 + #[serde(skip_serializing_if = "core::option::Option::is_none")] 60 + /// Optional limit of repos to return. 61 + pub limit: Option<String>, 62 + } 63 + #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 64 + #[serde(rename_all = "camelCase")] 65 + /// Parameters for `/xrpc/com.atproto.sync.subscribeRepos` \ 66 + /// HACK: `cursor` may be passed as a string, so we must treat it as one. 67 + pub(super) struct SubscribeReposParametersData { 68 + ///The last known event seq number to backfill from. 69 + #[serde(skip_serializing_if = "core::option::Option::is_none")] 70 + pub cursor: Option<String>, 71 + } 72 + 73 + async fn get_blob( 74 + State(config): State<AppConfig>, 75 + Query(input): Query<sync::get_blob::ParametersData>, 76 + ) -> Result<Response<Body>> { 77 + let blob = config 78 + .blob 79 + .path 80 + .join(format!("{}.blob", input.cid.as_ref())); 81 + 82 + let f = tokio::fs::File::open(blob) 83 + .await 84 + .context("blob not found")?; 85 + let len = f 86 + .metadata() 87 + .await 88 + .context("failed to query file metadata")? 89 + .len(); 90 + 91 + let s = ReaderStream::new(f); 92 + 93 + Ok(Response::builder() 94 + .header(http::header::CONTENT_LENGTH, format!("{len}")) 95 + .body(Body::from_stream(s)) 96 + .context("failed to construct response")?) 97 + } 98 + 99 + /// Enumerates which accounts the requesting account is currently blocking. Requires auth. 100 + /// - GET /xrpc/com.atproto.sync.getBlocks 101 + /// ### Query Parameters 102 + /// - `limit`: integer, optional, default: 50, >=1 and <=100 103 + /// - `cursor`: string, optional 104 + /// ### Responses 105 + /// - 200 OK: ... 106 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 107 + /// - 401 Unauthorized 108 + async fn get_blocks( 109 + State(config): State<AppConfig>, 110 + Query(input): Query<sync::get_blocks::ParametersData>, 111 + ) -> Result<Response<Body>> { 112 + let mut repo = open_store(&config.repo, input.did.as_str()) 113 + .await 114 + .context("failed to open repository")?; 115 + 116 + let mut mem = Vec::new(); 117 + let mut store = CarStore::create(std::io::Cursor::new(&mut mem)) 118 + .await 119 + .context("failed to create intermediate carstore")?; 120 + 121 + for cid in &input.cids { 122 + // SEC: This can potentially fetch stale blocks from a repository (e.g. those that were deleted). 123 + // We'll want to prevent accesses to stale blocks eventually just to respect a user's right to be forgotten. 124 + _ = store 125 + .write_block( 126 + DAG_CBOR, 127 + SHA2_256, 128 + &repo 129 + .read_block(*cid.as_ref()) 130 + .await 131 + .context("failed to read block")?, 132 + ) 133 + .await 134 + .context("failed to write block")?; 135 + } 136 + 137 + Ok(Response::builder() 138 + .header(http::header::CONTENT_TYPE, "application/vnd.ipld.car") 139 + .body(Body::from(mem)) 140 + .context("failed to construct response")?) 141 + } 142 + 143 + /// Get the current commit CID & revision of the specified repo. Does not require auth. 144 + /// ### Query Parameters 145 + /// - `did`: The DID of the repo. 146 + /// ### Responses 147 + /// - 200 OK: {"cid": "string","rev": "string"} 148 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoTakendown`, `RepoSuspended`, `RepoDeactivated`]} 149 + async fn get_latest_commit( 150 + State(config): State<AppConfig>, 151 + State(db): State<Db>, 152 + Query(input): Query<sync::get_latest_commit::ParametersData>, 153 + ) -> Result<Json<sync::get_latest_commit::Output>> { 154 + let repo = open_repo_db(&config.repo, &db, input.did.as_str()) 155 + .await 156 + .context("failed to open repository")?; 157 + 158 + let cid = repo.root(); 159 + let commit = repo.commit(); 160 + 161 + Ok(Json( 162 + sync::get_latest_commit::OutputData { 163 + cid: atrium_api::types::string::Cid::new(cid), 164 + rev: commit.rev(), 165 + } 166 + .into(), 167 + )) 168 + } 169 + 170 + /// Get data blocks needed to prove the existence or non-existence of record in the current version of repo. Does not require auth. 171 + /// ### Query Parameters 172 + /// - `did`: The DID of the repo. 173 + /// - `collection`: nsid 174 + /// - `rkey`: record-key 175 + /// ### Responses 176 + /// - 200 OK: ... 177 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`, `RepoNotFound`, `RepoTakendown`, 178 + /// `RepoSuspended`, `RepoDeactivated`]} 179 + async fn get_record( 180 + State(config): State<AppConfig>, 181 + State(db): State<Db>, 182 + Query(input): Query<sync::get_record::ParametersData>, 183 + ) -> Result<Response<Body>> { 184 + let mut repo = open_repo_db(&config.repo, &db, input.did.as_str()) 185 + .await 186 + .context("failed to open repo")?; 187 + 188 + let key = format!("{}/{}", input.collection.as_str(), input.rkey.as_str()); 189 + 190 + let mut contents = Vec::new(); 191 + let mut ret_store = 192 + CarStore::create_with_roots(std::io::Cursor::new(&mut contents), [repo.root()]) 193 + .await 194 + .context("failed to create car store")?; 195 + 196 + repo.extract_raw_into(&key, &mut ret_store) 197 + .await 198 + .context("failed to extract records")?; 199 + 200 + Ok(Response::builder() 201 + .header(http::header::CONTENT_TYPE, "application/vnd.ipld.car") 202 + .body(Body::from(contents)) 203 + .context("failed to construct response")?) 204 + } 205 + 206 + /// Get the hosting status for a repository, on this server. Expected to be implemented by PDS and Relay. 207 + /// ### Query Parameters 208 + /// - `did`: The DID of the repo. 209 + /// ### Responses 210 + /// - 200 OK: {"did": "string","active": true,"status": "takendown","rev": "string"} 211 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoNotFound`]} 212 + async fn get_repo_status( 213 + State(db): State<Db>, 214 + Query(input): Query<sync::get_repo::ParametersData>, 215 + ) -> Result<Json<sync::get_repo_status::Output>> { 216 + let did = input.did.as_str(); 217 + let r = sqlx::query!(r#"SELECT rev, status FROM accounts WHERE did = ?"#, did) 218 + .fetch_optional(&db) 219 + .await 220 + .context("failed to execute query")?; 221 + 222 + let Some(r) = r else { 223 + return Err(Error::with_status( 224 + StatusCode::NOT_FOUND, 225 + anyhow!("account not found"), 226 + )); 227 + }; 228 + 229 + let active = r.status == "active"; 230 + let status = if active { None } else { Some(r.status) }; 231 + 232 + Ok(Json( 233 + sync::get_repo_status::OutputData { 234 + active, 235 + status, 236 + did: input.did.clone(), 237 + rev: Some( 238 + atrium_api::types::string::Tid::new(r.rev).expect("should be able to convert Tid"), 239 + ), 240 + } 241 + .into(), 242 + )) 243 + } 244 + 245 + /// Download a repository export as CAR file. Optionally only a 'diff' since a previous revision. 246 + /// Does not require auth; implemented by PDS. 247 + /// ### Query Parameters 248 + /// - `did`: The DID of the repo. 249 + /// - `since`: The revision ('rev') of the repo to create a diff from. 250 + /// ### Responses 251 + /// - 200 OK: ... 252 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoNotFound`, 253 + /// `RepoTakendown`, `RepoSuspended`, `RepoDeactivated`]} 254 + async fn get_repo( 255 + State(config): State<AppConfig>, 256 + State(db): State<Db>, 257 + Query(input): Query<sync::get_repo::ParametersData>, 258 + ) -> Result<Response<Body>> { 259 + let mut repo = open_repo_db(&config.repo, &db, input.did.as_str()) 260 + .await 261 + .context("failed to open repo")?; 262 + 263 + let mut contents = Vec::new(); 264 + let mut store = CarStore::create_with_roots(std::io::Cursor::new(&mut contents), [repo.root()]) 265 + .await 266 + .context("failed to create car store")?; 267 + 268 + repo.export_into(&mut store) 269 + .await 270 + .context("failed to extract records")?; 271 + 272 + Ok(Response::builder() 273 + .header(http::header::CONTENT_TYPE, "application/vnd.ipld.car") 274 + .body(Body::from(contents)) 275 + .context("failed to construct response")?) 276 + } 277 + 278 + /// List blob CIDs for an account, since some repo revision. Does not require auth; implemented by PDS. 279 + /// ### Query Parameters 280 + /// - `did`: The DID of the repo. Required. 281 + /// - `since`: Optional revision of the repo to list blobs since. 282 + /// - `limit`: >= 1 and <= 1000, default 500 283 + /// - `cursor`: string 284 + /// ### Responses 285 + /// - 200 OK: {"cursor": "string","cids": [string]} 286 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoNotFound`, `RepoTakendown`, 287 + /// `RepoSuspended`, `RepoDeactivated`]} 288 + async fn list_blobs( 289 + State(db): State<Db>, 290 + Query(input): Query<sync::list_blobs::ParametersData>, 291 + ) -> Result<Json<sync::list_blobs::Output>> { 292 + let did_str = input.did.as_str(); 293 + 294 + // TODO: `input.since` 295 + // TODO: `input.limit` 296 + // TODO: `input.cursor` 297 + 298 + let cids = sqlx::query_scalar!(r#"SELECT cid FROM blob_ref WHERE did = ?"#, did_str) 299 + .fetch_all(&db) 300 + .await 301 + .context("failed to query blobs")?; 302 + 303 + let cids = cids 304 + .into_iter() 305 + .map(|c| { 306 + Cid::from_str(&c) 307 + .map(atrium_api::types::string::Cid::new) 308 + .map_err(anyhow::Error::new) 309 + }) 310 + .collect::<anyhow::Result<Vec<_>>>() 311 + .context("failed to convert cids")?; 312 + 313 + Ok(Json( 314 + sync::list_blobs::OutputData { cursor: None, cids }.into(), 315 + )) 316 + } 317 + 318 + /// Enumerates all the DID, rev, and commit CID for all repos hosted by this service. 319 + /// Does not require auth; implemented by PDS and Relay. 320 + /// ### Query Parameters 321 + /// - `limit`: >= 1 and <= 1000, default 500 322 + /// - `cursor`: string 323 + /// ### Responses 324 + /// - 200 OK: {"cursor": "string","repos": [{"did": "string","head": "string","rev": "string","active": true,"status": "takendown"}]} 325 + /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 326 + async fn list_repos( 327 + State(db): State<Db>, 328 + Query(input): Query<sync::list_repos::ParametersData>, 329 + ) -> Result<Json<sync::list_repos::Output>> { 330 + struct Record { 331 + /// The DID of the repo. 332 + did: String, 333 + /// The commit CID of the repo. 334 + rev: String, 335 + /// The root CID of the repo. 336 + root: String, 337 + } 338 + 339 + let limit: u16 = input.limit.unwrap_or(LimitedNonZeroU16::MAX).into(); 340 + 341 + let r = if let Some(ref cursor) = input.cursor { 342 + let r = sqlx::query_as!( 343 + Record, 344 + r#"SELECT did, root, rev FROM accounts WHERE did > ? LIMIT ?"#, 345 + cursor, 346 + limit 347 + ) 348 + .fetch(&db); 349 + 350 + r.try_collect::<Vec<_>>() 351 + .await 352 + .context("failed to fetch profiles")? 353 + } else { 354 + let r = sqlx::query_as!( 355 + Record, 356 + r#"SELECT did, root, rev FROM accounts LIMIT ?"#, 357 + limit 358 + ) 359 + .fetch(&db); 360 + 361 + r.try_collect::<Vec<_>>() 362 + .await 363 + .context("failed to fetch profiles")? 364 + }; 365 + 366 + let cursor = r.last().map(|r| r.did.clone()); 367 + let repos = r 368 + .into_iter() 369 + .map(|r| { 370 + sync::list_repos::RepoData { 371 + active: Some(true), 372 + did: Did::new(r.did).expect("should be a valid DID"), 373 + head: atrium_api::types::string::Cid::new( 374 + Cid::from_str(&r.root).expect("should be a valid CID"), 375 + ), 376 + rev: atrium_api::types::string::Tid::new(r.rev) 377 + .expect("should be able to convert Tid"), 378 + status: None, 379 + } 380 + .into() 381 + }) 382 + .collect::<Vec<_>>(); 383 + 384 + Ok(Json(sync::list_repos::OutputData { cursor, repos }.into())) 385 + } 386 + 387 + /// Repository event stream, aka Firehose endpoint. Outputs repo commits with diff data, and identity update events, 388 + /// for all repositories on the current server. See the atproto specifications for details around stream sequencing, 389 + /// repo versioning, CAR diff format, and more. Public and does not require auth; implemented by PDS and Relay. 390 + /// ### Query Parameters 391 + /// - `cursor`: The last known event seq number to backfill from. 392 + /// ### Responses 393 + /// - 200 OK: ... 394 + async fn subscribe_repos( 395 + ws_up: WebSocketUpgrade, 396 + State(fh): State<FirehoseProducer>, 397 + Query(input): Query<sync::subscribe_repos::ParametersData>, 398 + ) -> impl IntoResponse { 399 + ws_up.on_upgrade(async move |ws| { 400 + fh.client_connection(ws, input.cursor).await; 401 + }) 402 + } 403 + 404 + #[rustfmt::skip] 405 + /// These endpoints are part of the atproto repository synchronization APIs. Requests usually do not require authentication, 406 + /// and can be made to PDS intances or Relay instances. 407 + /// ### Routes 408 + /// - `GET /xrpc/com.atproto.sync.getBlob` -> [`get_blob`] 409 + /// - `GET /xrpc/com.atproto.sync.getBlocks` -> [`get_blocks`] 410 + /// - `GET /xrpc/com.atproto.sync.getLatestCommit` -> [`get_latest_commit`] 411 + /// - `GET /xrpc/com.atproto.sync.getRecord` -> [`get_record`] 412 + /// - `GET /xrpc/com.atproto.sync.getRepoStatus` -> [`get_repo_status`] 413 + /// - `GET /xrpc/com.atproto.sync.getRepo` -> [`get_repo`] 414 + /// - `GET /xrpc/com.atproto.sync.listBlobs` -> [`list_blobs`] 415 + /// - `GET /xrpc/com.atproto.sync.listRepos` -> [`list_repos`] 416 + /// - `GET /xrpc/com.atproto.sync.subscribeRepos` -> [`subscribe_repos`] 417 + pub(super) fn routes() -> Router<AppState> { 418 + Router::new() 419 + .route(concat!("/", sync::get_blob::NSID), get(get_blob)) 420 + .route(concat!("/", sync::get_blocks::NSID), get(get_blocks)) 421 + .route(concat!("/", sync::get_latest_commit::NSID), get(get_latest_commit)) 422 + .route(concat!("/", sync::get_record::NSID), get(get_record)) 423 + .route(concat!("/", sync::get_repo_status::NSID), get(get_repo_status)) 424 + .route(concat!("/", sync::get_repo::NSID), get(get_repo)) 425 + .route(concat!("/", sync::list_blobs::NSID), get(list_blobs)) 426 + .route(concat!("/", sync::list_repos::NSID), get(list_repos)) 427 + .route(concat!("/", sync::subscribe_repos::NSID), get(subscribe_repos)) 428 + }
+1
src/apis/com/mod.rs
··· 1 + pub mod atproto;
+27
src/apis/mod.rs
··· 1 + //! Root module for all endpoints. 2 + // mod identity; 3 + mod com; 4 + // mod server; 5 + // mod sync; 6 + 7 + use axum::{Json, Router, routing::get}; 8 + use serde_json::json; 9 + 10 + use crate::serve::{AppState, Result}; 11 + 12 + /// Health check endpoint. Returns name and version of the service. 13 + pub(crate) async fn health() -> Result<Json<serde_json::Value>> { 14 + Ok(Json(json!({ 15 + "version": concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")), 16 + }))) 17 + } 18 + 19 + /// Register all root routes. 20 + pub(crate) fn routes() -> Router<AppState> { 21 + Router::new() 22 + .route("/_health", get(health)) 23 + // .merge(identity::routes()) // com.atproto.identity 24 + .merge(com::atproto::repo::routes()) // com.atproto.repo 25 + // .merge(server::routes()) // com.atproto.server 26 + // .merge(sync::routes()) // com.atproto.sync 27 + }
+9 -6
src/auth.rs
··· 8 8 use diesel::prelude::*; 9 9 use sha2::{Digest as _, Sha256}; 10 10 11 - use crate::{AppState, Error, error::ErrorMessage}; 11 + use crate::{ 12 + error::{Error, ErrorMessage}, 13 + serve::AppState, 14 + }; 12 15 13 16 /// Request extractor for authenticated users. 14 17 /// If specified in an API endpoint, this guarantees the API can only be called ··· 130 133 131 134 // Extract subject (DID) 132 135 if let Some(did) = claims.get("sub").and_then(serde_json::Value::as_str) { 133 - use rsky_pds::schema::pds::account::dsl as AccountSchema; 136 + use crate::schema::pds::account::dsl as AccountSchema; 134 137 let did_clone = did.to_owned(); 135 138 136 139 let _did = state ··· 341 344 use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema; 342 345 343 346 // Check if JTI has been used before 344 - let jti_string = jti.to_string(); 347 + let jti_string = jti.to_owned(); 345 348 let jti_used = state 346 349 .db 347 350 .get() ··· 372 375 .unwrap_or_else(|| timestamp.checked_add(60).unwrap_or(timestamp)); 373 376 374 377 // Convert SQLx INSERT to Diesel 375 - let jti_str = jti.to_string(); 378 + let jti_str = jti.to_owned(); 376 379 let thumbprint_str = calculated_thumbprint.to_string(); 377 - state 380 + let _ = state 378 381 .db 379 382 .get() 380 383 .await ··· 395 398 396 399 // Extract subject (DID) from access token 397 400 if let Some(did) = claims.get("sub").and_then(|v| v.as_str()) { 398 - use rsky_pds::schema::pds::account::dsl as AccountSchema; 401 + use crate::schema::pds::account::dsl as AccountSchema; 399 402 400 403 let did_clone = did.to_owned(); 401 404
+1 -1
src/did.rs
··· 5 5 use serde::{Deserialize, Serialize}; 6 6 use url::Url; 7 7 8 - use crate::Client; 8 + use crate::serve::Client; 9 9 10 10 /// URL whitelist for DID document resolution. 11 11 const ALLOWED_URLS: &[&str] = &["bsky.app", "bsky.chat"];
-245
src/endpoints/identity.rs
··· 1 - //! Identity endpoints (/xrpc/com.atproto.identity.*) 2 - use std::collections::HashMap; 3 - 4 - use anyhow::{Context as _, anyhow}; 5 - use atrium_api::{ 6 - com::atproto::identity, 7 - types::string::{Datetime, Handle}, 8 - }; 9 - use atrium_crypto::keypair::Did as _; 10 - use atrium_repo::blockstore::{AsyncBlockStoreWrite as _, CarStore, DAG_CBOR, SHA2_256}; 11 - use axum::{ 12 - Json, Router, 13 - extract::{Query, State}, 14 - http::StatusCode, 15 - routing::{get, post}, 16 - }; 17 - use constcat::concat; 18 - 19 - use crate::{ 20 - AppState, Client, Db, Error, Result, RotationKey, SigningKey, 21 - auth::AuthenticatedUser, 22 - config::AppConfig, 23 - did, 24 - firehose::FirehoseProducer, 25 - plc::{self, PlcOperation, PlcService}, 26 - }; 27 - 28 - /// (GET) Resolves an atproto handle (hostname) to a DID. Does not necessarily bi-directionally verify against the the DID document. 29 - /// ### Query Parameters 30 - /// - handle: The handle to resolve. 31 - /// ### Responses 32 - /// - 200 OK: {did: did} 33 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `HandleNotFound`]} 34 - /// - 401 Unauthorized 35 - async fn resolve_handle( 36 - State(db): State<Db>, 37 - State(client): State<Client>, 38 - Query(input): Query<identity::resolve_handle::ParametersData>, 39 - ) -> Result<Json<identity::resolve_handle::Output>> { 40 - let handle = input.handle.as_str(); 41 - if let Ok(did) = sqlx::query_scalar!(r#"SELECT did FROM handles WHERE handle = ?"#, handle) 42 - .fetch_one(&db) 43 - .await 44 - { 45 - return Ok(Json( 46 - identity::resolve_handle::OutputData { 47 - did: atrium_api::types::string::Did::new(did).expect("should be valid DID format"), 48 - } 49 - .into(), 50 - )); 51 - } 52 - 53 - // HACK: Query bsky to see if they have this handle cached. 54 - let response = client 55 - .get(format!( 56 - "https://api.bsky.app/xrpc/com.atproto.identity.resolveHandle?handle={handle}" 57 - )) 58 - .send() 59 - .await 60 - .context("failed to query upstream server")? 61 - .json() 62 - .await 63 - .context("failed to decode response as JSON")?; 64 - 65 - Ok(Json(response)) 66 - } 67 - 68 - #[expect(unused_variables, clippy::todo, reason = "Not yet implemented")] 69 - /// Request an email with a code to in order to request a signed PLC operation. Requires Auth. 70 - /// - POST /xrpc/com.atproto.identity.requestPlcOperationSignature 71 - /// ### Responses 72 - /// - 200 OK 73 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 74 - /// - 401 Unauthorized 75 - async fn request_plc_operation_signature(user: AuthenticatedUser) -> Result<()> { 76 - todo!() 77 - } 78 - 79 - #[expect(unused_variables, clippy::todo, reason = "Not yet implemented")] 80 - /// Signs a PLC operation to update some value(s) in the requesting DID's document. 81 - /// - POST /xrpc/com.atproto.identity.signPlcOperation 82 - /// ### Request Body 83 - /// - token: string // A token received through com.atproto.identity.requestPlcOperationSignature 84 - /// - rotationKeys: string[] 85 - /// - alsoKnownAs: string[] 86 - /// - verificationMethods: services 87 - /// ### Responses 88 - /// - 200 OK: {operation: string} 89 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 90 - /// - 401 Unauthorized 91 - async fn sign_plc_operation( 92 - user: AuthenticatedUser, 93 - State(skey): State<SigningKey>, 94 - State(rkey): State<RotationKey>, 95 - State(config): State<AppConfig>, 96 - Json(input): Json<identity::sign_plc_operation::Input>, 97 - ) -> Result<Json<identity::sign_plc_operation::Output>> { 98 - todo!() 99 - } 100 - 101 - #[expect( 102 - clippy::too_many_arguments, 103 - reason = "Many parameters are required for this endpoint" 104 - )] 105 - /// Updates the current account's handle. Verifies handle validity, and updates did:plc document if necessary. Implemented by PDS, and requires auth. 106 - /// - POST /xrpc/com.atproto.identity.updateHandle 107 - /// ### Query Parameters 108 - /// - handle: handle // The new handle. 109 - /// ### Responses 110 - /// - 200 OK 111 - /// ## Errors 112 - /// - If the handle is already in use. 113 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 114 - /// - 401 Unauthorized 115 - /// ## Panics 116 - /// - If the handle is not valid. 117 - async fn update_handle( 118 - user: AuthenticatedUser, 119 - State(skey): State<SigningKey>, 120 - State(rkey): State<RotationKey>, 121 - State(client): State<Client>, 122 - State(config): State<AppConfig>, 123 - State(db): State<Db>, 124 - State(fhp): State<FirehoseProducer>, 125 - Json(input): Json<identity::update_handle::Input>, 126 - ) -> Result<()> { 127 - let handle = input.handle.as_str(); 128 - let did_str = user.did(); 129 - let did = atrium_api::types::string::Did::new(user.did()).expect("should be valid DID format"); 130 - 131 - if let Some(existing_did) = 132 - sqlx::query_scalar!(r#"SELECT did FROM handles WHERE handle = ?"#, handle) 133 - .fetch_optional(&db) 134 - .await 135 - .context("failed to query did count")? 136 - { 137 - if existing_did != did_str { 138 - return Err(Error::with_status( 139 - StatusCode::BAD_REQUEST, 140 - anyhow!("attempted to update handle to one that is already in use"), 141 - )); 142 - } 143 - } 144 - 145 - // Ensure the existing DID is resolvable. 146 - // If not, we need to register the original handle. 147 - let _did = did::resolve(&client, did.clone()) 148 - .await 149 - .with_context(|| format!("failed to resolve DID for {did_str}")) 150 - .context("should be able to resolve DID")?; 151 - 152 - let op = plc::sign_op( 153 - &rkey, 154 - PlcOperation { 155 - typ: "plc_operation".to_owned(), 156 - rotation_keys: vec![rkey.did()], 157 - verification_methods: HashMap::from([("atproto".to_owned(), skey.did())]), 158 - also_known_as: vec![input.handle.as_str().to_owned()], 159 - services: HashMap::from([( 160 - "atproto_pds".to_owned(), 161 - PlcService::Pds { 162 - endpoint: config.host_name.clone(), 163 - }, 164 - )]), 165 - prev: Some( 166 - sqlx::query_scalar!(r#"SELECT plc_root FROM accounts WHERE did = ?"#, did_str) 167 - .fetch_one(&db) 168 - .await 169 - .context("failed to fetch user PLC root")?, 170 - ), 171 - }, 172 - ) 173 - .context("failed to sign plc op")?; 174 - 175 - if !config.test { 176 - plc::submit(&client, did.as_str(), &op) 177 - .await 178 - .context("failed to submit PLC operation")?; 179 - } 180 - 181 - // FIXME: Properly abstract these implementation details. 182 - let did_hash = did_str 183 - .strip_prefix("did:plc:") 184 - .context("should be valid DID format")?; 185 - let doc = tokio::fs::File::options() 186 - .read(true) 187 - .write(true) 188 - .open(config.plc.path.join(format!("{did_hash}.car"))) 189 - .await 190 - .context("failed to open did doc")?; 191 - 192 - let op_bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode plc op")?; 193 - 194 - let plc_cid = CarStore::open(doc) 195 - .await 196 - .context("failed to open did carstore")? 197 - .write_block(DAG_CBOR, SHA2_256, &op_bytes) 198 - .await 199 - .context("failed to write genesis commit")?; 200 - 201 - let cid_str = plc_cid.to_string(); 202 - 203 - _ = sqlx::query!( 204 - r#"UPDATE accounts SET plc_root = ? WHERE did = ?"#, 205 - cid_str, 206 - did_str 207 - ) 208 - .execute(&db) 209 - .await 210 - .context("failed to update account PLC root")?; 211 - 212 - // Broadcast the identity event now that the new identity is resolvable on the public directory. 213 - fhp.identity( 214 - atrium_api::com::atproto::sync::subscribe_repos::IdentityData { 215 - did: did.clone(), 216 - handle: Some(Handle::new(handle.to_owned()).expect("should be valid handle")), 217 - seq: 0, // Filled by firehose later. 218 - time: Datetime::now(), 219 - }, 220 - ) 221 - .await; 222 - 223 - Ok(()) 224 - } 225 - 226 - async fn todo() -> Result<()> { 227 - Err(Error::unimplemented(anyhow!("not implemented"))) 228 - } 229 - 230 - #[rustfmt::skip] 231 - /// Identity endpoints (/xrpc/com.atproto.identity.*) 232 - /// ### Routes 233 - /// - AP /xrpc/com.atproto.identity.updateHandle -> [`update_handle`] 234 - /// - AP /xrpc/com.atproto.identity.requestPlcOperationSignature -> [`request_plc_operation_signature`] 235 - /// - AP /xrpc/com.atproto.identity.signPlcOperation -> [`sign_plc_operation`] 236 - /// - UG /xrpc/com.atproto.identity.resolveHandle -> [`resolve_handle`] 237 - pub(super) fn routes() -> Router<AppState> { 238 - Router::new() 239 - .route(concat!("/", identity::get_recommended_did_credentials::NSID), get(todo)) 240 - .route(concat!("/", identity::request_plc_operation_signature::NSID), post(request_plc_operation_signature)) 241 - .route(concat!("/", identity::resolve_handle::NSID), get(resolve_handle)) 242 - .route(concat!("/", identity::sign_plc_operation::NSID), post(sign_plc_operation)) 243 - .route(concat!("/", identity::submit_plc_operation::NSID), post(todo)) 244 - .route(concat!("/", identity::update_handle::NSID), post(update_handle)) 245 - }
-26
src/endpoints/mod.rs
··· 1 - //! Root module for all endpoints. 2 - // mod identity; 3 - // mod repo; 4 - // mod server; 5 - // mod sync; 6 - 7 - use axum::{Json, Router, routing::get}; 8 - use serde_json::json; 9 - 10 - use crate::{AppState, Result}; 11 - 12 - /// Health check endpoint. Returns name and version of the service. 13 - pub(crate) async fn health() -> Result<Json<serde_json::Value>> { 14 - Ok(Json(json!({ 15 - "version": concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")), 16 - }))) 17 - } 18 - 19 - /// Register all root routes. 20 - pub(crate) fn routes() -> Router<AppState> { 21 - Router::new().route("/_health", get(health)) 22 - // .merge(identity::routes()) // com.atproto.identity 23 - // .merge(repo::routes()) // com.atproto.repo 24 - // .merge(server::routes()) // com.atproto.server 25 - // .merge(sync::routes()) // com.atproto.sync 26 - }
-195
src/endpoints/repo/apply_writes.rs
··· 1 - //! Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 2 - use crate::{ 3 - ActorPools, AppState, Db, Error, Result, SigningKey, 4 - actor_store::{ActorStore, sql_blob::BlobStoreSql}, 5 - auth::AuthenticatedUser, 6 - config::AppConfig, 7 - error::ErrorMessage, 8 - firehose::{self, FirehoseProducer, RepoOp}, 9 - metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE}, 10 - storage, 11 - }; 12 - use anyhow::bail; 13 - use anyhow::{Context as _, anyhow}; 14 - use atrium_api::com::atproto::repo::apply_writes::{self, InputWritesItem, OutputResultsItem}; 15 - use atrium_api::{ 16 - com::atproto::repo::{self, defs::CommitMetaData}, 17 - types::{ 18 - LimitedU32, Object, TryFromUnknown as _, TryIntoUnknown as _, Unknown, 19 - string::{AtIdentifier, Nsid, Tid}, 20 - }, 21 - }; 22 - use atrium_repo::blockstore::CarStore; 23 - use axum::{ 24 - Json, Router, 25 - body::Body, 26 - extract::{Query, Request, State}, 27 - http::{self, StatusCode}, 28 - routing::{get, post}, 29 - }; 30 - use cidv10::Cid; 31 - use constcat::concat; 32 - use futures::TryStreamExt as _; 33 - use futures::stream::{self, StreamExt}; 34 - use metrics::counter; 35 - use rsky_lexicon::com::atproto::repo::{ApplyWritesInput, ApplyWritesInputRefWrite}; 36 - use rsky_pds::SharedSequencer; 37 - use rsky_pds::account_manager::AccountManager; 38 - use rsky_pds::account_manager::helpers::account::AvailabilityFlags; 39 - use rsky_pds::apis::ApiError; 40 - use rsky_pds::auth_verifier::AccessStandardIncludeChecks; 41 - use rsky_pds::repo::prepare::{ 42 - PrepareCreateOpts, PrepareDeleteOpts, PrepareUpdateOpts, prepare_create, prepare_delete, 43 - prepare_update, 44 - }; 45 - use rsky_repo::types::PreparedWrite; 46 - use rsky_syntax::aturi::AtUri; 47 - use serde::Deserialize; 48 - use std::{collections::HashSet, str::FromStr}; 49 - use tokio::io::AsyncWriteExt as _; 50 - 51 - use super::resolve_did; 52 - 53 - /// Apply a batch transaction of repository creates, updates, and deletes. Requires auth, implemented by PDS. 54 - /// - POST /xrpc/com.atproto.repo.applyWrites 55 - /// ### Request Body 56 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 57 - /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data across all operations, 'true' to require it, or leave unset to validate only for known Lexicons. 58 - /// - `writes`: `object[]` // One of: 59 - /// - - com.atproto.repo.applyWrites.create 60 - /// - - com.atproto.repo.applyWrites.update 61 - /// - - com.atproto.repo.applyWrites.delete 62 - /// - `swap_commit`: `cid` // If provided, the entire operation will fail if the current repo commit CID does not match this value. Used to prevent conflicting repo mutations. 63 - pub(crate) async fn apply_writes( 64 - user: AuthenticatedUser, 65 - State(skey): State<SigningKey>, 66 - State(config): State<AppConfig>, 67 - State(db): State<Db>, 68 - State(db_actors): State<std::collections::HashMap<String, ActorPools>>, 69 - State(fhp): State<FirehoseProducer>, 70 - Json(input): Json<ApplyWritesInput>, 71 - ) -> Result<Json<repo::apply_writes::Output>> { 72 - let tx: ApplyWritesInput = input; 73 - let ApplyWritesInput { 74 - repo, 75 - validate, 76 - swap_commit, 77 - .. 78 - } = tx; 79 - let account = account_manager 80 - .get_account( 81 - &repo, 82 - Some(AvailabilityFlags { 83 - include_deactivated: Some(true), 84 - include_taken_down: None, 85 - }), 86 - ) 87 - .await?; 88 - 89 - if let Some(account) = account { 90 - if account.deactivated_at.is_some() { 91 - return Err(Error::with_message( 92 - StatusCode::FORBIDDEN, 93 - anyhow!("Account is deactivated"), 94 - ErrorMessage::new("AccountDeactivated", "Account is deactivated"), 95 - )); 96 - } 97 - let did = account.did; 98 - if did != user.did() { 99 - return Err(Error::with_message( 100 - StatusCode::FORBIDDEN, 101 - anyhow!("AuthRequiredError"), 102 - ErrorMessage::new("AuthRequiredError", "Auth required"), 103 - )); 104 - } 105 - let did: &String = &did; 106 - if tx.writes.len() > 200 { 107 - return Err(Error::with_message( 108 - StatusCode::BAD_REQUEST, 109 - anyhow!("Too many writes. Max: 200"), 110 - ErrorMessage::new("TooManyWrites", "Too many writes. Max: 200"), 111 - )); 112 - } 113 - 114 - let writes: Vec<PreparedWrite> = stream::iter(tx.writes) 115 - .then(|write| async move { 116 - Ok::<PreparedWrite, anyhow::Error>(match write { 117 - ApplyWritesInputRefWrite::Create(write) => PreparedWrite::Create( 118 - prepare_create(PrepareCreateOpts { 119 - did: did.clone(), 120 - collection: write.collection, 121 - rkey: write.rkey, 122 - swap_cid: None, 123 - record: serde_json::from_value(write.value)?, 124 - validate, 125 - }) 126 - .await?, 127 - ), 128 - ApplyWritesInputRefWrite::Update(write) => PreparedWrite::Update( 129 - prepare_update(PrepareUpdateOpts { 130 - did: did.clone(), 131 - collection: write.collection, 132 - rkey: write.rkey, 133 - swap_cid: None, 134 - record: serde_json::from_value(write.value)?, 135 - validate, 136 - }) 137 - .await?, 138 - ), 139 - ApplyWritesInputRefWrite::Delete(write) => { 140 - PreparedWrite::Delete(prepare_delete(PrepareDeleteOpts { 141 - did: did.clone(), 142 - collection: write.collection, 143 - rkey: write.rkey, 144 - swap_cid: None, 145 - })?) 146 - } 147 - }) 148 - }) 149 - .collect::<Vec<_>>() 150 - .await 151 - .into_iter() 152 - .collect::<Result<Vec<PreparedWrite>, _>>()?; 153 - 154 - let swap_commit_cid = match swap_commit { 155 - Some(swap_commit) => Some(Cid::from_str(&swap_commit)?), 156 - None => None, 157 - }; 158 - 159 - let actor_db = db_actors 160 - .get(did) 161 - .ok_or_else(|| anyhow!("Actor DB not found"))?; 162 - let conn = actor_db 163 - .repo 164 - .get() 165 - .await 166 - .context("Failed to get actor db connection")?; 167 - let mut actor_store = ActorStore::new( 168 - did.clone(), 169 - BlobStoreSql::new(did.clone(), actor_db.blob), 170 - actor_db.repo, 171 - conn, 172 - ); 173 - 174 - let commit = actor_store 175 - .process_writes(writes.clone(), swap_commit_cid) 176 - .await?; 177 - 178 - let mut lock = sequencer.sequencer.write().await; 179 - lock.sequence_commit(did.clone(), commit.clone()).await?; 180 - account_manager 181 - .update_repo_root( 182 - did.to_string(), 183 - commit.commit_data.cid, 184 - commit.commit_data.rev, 185 - ) 186 - .await?; 187 - Ok(()) 188 - } else { 189 - Err(Error::with_message( 190 - StatusCode::NOT_FOUND, 191 - anyhow!("Could not find repo: `{repo}`"), 192 - ErrorMessage::new("RepoNotFound", "Could not find repo"), 193 - )) 194 - } 195 - }
-514
src/endpoints/repo.rs
··· 1 - //! PDS repository endpoints /xrpc/com.atproto.repo.*) 2 - mod apply_writes; 3 - pub(crate) use apply_writes::apply_writes; 4 - 5 - use std::{collections::HashSet, str::FromStr}; 6 - 7 - use anyhow::{Context as _, anyhow}; 8 - use atrium_api::com::atproto::repo::apply_writes::{ 9 - self as atrium_apply_writes, InputWritesItem, OutputResultsItem, 10 - }; 11 - use atrium_api::{ 12 - com::atproto::repo::{self, defs::CommitMetaData}, 13 - types::{ 14 - LimitedU32, Object, TryFromUnknown as _, TryIntoUnknown as _, Unknown, 15 - string::{AtIdentifier, Nsid, Tid}, 16 - }, 17 - }; 18 - use atrium_repo::{Cid, blockstore::CarStore}; 19 - use axum::{ 20 - Json, Router, 21 - body::Body, 22 - extract::{Query, Request, State}, 23 - http::{self, StatusCode}, 24 - routing::{get, post}, 25 - }; 26 - use constcat::concat; 27 - use futures::TryStreamExt as _; 28 - use metrics::counter; 29 - use rsky_syntax::aturi::AtUri; 30 - use serde::Deserialize; 31 - use tokio::io::AsyncWriteExt as _; 32 - 33 - use crate::repo::block_map::cid_for_cbor; 34 - use crate::repo::types::PreparedCreateOrUpdate; 35 - use crate::{ 36 - AppState, Db, Error, Result, SigningKey, 37 - actor_store::{ActorStoreTransactor, ActorStoreWriter}, 38 - auth::AuthenticatedUser, 39 - config::AppConfig, 40 - error::ErrorMessage, 41 - firehose::{self, FirehoseProducer, RepoOp}, 42 - metrics::{REPO_COMMITS, REPO_OP_CREATE, REPO_OP_DELETE, REPO_OP_UPDATE}, 43 - repo::types::{PreparedWrite, WriteOpAction}, 44 - storage, 45 - }; 46 - 47 - #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 48 - #[serde(rename_all = "camelCase")] 49 - /// Parameters for [`list_records`]. 50 - pub(super) struct ListRecordsParameters { 51 - ///The NSID of the record type. 52 - pub collection: Nsid, 53 - /// The cursor to start from. 54 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 55 - pub cursor: Option<String>, 56 - ///The number of records to return. 57 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 58 - pub limit: Option<String>, 59 - ///The handle or DID of the repo. 60 - pub repo: AtIdentifier, 61 - ///Flag to reverse the order of the returned records. 62 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 63 - pub reverse: Option<bool>, 64 - ///DEPRECATED: The highest sort-ordered rkey to stop at (exclusive) 65 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 66 - pub rkey_end: Option<String>, 67 - ///DEPRECATED: The lowest sort-ordered rkey to start from (exclusive) 68 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 69 - pub rkey_start: Option<String>, 70 - } 71 - 72 - /// Resolve DID to DID document. Does not bi-directionally verify handle. 73 - /// - GET /xrpc/com.atproto.repo.resolveDid 74 - /// ### Query Parameters 75 - /// - `did`: DID to resolve. 76 - /// ### Responses 77 - /// - 200 OK: {`did_doc`: `did_doc`} 78 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `DidNotFound`, `DidDeactivated`]} 79 - async fn resolve_did( 80 - db: &Db, 81 - identifier: &AtIdentifier, 82 - ) -> anyhow::Result<( 83 - atrium_api::types::string::Did, 84 - atrium_api::types::string::Handle, 85 - )> { 86 - let (handle, did) = match *identifier { 87 - AtIdentifier::Handle(ref handle) => { 88 - let handle_as_str = &handle.as_str(); 89 - ( 90 - &handle.to_owned(), 91 - &atrium_api::types::string::Did::new( 92 - sqlx::query_scalar!( 93 - r#"SELECT did FROM handles WHERE handle = ?"#, 94 - handle_as_str 95 - ) 96 - .fetch_one(db) 97 - .await 98 - .context("failed to query did")?, 99 - ) 100 - .expect("should be valid DID"), 101 - ) 102 - } 103 - AtIdentifier::Did(ref did) => { 104 - let did_as_str = &did.as_str(); 105 - ( 106 - &atrium_api::types::string::Handle::new( 107 - sqlx::query_scalar!(r#"SELECT handle FROM handles WHERE did = ?"#, did_as_str) 108 - .fetch_one(db) 109 - .await 110 - .context("failed to query did")?, 111 - ) 112 - .expect("should be valid handle"), 113 - &did.to_owned(), 114 - ) 115 - } 116 - }; 117 - 118 - Ok((did.to_owned(), handle.to_owned())) 119 - } 120 - 121 - /// Create a single new repository record. Requires auth, implemented by PDS. 122 - /// - POST /xrpc/com.atproto.repo.createRecord 123 - /// ### Request Body 124 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 125 - /// - `collection`: `nsid` // The NSID of the record collection. 126 - /// - `rkey`: `string` // The record key. <= 512 characters. 127 - /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 128 - /// - `record` 129 - /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 130 - /// ### Responses 131 - /// - 200 OK: {`cid`: `cid`, `uri`: `at-uri`, `commit`: {`cid`: `cid`, `rev`: `tid`}, `validation_status`: [`valid`, `unknown`]} 132 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 133 - /// - 401 Unauthorized 134 - async fn create_record( 135 - user: AuthenticatedUser, 136 - State(actor_store): State<ActorStore>, 137 - State(skey): State<SigningKey>, 138 - State(config): State<AppConfig>, 139 - State(db): State<Db>, 140 - State(fhp): State<FirehoseProducer>, 141 - Json(input): Json<repo::create_record::Input>, 142 - ) -> Result<Json<repo::create_record::Output>> { 143 - todo!(); 144 - // let write_result = apply_writes::apply_writes( 145 - // user, 146 - // State(actor_store), 147 - // State(skey), 148 - // State(config), 149 - // State(db), 150 - // State(fhp), 151 - // Json( 152 - // repo::apply_writes::InputData { 153 - // repo: input.repo.clone(), 154 - // validate: input.validate, 155 - // swap_commit: input.swap_commit.clone(), 156 - // writes: vec![repo::apply_writes::InputWritesItem::Create(Box::new( 157 - // repo::apply_writes::CreateData { 158 - // collection: input.collection.clone(), 159 - // rkey: input.rkey.clone(), 160 - // value: input.record.clone(), 161 - // } 162 - // .into(), 163 - // ))], 164 - // } 165 - // .into(), 166 - // ), 167 - // ) 168 - // .await 169 - // .context("failed to apply writes")?; 170 - 171 - // let create_result = if let repo::apply_writes::OutputResultsItem::CreateResult(create_result) = 172 - // write_result 173 - // .results 174 - // .clone() 175 - // .and_then(|result| result.first().cloned()) 176 - // .context("unexpected output from apply_writes")? 177 - // { 178 - // Some(create_result) 179 - // } else { 180 - // None 181 - // } 182 - // .context("unexpected result from apply_writes")?; 183 - 184 - // Ok(Json( 185 - // repo::create_record::OutputData { 186 - // cid: create_result.cid.clone(), 187 - // commit: write_result.commit.clone(), 188 - // uri: create_result.uri.clone(), 189 - // validation_status: Some("unknown".to_owned()), 190 - // } 191 - // .into(), 192 - // )) 193 - } 194 - 195 - /// Write a repository record, creating or updating it as needed. Requires auth, implemented by PDS. 196 - /// - POST /xrpc/com.atproto.repo.putRecord 197 - /// ### Request Body 198 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 199 - /// - `collection`: `nsid` // The NSID of the record collection. 200 - /// - `rkey`: `string` // The record key. <= 512 characters. 201 - /// - `validate`: `boolean` // Can be set to 'false' to skip Lexicon schema validation of record data, 'true' to require it, or leave unset to validate only for known Lexicons. 202 - /// - `record` 203 - /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. WARNING: nullable and optional field; may cause problems with golang implementation 204 - /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 205 - /// ### Responses 206 - /// - 200 OK: {"uri": "string","cid": "string","commit": {"cid": "string","rev": "string"},"validationStatus": "valid | unknown"} 207 - /// - 400 Bad Request: {error:"`InvalidRequest` | `ExpiredToken` | `InvalidToken` | `InvalidSwap`"} 208 - /// - 401 Unauthorized 209 - async fn put_record( 210 - user: AuthenticatedUser, 211 - State(actor_store): State<ActorStore>, 212 - State(skey): State<SigningKey>, 213 - State(config): State<AppConfig>, 214 - State(db): State<Db>, 215 - State(fhp): State<FirehoseProducer>, 216 - Json(input): Json<repo::put_record::Input>, 217 - ) -> Result<Json<repo::put_record::Output>> { 218 - todo!(); 219 - // // TODO: `input.swap_record` 220 - // // FIXME: "put" implies that we will create the record if it does not exist. 221 - // // We currently only update existing records and/or throw an error if one doesn't exist. 222 - // let input = (*input).clone(); 223 - // let input = repo::apply_writes::InputData { 224 - // repo: input.repo, 225 - // validate: input.validate, 226 - // swap_commit: input.swap_commit, 227 - // writes: vec![repo::apply_writes::InputWritesItem::Update(Box::new( 228 - // repo::apply_writes::UpdateData { 229 - // collection: input.collection, 230 - // rkey: input.rkey, 231 - // value: input.record, 232 - // } 233 - // .into(), 234 - // ))], 235 - // } 236 - // .into(); 237 - 238 - // let write_result = apply_writes::apply_writes( 239 - // user, 240 - // State(actor_store), 241 - // State(skey), 242 - // State(config), 243 - // State(db), 244 - // State(fhp), 245 - // Json(input), 246 - // ) 247 - // .await 248 - // .context("failed to apply writes")?; 249 - 250 - // let update_result = write_result 251 - // .results 252 - // .clone() 253 - // .and_then(|result| result.first().cloned()) 254 - // .context("unexpected output from apply_writes")?; 255 - // let (cid, uri) = match update_result { 256 - // repo::apply_writes::OutputResultsItem::CreateResult(create_result) => ( 257 - // Some(create_result.cid.clone()), 258 - // Some(create_result.uri.clone()), 259 - // ), 260 - // repo::apply_writes::OutputResultsItem::UpdateResult(update_result) => ( 261 - // Some(update_result.cid.clone()), 262 - // Some(update_result.uri.clone()), 263 - // ), 264 - // repo::apply_writes::OutputResultsItem::DeleteResult(_) => (None, None), 265 - // }; 266 - // Ok(Json( 267 - // repo::put_record::OutputData { 268 - // cid: cid.context("missing cid")?, 269 - // commit: write_result.commit.clone(), 270 - // uri: uri.context("missing uri")?, 271 - // validation_status: Some("unknown".to_owned()), 272 - // } 273 - // .into(), 274 - // )) 275 - } 276 - 277 - /// Delete a repository record, or ensure it doesn't exist. Requires auth, implemented by PDS. 278 - /// - POST /xrpc/com.atproto.repo.deleteRecord 279 - /// ### Request Body 280 - /// - `repo`: `at-identifier` // The handle or DID of the repo (aka, current account). 281 - /// - `collection`: `nsid` // The NSID of the record collection. 282 - /// - `rkey`: `string` // The record key. <= 512 characters. 283 - /// - `swap_record`: `boolean` // Compare and swap with the previous record by CID. 284 - /// - `swap_commit`: `cid` // Compare and swap with the previous commit by CID. 285 - /// ### Responses 286 - /// - 200 OK: {"commit": {"cid": "string","rev": "string"}} 287 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidSwap`]} 288 - /// - 401 Unauthorized 289 - async fn delete_record( 290 - user: AuthenticatedUser, 291 - State(actor_store): State<ActorStore>, 292 - State(skey): State<SigningKey>, 293 - State(config): State<AppConfig>, 294 - State(db): State<Db>, 295 - State(fhp): State<FirehoseProducer>, 296 - Json(input): Json<repo::delete_record::Input>, 297 - ) -> Result<Json<repo::delete_record::Output>> { 298 - todo!(); 299 - // // TODO: `input.swap_record` 300 - 301 - // Ok(Json( 302 - // repo::delete_record::OutputData { 303 - // commit: apply_writes::apply_writes( 304 - // user, 305 - // State(actor_store), 306 - // State(skey), 307 - // State(config), 308 - // State(db), 309 - // State(fhp), 310 - // Json( 311 - // repo::apply_writes::InputData { 312 - // repo: input.repo.clone(), 313 - // swap_commit: input.swap_commit.clone(), 314 - // validate: None, 315 - // writes: vec![repo::apply_writes::InputWritesItem::Delete(Box::new( 316 - // repo::apply_writes::DeleteData { 317 - // collection: input.collection.clone(), 318 - // rkey: input.rkey.clone(), 319 - // } 320 - // .into(), 321 - // ))], 322 - // } 323 - // .into(), 324 - // ), 325 - // ) 326 - // .await 327 - // .context("failed to apply writes")? 328 - // .commit 329 - // .clone(), 330 - // } 331 - // .into(), 332 - // )) 333 - } 334 - 335 - /// Get information about an account and repository, including the list of collections. Does not require auth. 336 - /// - GET /xrpc/com.atproto.repo.describeRepo 337 - /// ### Query Parameters 338 - /// - `repo`: `at-identifier` // The handle or DID of the repo. 339 - /// ### Responses 340 - /// - 200 OK: {"handle": "string","did": "string","didDoc": {},"collections": [string],"handleIsCorrect": true} \ 341 - /// handeIsCorrect - boolean - Indicates if handle is currently valid (resolves bi-directionally) 342 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 343 - /// - 401 Unauthorized 344 - async fn describe_repo( 345 - State(actor_store): State<ActorStore>, 346 - State(config): State<AppConfig>, 347 - State(db): State<Db>, 348 - Query(input): Query<repo::describe_repo::ParametersData>, 349 - ) -> Result<Json<repo::describe_repo::Output>> { 350 - // Lookup the DID by the provided handle. 351 - let (did, handle) = resolve_did(&db, &input.repo) 352 - .await 353 - .context("failed to resolve handle")?; 354 - 355 - // Use Actor Store to get the collections 356 - todo!(); 357 - } 358 - 359 - /// Get a single record from a repository. Does not require auth. 360 - /// - GET /xrpc/com.atproto.repo.getRecord 361 - /// ### Query Parameters 362 - /// - `repo`: `at-identifier` // The handle or DID of the repo. 363 - /// - `collection`: `nsid` // The NSID of the record collection. 364 - /// - `rkey`: `string` // The record key. <= 512 characters. 365 - /// - `cid`: `cid` // The CID of the version of the record. If not specified, then return the most recent version. 366 - /// ### Responses 367 - /// - 200 OK: {"uri": "string","cid": "string","value": {}} 368 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`]} 369 - /// - 401 Unauthorized 370 - async fn get_record( 371 - State(actor_store): State<ActorStore>, 372 - State(config): State<AppConfig>, 373 - State(db): State<Db>, 374 - Query(input): Query<repo::get_record::ParametersData>, 375 - ) -> Result<Json<repo::get_record::Output>> { 376 - if input.cid.is_some() { 377 - return Err(Error::unimplemented(anyhow!( 378 - "looking up old records is unsupported" 379 - ))); 380 - } 381 - 382 - // Lookup the DID by the provided handle. 383 - let (did, _handle) = resolve_did(&db, &input.repo) 384 - .await 385 - .context("failed to resolve handle")?; 386 - 387 - // Create a URI from the parameters 388 - let uri = format!( 389 - "at://{}/{}/{}", 390 - did.as_str(), 391 - input.collection.as_str(), 392 - input.rkey.as_str() 393 - ); 394 - 395 - // Use Actor Store to get the record 396 - todo!(); 397 - } 398 - 399 - /// List a range of records in a repository, matching a specific collection. Does not require auth. 400 - /// - GET /xrpc/com.atproto.repo.listRecords 401 - /// ### Query Parameters 402 - /// - `repo`: `at-identifier` // The handle or DID of the repo. 403 - /// - `collection`: `nsid` // The NSID of the record type. 404 - /// - `limit`: `integer` // The maximum number of records to return. Default 50, >=1 and <=100. 405 - /// - `cursor`: `string` 406 - /// - `reverse`: `boolean` // Flag to reverse the order of the returned records. 407 - /// ### Responses 408 - /// - 200 OK: {"cursor": "string","records": [{"uri": "string","cid": "string","value": {}}]} 409 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 410 - /// - 401 Unauthorized 411 - async fn list_records( 412 - State(actor_store): State<ActorStore>, 413 - State(config): State<AppConfig>, 414 - State(db): State<Db>, 415 - Query(input): Query<Object<ListRecordsParameters>>, 416 - ) -> Result<Json<repo::list_records::Output>> { 417 - // Lookup the DID by the provided handle. 418 - let (did, _handle) = resolve_did(&db, &input.repo) 419 - .await 420 - .context("failed to resolve handle")?; 421 - 422 - // Use Actor Store to list records for the collection 423 - todo!(); 424 - } 425 - 426 - /// Upload a new blob, to be referenced from a repository record. \ 427 - /// The blob will be deleted if it is not referenced within a time window (eg, minutes). \ 428 - /// Blob restrictions (mimetype, size, etc) are enforced when the reference is created. \ 429 - /// Requires auth, implemented by PDS. 430 - /// - POST /xrpc/com.atproto.repo.uploadBlob 431 - /// ### Request Body 432 - /// ### Responses 433 - /// - 200 OK: {"blob": "binary"} 434 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 435 - /// - 401 Unauthorized 436 - async fn upload_blob( 437 - user: AuthenticatedUser, 438 - State(actor_store): State<ActorStore>, 439 - State(config): State<AppConfig>, 440 - State(db): State<Db>, 441 - request: Request<Body>, 442 - ) -> Result<Json<repo::upload_blob::Output>> { 443 - let length = request 444 - .headers() 445 - .get(http::header::CONTENT_LENGTH) 446 - .context("no content length provided")? 447 - .to_str() 448 - .map_err(anyhow::Error::from) 449 - .and_then(|content_length| content_length.parse::<u64>().map_err(anyhow::Error::from)) 450 - .context("invalid content-length header")?; 451 - let mime = request 452 - .headers() 453 - .get(http::header::CONTENT_TYPE) 454 - .context("no content-type provided")? 455 - .to_str() 456 - .context("invalid content-type provided")? 457 - .to_owned(); 458 - 459 - if length > config.blob.limit { 460 - return Err(Error::with_status( 461 - StatusCode::PAYLOAD_TOO_LARGE, 462 - anyhow!("size {} above limit {}", length, config.blob.limit), 463 - )); 464 - } 465 - 466 - // Read the blob data 467 - let mut body_data = Vec::new(); 468 - let mut stream = request.into_body().into_data_stream(); 469 - while let Some(bytes) = stream.try_next().await.context("failed to receive file")? { 470 - body_data.extend_from_slice(&bytes); 471 - 472 - // Check size limit incrementally 473 - if body_data.len() as u64 > config.blob.limit { 474 - return Err(Error::with_status( 475 - StatusCode::PAYLOAD_TOO_LARGE, 476 - anyhow!("size above limit and content-length header was wrong"), 477 - )); 478 - } 479 - } 480 - 481 - // Use Actor Store to upload the blob 482 - todo!(); 483 - } 484 - 485 - async fn todo() -> Result<()> { 486 - Err(Error::unimplemented(anyhow!("not implemented"))) 487 - } 488 - 489 - /// These endpoints are part of the atproto PDS repository management APIs. \ 490 - /// Requests usually require authentication (unlike the com.atproto.sync.* endpoints), and are made directly to the user's own PDS instance. 491 - /// ### Routes 492 - /// - AP /xrpc/com.atproto.repo.applyWrites -> [`apply_writes`] 493 - /// - AP /xrpc/com.atproto.repo.createRecord -> [`create_record`] 494 - /// - AP /xrpc/com.atproto.repo.putRecord -> [`put_record`] 495 - /// - AP /xrpc/com.atproto.repo.deleteRecord -> [`delete_record`] 496 - /// - AP /xrpc/com.atproto.repo.uploadBlob -> [`upload_blob`] 497 - /// - UG /xrpc/com.atproto.repo.describeRepo -> [`describe_repo`] 498 - /// - UG /xrpc/com.atproto.repo.getRecord -> [`get_record`] 499 - /// - UG /xrpc/com.atproto.repo.listRecords -> [`list_records`] 500 - /// - [ ] xx /xrpc/com.atproto.repo.importRepo 501 - // - [ ] xx /xrpc/com.atproto.repo.listMissingBlobs 502 - pub(super) fn routes() -> Router<AppState> { 503 - Router::new() 504 - .route(concat!("/", repo::apply_writes::NSID), post(apply_writes)) 505 - // .route(concat!("/", repo::create_record::NSID), post(create_record)) 506 - // .route(concat!("/", repo::put_record::NSID), post(put_record)) 507 - // .route(concat!("/", repo::delete_record::NSID), post(delete_record)) 508 - // .route(concat!("/", repo::upload_blob::NSID), post(upload_blob)) 509 - // .route(concat!("/", repo::describe_repo::NSID), get(describe_repo)) 510 - // .route(concat!("/", repo::get_record::NSID), get(get_record)) 511 - .route(concat!("/", repo::import_repo::NSID), post(todo)) 512 - .route(concat!("/", repo::list_missing_blobs::NSID), get(todo)) 513 - // .route(concat!("/", repo::list_records::NSID), get(list_records)) 514 - }
-791
src/endpoints/server.rs
··· 1 - //! Server endpoints. (/xrpc/com.atproto.server.*) 2 - use std::{collections::HashMap, str::FromStr as _}; 3 - 4 - use anyhow::{Context as _, anyhow}; 5 - use argon2::{ 6 - Argon2, PasswordHash, PasswordHasher as _, PasswordVerifier as _, password_hash::SaltString, 7 - }; 8 - use atrium_api::{ 9 - com::atproto::server, 10 - types::string::{Datetime, Did, Handle, Tid}, 11 - }; 12 - use atrium_crypto::keypair::Did as _; 13 - use atrium_repo::{ 14 - Cid, Repository, 15 - blockstore::{AsyncBlockStoreWrite as _, CarStore, DAG_CBOR, SHA2_256}, 16 - }; 17 - use axum::{ 18 - Json, Router, 19 - extract::{Query, Request, State}, 20 - http::StatusCode, 21 - routing::{get, post}, 22 - }; 23 - use constcat::concat; 24 - use metrics::counter; 25 - use rand::Rng as _; 26 - use sha2::Digest as _; 27 - use uuid::Uuid; 28 - 29 - use crate::{ 30 - AppState, Client, Db, Error, Result, RotationKey, SigningKey, 31 - auth::{self, AuthenticatedUser}, 32 - config::AppConfig, 33 - firehose::{Commit, FirehoseProducer}, 34 - metrics::AUTH_FAILED, 35 - plc::{self, PlcOperation, PlcService}, 36 - storage, 37 - }; 38 - 39 - /// This is a dummy password that can be used in absence of a real password. 40 - const DUMMY_PASSWORD: &str = "$argon2id$v=19$m=19456,t=2,p=1$En2LAfHjeO0SZD5IUU1Abg$RpS8nHhhqY4qco2uyd41p9Y/1C+Lvi214MAWukzKQMI"; 41 - 42 - /// Create an invite code. 43 - /// - POST /xrpc/com.atproto.server.createInviteCode 44 - /// ### Request Body 45 - /// - `useCount`: integer 46 - /// - `forAccount`: string (optional) 47 - /// ### Responses 48 - /// - 200 OK: {code: string} 49 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 50 - /// - 401 Unauthorized 51 - async fn create_invite_code( 52 - _user: AuthenticatedUser, 53 - State(db): State<Db>, 54 - Json(input): Json<server::create_invite_code::Input>, 55 - ) -> Result<Json<server::create_invite_code::Output>> { 56 - let uuid = Uuid::new_v4().to_string(); 57 - let did = input.for_account.as_deref(); 58 - let count = std::cmp::min(input.use_count, 100); // Maximum of 100 uses for any code. 59 - 60 - if count <= 0 { 61 - return Err(anyhow!("use_count must be greater than 0").into()); 62 - } 63 - 64 - Ok(Json( 65 - server::create_invite_code::OutputData { 66 - code: sqlx::query_scalar!( 67 - r#" 68 - INSERT INTO invites (id, did, count, created_at) 69 - VALUES (?, ?, ?, datetime('now')) 70 - RETURNING id 71 - "#, 72 - uuid, 73 - did, 74 - count, 75 - ) 76 - .fetch_one(&db) 77 - .await 78 - .context("failed to create new invite code")?, 79 - } 80 - .into(), 81 - )) 82 - } 83 - 84 - #[expect(clippy::too_many_lines, reason = "TODO: refactor")] 85 - /// Create an account. Implemented by PDS. 86 - /// - POST /xrpc/com.atproto.server.createAccount 87 - /// ### Request Body 88 - /// - `email`: string 89 - /// - `handle`: string (required) 90 - /// - `did`: string - Pre-existing atproto DID, being imported to a new account. 91 - /// - `inviteCode`: string 92 - /// - `verificationCode`: string 93 - /// - `verificationPhone`: string 94 - /// - `password`: string - Initial account password. May need to meet instance-specific password strength requirements. 95 - /// - `recoveryKey`: string - DID PLC rotation key (aka, recovery key) to be included in PLC creation operation. 96 - /// - `plcOp`: object 97 - /// ## Responses 98 - /// - 200 OK: {"accessJwt": "string","refreshJwt": "string","handle": "string","did": "string","didDoc": {}} 99 - /// - 400 Bad Request: {error: [`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `InvalidHandle`, `InvalidPassword`, \ 100 - /// `InvalidInviteCode`, `HandleNotAvailable`, `UnsupportedDomain`, `UnresolvableDid`, `IncompatibleDidDoc`)} 101 - /// - 401 Unauthorized 102 - async fn create_account( 103 - State(db): State<Db>, 104 - State(skey): State<SigningKey>, 105 - State(rkey): State<RotationKey>, 106 - State(client): State<Client>, 107 - State(config): State<AppConfig>, 108 - State(fhp): State<FirehoseProducer>, 109 - Json(input): Json<server::create_account::Input>, 110 - ) -> Result<Json<server::create_account::Output>> { 111 - let email = input.email.as_deref().context("no email provided")?; 112 - // Hash the user's password. 113 - let pass = Argon2::default() 114 - .hash_password( 115 - input 116 - .password 117 - .as_deref() 118 - .context("no password provided")? 119 - .as_bytes(), 120 - SaltString::generate(&mut rand::thread_rng()).as_salt(), 121 - ) 122 - .context("failed to hash password")? 123 - .to_string(); 124 - let handle = input.handle.as_str().to_owned(); 125 - 126 - // TODO: Handle the account migration flow. 127 - // Users will hit this endpoint with a service-level authentication token. 128 - // 129 - // https://github.com/bluesky-social/pds/blob/main/ACCOUNT_MIGRATION.md 130 - 131 - // TODO: `input.plc_op` 132 - if input.plc_op.is_some() { 133 - return Err(Error::unimplemented(anyhow!("plc_op"))); 134 - } 135 - 136 - let recovery_keys = if let Some(ref key) = input.recovery_key { 137 - // Ensure the provided recovery key is valid. 138 - if let Err(error) = atrium_crypto::did::parse_did_key(key) { 139 - return Err(Error::with_status( 140 - StatusCode::BAD_REQUEST, 141 - anyhow::Error::new(error).context("provided recovery key is in invalid format"), 142 - )); 143 - } 144 - 145 - // Enroll the user-provided recovery key at a higher priority than our own. 146 - vec![key.clone(), rkey.did()] 147 - } else { 148 - vec![rkey.did()] 149 - }; 150 - 151 - // Begin a new transaction to actually create the user's profile. 152 - // Unless committed, the transaction will be automatically rolled back. 153 - let mut tx = db.begin().await.context("failed to begin transaction")?; 154 - 155 - // TODO: Make this its own toggle instead of tied to test mode 156 - if !config.test { 157 - let _invite = match input.invite_code { 158 - Some(ref code) => { 159 - let invite: Option<String> = sqlx::query_scalar!( 160 - r#" 161 - UPDATE invites 162 - SET count = count - 1 163 - WHERE id = ? 164 - AND count > 0 165 - RETURNING id 166 - "#, 167 - code 168 - ) 169 - .fetch_optional(&mut *tx) 170 - .await 171 - .context("failed to check invite code")?; 172 - 173 - invite.context("invalid invite code")? 174 - } 175 - None => { 176 - return Err(anyhow!("invite code required").into()); 177 - } 178 - }; 179 - } 180 - 181 - // Account can be created. Synthesize a new DID for the user. 182 - // https://github.com/did-method-plc/did-method-plc?tab=readme-ov-file#did-creation 183 - let op = plc::sign_op( 184 - &rkey, 185 - PlcOperation { 186 - typ: "plc_operation".to_owned(), 187 - rotation_keys: recovery_keys, 188 - verification_methods: HashMap::from([("atproto".to_owned(), skey.did())]), 189 - also_known_as: vec![format!("at://{}", input.handle.as_str())], 190 - services: HashMap::from([( 191 - "atproto_pds".to_owned(), 192 - PlcService::Pds { 193 - endpoint: format!("https://{}", config.host_name), 194 - }, 195 - )]), 196 - prev: None, 197 - }, 198 - ) 199 - .context("failed to sign genesis op")?; 200 - let op_bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode genesis op")?; 201 - 202 - let did_hash = { 203 - let digest = base32::encode( 204 - base32::Alphabet::Rfc4648Lower { padding: false }, 205 - sha2::Sha256::digest(&op_bytes).as_slice(), 206 - ); 207 - if digest.len() < 24 { 208 - return Err(anyhow!("digest too short").into()); 209 - } 210 - #[expect(clippy::string_slice, reason = "digest length confirmed")] 211 - digest[..24].to_owned() 212 - }; 213 - let did = format!("did:plc:{did_hash}"); 214 - 215 - let doc = tokio::fs::File::create(config.plc.path.join(format!("{did_hash}.car"))) 216 - .await 217 - .context("failed to create did doc")?; 218 - 219 - let mut plc_doc = CarStore::create(doc) 220 - .await 221 - .context("failed to create did doc")?; 222 - 223 - let plc_cid = plc_doc 224 - .write_block(DAG_CBOR, SHA2_256, &op_bytes) 225 - .await 226 - .context("failed to write genesis commit")? 227 - .to_string(); 228 - 229 - if !config.test { 230 - // Send the new account's data to the PLC directory. 231 - plc::submit(&client, &did, &op) 232 - .await 233 - .context("failed to submit PLC operation to directory")?; 234 - } 235 - 236 - // Write out an initial commit for the user. 237 - // https://atproto.com/guides/account-lifecycle 238 - let (cid, rev, store) = async { 239 - let store = storage::create_storage_for_did(&config.repo, &did_hash) 240 - .await 241 - .context("failed to create storage")?; 242 - 243 - // Initialize the repository with the storage 244 - let repo_builder = Repository::create( 245 - store, 246 - Did::from_str(&did).expect("should be valid DID format"), 247 - ) 248 - .await 249 - .context("failed to initialize user repo")?; 250 - 251 - // Sign the root commit. 252 - let sig = skey 253 - .sign(&repo_builder.bytes()) 254 - .context("failed to sign root commit")?; 255 - let mut repo = repo_builder 256 - .finalize(sig) 257 - .await 258 - .context("failed to attach signature to root commit")?; 259 - 260 - let root = repo.root(); 261 - let rev = repo.commit().rev(); 262 - 263 - // Create a temporary CAR store for firehose events 264 - let mut mem = Vec::new(); 265 - let mut firehose_store = 266 - CarStore::create_with_roots(std::io::Cursor::new(&mut mem), [repo.root()]) 267 - .await 268 - .context("failed to create temp carstore")?; 269 - 270 - repo.export_into(&mut firehose_store) 271 - .await 272 - .context("failed to export repository")?; 273 - 274 - Ok::<(Cid, Tid, Vec<u8>), anyhow::Error>((root, rev, mem)) 275 - } 276 - .await 277 - .context("failed to create user repo")?; 278 - 279 - let cid_str = cid.to_string(); 280 - let rev_str = rev.as_str(); 281 - 282 - _ = sqlx::query!( 283 - r#" 284 - INSERT INTO accounts (did, email, password, root, plc_root, rev, created_at) 285 - VALUES (?, ?, ?, ?, ?, ?, datetime('now')); 286 - 287 - INSERT INTO handles (did, handle, created_at) 288 - VALUES (?, ?, datetime('now')); 289 - 290 - -- Cleanup stale invite codes 291 - DELETE FROM invites 292 - WHERE count <= 0; 293 - "#, 294 - did, 295 - email, 296 - pass, 297 - cid_str, 298 - plc_cid, 299 - rev_str, 300 - did, 301 - handle 302 - ) 303 - .execute(&mut *tx) 304 - .await 305 - .context("failed to create new account")?; 306 - 307 - // The account is fully created. Commit the SQL transaction to the database. 308 - tx.commit().await.context("failed to commit transaction")?; 309 - 310 - // Broadcast the identity event now that the new identity is resolvable on the public directory. 311 - fhp.identity( 312 - atrium_api::com::atproto::sync::subscribe_repos::IdentityData { 313 - did: Did::from_str(&did).expect("should be valid DID format"), 314 - handle: Some(Handle::new(handle).expect("should be valid handle")), 315 - seq: 0, // Filled by firehose later. 316 - time: Datetime::now(), 317 - }, 318 - ) 319 - .await; 320 - 321 - // The new account is now active on this PDS, so we can broadcast the account firehose event. 322 - fhp.account( 323 - atrium_api::com::atproto::sync::subscribe_repos::AccountData { 324 - active: true, 325 - did: Did::from_str(&did).expect("should be valid DID format"), 326 - seq: 0, // Filled by firehose later. 327 - status: None, // "takedown" / "suspended" / "deactivated" 328 - time: Datetime::now(), 329 - }, 330 - ) 331 - .await; 332 - 333 - let did = Did::from_str(&did).expect("should be valid DID format"); 334 - 335 - fhp.commit(Commit { 336 - car: store, 337 - ops: Vec::new(), 338 - cid, 339 - rev: rev.to_string(), 340 - did: did.clone(), 341 - pcid: None, 342 - blobs: Vec::new(), 343 - }) 344 - .await; 345 - 346 - // Finally, sign some authentication tokens for the new user. 347 - let token = auth::sign( 348 - &skey, 349 - "at+jwt", 350 - &serde_json::json!({ 351 - "scope": "com.atproto.access", 352 - "sub": did, 353 - "iat": chrono::Utc::now().timestamp(), 354 - "exp": chrono::Utc::now().checked_add_signed(chrono::Duration::hours(4)).context("should be valid time")?.timestamp(), 355 - "aud": format!("did:web:{}", config.host_name) 356 - }), 357 - ) 358 - .context("failed to sign jwt")?; 359 - 360 - let refresh_token = auth::sign( 361 - &skey, 362 - "refresh+jwt", 363 - &serde_json::json!({ 364 - "scope": "com.atproto.refresh", 365 - "sub": did, 366 - "iat": chrono::Utc::now().timestamp(), 367 - "exp": chrono::Utc::now().checked_add_days(chrono::Days::new(90)).context("should be valid time")?.timestamp(), 368 - "aud": format!("did:web:{}", config.host_name) 369 - }), 370 - ) 371 - .context("failed to sign refresh jwt")?; 372 - 373 - Ok(Json( 374 - server::create_account::OutputData { 375 - access_jwt: token, 376 - did, 377 - did_doc: None, 378 - handle: input.handle.clone(), 379 - refresh_jwt: refresh_token, 380 - } 381 - .into(), 382 - )) 383 - } 384 - 385 - /// Create an authentication session. 386 - /// - POST /xrpc/com.atproto.server.createSession 387 - /// ### Request Body 388 - /// - `identifier`: string - Handle or other identifier supported by the server for the authenticating user. 389 - /// - `password`: string - Password for the authenticating user. 390 - /// - `authFactorToken` - string (optional) 391 - /// - `allowTakedown` - boolean (optional) - When true, instead of throwing error for takendown accounts, a valid response with a narrow scoped token will be returned 392 - /// ### Responses 393 - /// - 200 OK: {"accessJwt": "string","refreshJwt": "string","handle": "string","did": "string","didDoc": {},"email": "string","emailConfirmed": true,"emailAuthFactor": true,"active": true,"status": "takendown"} 394 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `AccountTakedown`, `AuthFactorTokenRequired`]} 395 - /// - 401 Unauthorized 396 - async fn create_session( 397 - State(db): State<Db>, 398 - State(skey): State<SigningKey>, 399 - State(config): State<AppConfig>, 400 - Json(input): Json<server::create_session::Input>, 401 - ) -> Result<Json<server::create_session::Output>> { 402 - let handle = &input.identifier; 403 - let password = &input.password; 404 - 405 - // TODO: `input.allow_takedown` 406 - // TODO: `input.auth_factor_token` 407 - 408 - let Some(account) = sqlx::query!( 409 - r#" 410 - WITH LatestHandles AS ( 411 - SELECT did, handle 412 - FROM handles 413 - WHERE (did, created_at) IN ( 414 - SELECT did, MAX(created_at) AS max_created_at 415 - FROM handles 416 - GROUP BY did 417 - ) 418 - ) 419 - SELECT a.did, a.password, h.handle 420 - FROM accounts a 421 - LEFT JOIN LatestHandles h ON a.did = h.did 422 - WHERE h.handle = ? 423 - "#, 424 - handle 425 - ) 426 - .fetch_optional(&db) 427 - .await 428 - .context("failed to authenticate")? 429 - else { 430 - counter!(AUTH_FAILED).increment(1); 431 - 432 - // SEC: Call argon2's `verify_password` to simulate password verification and discard the result. 433 - // We do this to avoid exposing a timing attack where attackers can measure the response time to 434 - // determine whether or not an account exists. 435 - _ = Argon2::default().verify_password( 436 - password.as_bytes(), 437 - &PasswordHash::new(DUMMY_PASSWORD).context("should be valid password hash")?, 438 - ); 439 - 440 - return Err(Error::with_status( 441 - StatusCode::UNAUTHORIZED, 442 - anyhow!("failed to validate credentials"), 443 - )); 444 - }; 445 - 446 - match Argon2::default().verify_password( 447 - password.as_bytes(), 448 - &PasswordHash::new(account.password.as_str()).context("invalid password hash in db")?, 449 - ) { 450 - Ok(()) => {} 451 - Err(_e) => { 452 - counter!(AUTH_FAILED).increment(1); 453 - 454 - return Err(Error::with_status( 455 - StatusCode::UNAUTHORIZED, 456 - anyhow!("failed to validate credentials"), 457 - )); 458 - } 459 - } 460 - 461 - let did = account.did; 462 - 463 - let token = auth::sign( 464 - &skey, 465 - "at+jwt", 466 - &serde_json::json!({ 467 - "scope": "com.atproto.access", 468 - "sub": did, 469 - "iat": chrono::Utc::now().timestamp(), 470 - "exp": chrono::Utc::now().checked_add_signed(chrono::Duration::hours(4)).context("should be valid time")?.timestamp(), 471 - "aud": format!("did:web:{}", config.host_name) 472 - }), 473 - ) 474 - .context("failed to sign jwt")?; 475 - 476 - let refresh_token = auth::sign( 477 - &skey, 478 - "refresh+jwt", 479 - &serde_json::json!({ 480 - "scope": "com.atproto.refresh", 481 - "sub": did, 482 - "iat": chrono::Utc::now().timestamp(), 483 - "exp": chrono::Utc::now().checked_add_days(chrono::Days::new(90)).context("should be valid time")?.timestamp(), 484 - "aud": format!("did:web:{}", config.host_name) 485 - }), 486 - ) 487 - .context("failed to sign refresh jwt")?; 488 - 489 - Ok(Json( 490 - server::create_session::OutputData { 491 - access_jwt: token, 492 - refresh_jwt: refresh_token, 493 - 494 - active: Some(true), 495 - did: Did::from_str(&did).expect("should be valid DID format"), 496 - did_doc: None, 497 - email: None, 498 - email_auth_factor: None, 499 - email_confirmed: None, 500 - handle: Handle::new(account.handle).expect("should be valid handle"), 501 - status: None, 502 - } 503 - .into(), 504 - )) 505 - } 506 - 507 - /// Refresh an authentication session. Requires auth using the 'refreshJwt' (not the 'accessJwt'). 508 - /// - POST /xrpc/com.atproto.server.refreshSession 509 - /// ### Responses 510 - /// - 200 OK: {"accessJwt": "string","refreshJwt": "string","handle": "string","did": "string","didDoc": {},"active": true,"status": "takendown"} 511 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `AccountTakedown`]} 512 - /// - 401 Unauthorized 513 - async fn refresh_session( 514 - State(db): State<Db>, 515 - State(skey): State<SigningKey>, 516 - State(config): State<AppConfig>, 517 - req: Request, 518 - ) -> Result<Json<server::refresh_session::Output>> { 519 - // TODO: store hashes of refresh tokens and enforce single-use 520 - let auth_token = req 521 - .headers() 522 - .get(axum::http::header::AUTHORIZATION) 523 - .context("no authorization header provided")? 524 - .to_str() 525 - .ok() 526 - .and_then(|auth| auth.strip_prefix("Bearer ")) 527 - .context("invalid authentication token")?; 528 - 529 - let (typ, claims) = 530 - auth::verify(&skey.did(), auth_token).context("failed to verify refresh token")?; 531 - if typ != "refresh+jwt" { 532 - return Err(Error::with_status( 533 - StatusCode::UNAUTHORIZED, 534 - anyhow!("invalid refresh token"), 535 - )); 536 - } 537 - if claims 538 - .get("exp") 539 - .and_then(serde_json::Value::as_i64) 540 - .context("failed to get `exp`")? 541 - < chrono::Utc::now().timestamp() 542 - { 543 - return Err(Error::with_status( 544 - StatusCode::UNAUTHORIZED, 545 - anyhow!("refresh token expired"), 546 - )); 547 - } 548 - if claims 549 - .get("aud") 550 - .and_then(|audience| audience.as_str()) 551 - .context("invalid jwt")? 552 - != format!("did:web:{}", config.host_name) 553 - { 554 - return Err(Error::with_status( 555 - StatusCode::UNAUTHORIZED, 556 - anyhow!("invalid audience"), 557 - )); 558 - } 559 - 560 - let did = claims 561 - .get("sub") 562 - .and_then(|subject| subject.as_str()) 563 - .context("invalid jwt")?; 564 - 565 - let user = sqlx::query!( 566 - r#" 567 - SELECT a.status, h.handle 568 - FROM accounts a 569 - JOIN handles h ON a.did = h.did 570 - WHERE a.did = ? 571 - ORDER BY h.created_at ASC 572 - LIMIT 1 573 - "#, 574 - did 575 - ) 576 - .fetch_one(&db) 577 - .await 578 - .context("failed to fetch user account")?; 579 - 580 - let token = auth::sign( 581 - &skey, 582 - "at+jwt", 583 - &serde_json::json!({ 584 - "scope": "com.atproto.access", 585 - "sub": did, 586 - "iat": chrono::Utc::now().timestamp(), 587 - "exp": chrono::Utc::now().checked_add_signed(chrono::Duration::hours(4)).context("should be valid time")?.timestamp(), 588 - "aud": format!("did:web:{}", config.host_name) 589 - }), 590 - ) 591 - .context("failed to sign jwt")?; 592 - 593 - let refresh_token = auth::sign( 594 - &skey, 595 - "refresh+jwt", 596 - &serde_json::json!({ 597 - "scope": "com.atproto.refresh", 598 - "sub": did, 599 - "iat": chrono::Utc::now().timestamp(), 600 - "exp": chrono::Utc::now().checked_add_days(chrono::Days::new(90)).context("should be valid time")?.timestamp(), 601 - "aud": format!("did:web:{}", config.host_name) 602 - }), 603 - ) 604 - .context("failed to sign refresh jwt")?; 605 - 606 - let active = user.status == "active"; 607 - let status = if active { None } else { Some(user.status) }; 608 - 609 - Ok(Json( 610 - server::refresh_session::OutputData { 611 - access_jwt: token, 612 - refresh_jwt: refresh_token, 613 - 614 - active: Some(active), // TODO? 615 - did: Did::new(did.to_owned()).expect("should be valid DID format"), 616 - did_doc: None, 617 - handle: Handle::new(user.handle).expect("should be valid handle"), 618 - status, 619 - } 620 - .into(), 621 - )) 622 - } 623 - 624 - /// Get a signed token on behalf of the requesting DID for the requested service. 625 - /// - GET /xrpc/com.atproto.server.getServiceAuth 626 - /// ### Request Query Parameters 627 - /// - `aud`: string - The DID of the service that the token will be used to authenticate with 628 - /// - `exp`: integer (optional) - The time in Unix Epoch seconds that the JWT expires. Defaults to 60 seconds in the future. The service may enforce certain time bounds on tokens depending on the requested scope. 629 - /// - `lxm`: string (optional) - Lexicon (XRPC) method to bind the requested token to 630 - /// ### Responses 631 - /// - 200 OK: {token: string} 632 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `BadExpiration`]} 633 - /// - 401 Unauthorized 634 - async fn get_service_auth( 635 - user: AuthenticatedUser, 636 - State(skey): State<SigningKey>, 637 - Query(input): Query<server::get_service_auth::ParametersData>, 638 - ) -> Result<Json<server::get_service_auth::Output>> { 639 - let user_did = user.did(); 640 - let aud = input.aud.as_str(); 641 - 642 - let exp = (chrono::Utc::now().checked_add_signed(chrono::Duration::minutes(1))) 643 - .context("should be valid expiration datetime")? 644 - .timestamp(); 645 - let jti = rand::thread_rng() 646 - .sample_iter(rand::distributions::Alphanumeric) 647 - .take(10) 648 - .map(char::from) 649 - .collect::<String>(); 650 - 651 - let mut claims = serde_json::json!({ 652 - "iss": user_did.as_str(), 653 - "aud": aud, 654 - "exp": exp, 655 - "jti": jti, 656 - }); 657 - 658 - if let Some(ref lxm) = input.lxm { 659 - claims = claims 660 - .as_object_mut() 661 - .context("should be a valid object")? 662 - .insert("lxm".to_owned(), serde_json::Value::String(lxm.to_string())) 663 - .context("should be able to insert lxm into claims")?; 664 - } 665 - 666 - // Mint a bearer token by signing a JSON web token. 667 - let token = auth::sign(&skey, "JWT", &claims).context("failed to sign jwt")?; 668 - 669 - Ok(Json(server::get_service_auth::OutputData { token }.into())) 670 - } 671 - 672 - /// Get information about the current auth session. Requires auth. 673 - /// - GET /xrpc/com.atproto.server.getSession 674 - /// ### Responses 675 - /// - 200 OK: {"handle": "string","did": "string","email": "string","emailConfirmed": true,"emailAuthFactor": true,"didDoc": {},"active": true,"status": "takendown"} 676 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 677 - /// - 401 Unauthorized 678 - async fn get_session( 679 - user: AuthenticatedUser, 680 - State(db): State<Db>, 681 - ) -> Result<Json<server::get_session::Output>> { 682 - let did = user.did(); 683 - #[expect(clippy::shadow_unrelated, reason = "is related")] 684 - if let Some(user) = sqlx::query!( 685 - r#" 686 - SELECT a.email, a.status, ( 687 - SELECT h.handle 688 - FROM handles h 689 - WHERE h.did = a.did 690 - ORDER BY h.created_at ASC 691 - LIMIT 1 692 - ) AS handle 693 - FROM accounts a 694 - WHERE a.did = ? 695 - "#, 696 - did 697 - ) 698 - .fetch_optional(&db) 699 - .await 700 - .context("failed to fetch session")? 701 - { 702 - let active = user.status == "active"; 703 - let status = if active { None } else { Some(user.status) }; 704 - 705 - Ok(Json( 706 - server::get_session::OutputData { 707 - active: Some(active), 708 - did: Did::from_str(&did).expect("should be valid DID format"), 709 - did_doc: None, 710 - email: Some(user.email), 711 - email_auth_factor: None, 712 - email_confirmed: None, 713 - handle: Handle::new(user.handle).expect("should be valid handle"), 714 - status, 715 - } 716 - .into(), 717 - )) 718 - } else { 719 - Err(Error::with_status( 720 - StatusCode::UNAUTHORIZED, 721 - anyhow!("user not found"), 722 - )) 723 - } 724 - } 725 - 726 - /// Describes the server's account creation requirements and capabilities. Implemented by PDS. 727 - /// - GET /xrpc/com.atproto.server.describeServer 728 - /// ### Responses 729 - /// - 200 OK: {"inviteCodeRequired": true,"phoneVerificationRequired": true,"availableUserDomains": [`string`],"links": {"privacyPolicy": "string","termsOfService": "string"},"contact": {"email": "string"},"did": "string"} 730 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 731 - /// - 401 Unauthorized 732 - async fn describe_server( 733 - State(config): State<AppConfig>, 734 - ) -> Result<Json<server::describe_server::Output>> { 735 - Ok(Json( 736 - server::describe_server::OutputData { 737 - available_user_domains: vec![], 738 - contact: None, 739 - did: Did::from_str(&format!("did:web:{}", config.host_name)) 740 - .expect("should be valid DID format"), 741 - invite_code_required: Some(true), 742 - links: None, 743 - phone_verification_required: Some(false), // email verification 744 - } 745 - .into(), 746 - )) 747 - } 748 - 749 - async fn todo() -> Result<()> { 750 - Err(Error::unimplemented(anyhow!("not implemented"))) 751 - } 752 - 753 - #[rustfmt::skip] 754 - /// These endpoints are part of the atproto PDS server and account management APIs. \ 755 - /// Requests often require authentication and are made directly to the user's own PDS instance. 756 - /// ### Routes 757 - /// - `POST /xrpc/com.atproto.server.createAccount` -> [`create_account`] 758 - /// - `POST /xrpc/com.atproto.server.createInviteCode` -> [`create_invite_code`] 759 - /// - `POST /xrpc/com.atproto.server.createSession` -> [`create_session`] 760 - /// - `GET /xrpc/com.atproto.server.describeServer` -> [`describe_server`] 761 - /// - `GET /xrpc/com.atproto.server.getServiceAuth` -> [`get_service_auth`] 762 - /// - `GET /xrpc/com.atproto.server.getSession` -> [`get_session`] 763 - /// - `POST /xrpc/com.atproto.server.refreshSession` -> [`refresh_session`] 764 - pub(super) fn routes() -> Router<AppState> { 765 - Router::new() 766 - .route(concat!("/", server::activate_account::NSID), post(todo)) 767 - .route(concat!("/", server::check_account_status::NSID), post(todo)) 768 - .route(concat!("/", server::confirm_email::NSID), post(todo)) 769 - .route(concat!("/", server::create_account::NSID), post(create_account)) 770 - .route(concat!("/", server::create_app_password::NSID), post(todo)) 771 - .route(concat!("/", server::create_invite_code::NSID), post(create_invite_code)) 772 - .route(concat!("/", server::create_invite_codes::NSID), post(todo)) 773 - .route(concat!("/", server::create_session::NSID), post(create_session)) 774 - .route(concat!("/", server::deactivate_account::NSID), post(todo)) 775 - .route(concat!("/", server::delete_account::NSID), post(todo)) 776 - .route(concat!("/", server::delete_session::NSID), post(todo)) 777 - .route(concat!("/", server::describe_server::NSID), get(describe_server)) 778 - .route(concat!("/", server::get_account_invite_codes::NSID), post(todo)) 779 - .route(concat!("/", server::get_service_auth::NSID), get(get_service_auth)) 780 - .route(concat!("/", server::get_session::NSID), get(get_session)) 781 - .route(concat!("/", server::list_app_passwords::NSID), post(todo)) 782 - .route(concat!("/", server::refresh_session::NSID), post(refresh_session)) 783 - .route(concat!("/", server::request_account_delete::NSID), post(todo)) 784 - .route(concat!("/", server::request_email_confirmation::NSID), post(todo)) 785 - .route(concat!("/", server::request_email_update::NSID), post(todo)) 786 - .route(concat!("/", server::request_password_reset::NSID), post(todo)) 787 - .route(concat!("/", server::reserve_signing_key::NSID), post(todo)) 788 - .route(concat!("/", server::reset_password::NSID), post(todo)) 789 - .route(concat!("/", server::revoke_app_password::NSID), post(todo)) 790 - .route(concat!("/", server::update_email::NSID), post(todo)) 791 - }
-428
src/endpoints/sync.rs
··· 1 - //! Endpoints for the `ATProto` sync API. (/xrpc/com.atproto.sync.*) 2 - use std::str::FromStr as _; 3 - 4 - use anyhow::{Context as _, anyhow}; 5 - use atrium_api::{ 6 - com::atproto::sync, 7 - types::{LimitedNonZeroU16, string::Did}, 8 - }; 9 - use atrium_repo::{ 10 - Cid, 11 - blockstore::{ 12 - AsyncBlockStoreRead as _, AsyncBlockStoreWrite as _, CarStore, DAG_CBOR, SHA2_256, 13 - }, 14 - }; 15 - use axum::{ 16 - Json, Router, 17 - body::Body, 18 - extract::{Query, State, WebSocketUpgrade}, 19 - http::{self, Response, StatusCode}, 20 - response::IntoResponse, 21 - routing::get, 22 - }; 23 - use constcat::concat; 24 - use futures::stream::TryStreamExt as _; 25 - use tokio_util::io::ReaderStream; 26 - 27 - use crate::{ 28 - AppState, Db, Error, Result, 29 - config::AppConfig, 30 - firehose::FirehoseProducer, 31 - storage::{open_repo_db, open_store}, 32 - }; 33 - 34 - #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 35 - #[serde(rename_all = "camelCase")] 36 - /// Parameters for `/xrpc/com.atproto.sync.listBlobs` \ 37 - /// HACK: `limit` may be passed as a string, so we must treat it as one. 38 - pub(super) struct ListBlobsParameters { 39 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 40 - /// Optional cursor to paginate through blobs. 41 - pub cursor: Option<String>, 42 - ///The DID of the repo. 43 - pub did: Did, 44 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 45 - /// Optional limit of blobs to return. 46 - pub limit: Option<String>, 47 - ///Optional revision of the repo to list blobs since. 48 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 49 - pub since: Option<String>, 50 - } 51 - #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 52 - #[serde(rename_all = "camelCase")] 53 - /// Parameters for `/xrpc/com.atproto.sync.listRepos` \ 54 - /// HACK: `limit` may be passed as a string, so we must treat it as one. 55 - pub(super) struct ListReposParameters { 56 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 57 - /// Optional cursor to paginate through repos. 58 - pub cursor: Option<String>, 59 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 60 - /// Optional limit of repos to return. 61 - pub limit: Option<String>, 62 - } 63 - #[derive(serde::Serialize, serde::Deserialize, Debug, Clone, PartialEq, Eq)] 64 - #[serde(rename_all = "camelCase")] 65 - /// Parameters for `/xrpc/com.atproto.sync.subscribeRepos` \ 66 - /// HACK: `cursor` may be passed as a string, so we must treat it as one. 67 - pub(super) struct SubscribeReposParametersData { 68 - ///The last known event seq number to backfill from. 69 - #[serde(skip_serializing_if = "core::option::Option::is_none")] 70 - pub cursor: Option<String>, 71 - } 72 - 73 - async fn get_blob( 74 - State(config): State<AppConfig>, 75 - Query(input): Query<sync::get_blob::ParametersData>, 76 - ) -> Result<Response<Body>> { 77 - let blob = config 78 - .blob 79 - .path 80 - .join(format!("{}.blob", input.cid.as_ref())); 81 - 82 - let f = tokio::fs::File::open(blob) 83 - .await 84 - .context("blob not found")?; 85 - let len = f 86 - .metadata() 87 - .await 88 - .context("failed to query file metadata")? 89 - .len(); 90 - 91 - let s = ReaderStream::new(f); 92 - 93 - Ok(Response::builder() 94 - .header(http::header::CONTENT_LENGTH, format!("{len}")) 95 - .body(Body::from_stream(s)) 96 - .context("failed to construct response")?) 97 - } 98 - 99 - /// Enumerates which accounts the requesting account is currently blocking. Requires auth. 100 - /// - GET /xrpc/com.atproto.sync.getBlocks 101 - /// ### Query Parameters 102 - /// - `limit`: integer, optional, default: 50, >=1 and <=100 103 - /// - `cursor`: string, optional 104 - /// ### Responses 105 - /// - 200 OK: ... 106 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 107 - /// - 401 Unauthorized 108 - async fn get_blocks( 109 - State(config): State<AppConfig>, 110 - Query(input): Query<sync::get_blocks::ParametersData>, 111 - ) -> Result<Response<Body>> { 112 - let mut repo = open_store(&config.repo, input.did.as_str()) 113 - .await 114 - .context("failed to open repository")?; 115 - 116 - let mut mem = Vec::new(); 117 - let mut store = CarStore::create(std::io::Cursor::new(&mut mem)) 118 - .await 119 - .context("failed to create intermediate carstore")?; 120 - 121 - for cid in &input.cids { 122 - // SEC: This can potentially fetch stale blocks from a repository (e.g. those that were deleted). 123 - // We'll want to prevent accesses to stale blocks eventually just to respect a user's right to be forgotten. 124 - _ = store 125 - .write_block( 126 - DAG_CBOR, 127 - SHA2_256, 128 - &repo 129 - .read_block(*cid.as_ref()) 130 - .await 131 - .context("failed to read block")?, 132 - ) 133 - .await 134 - .context("failed to write block")?; 135 - } 136 - 137 - Ok(Response::builder() 138 - .header(http::header::CONTENT_TYPE, "application/vnd.ipld.car") 139 - .body(Body::from(mem)) 140 - .context("failed to construct response")?) 141 - } 142 - 143 - /// Get the current commit CID & revision of the specified repo. Does not require auth. 144 - /// ### Query Parameters 145 - /// - `did`: The DID of the repo. 146 - /// ### Responses 147 - /// - 200 OK: {"cid": "string","rev": "string"} 148 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoTakendown`, `RepoSuspended`, `RepoDeactivated`]} 149 - async fn get_latest_commit( 150 - State(config): State<AppConfig>, 151 - State(db): State<Db>, 152 - Query(input): Query<sync::get_latest_commit::ParametersData>, 153 - ) -> Result<Json<sync::get_latest_commit::Output>> { 154 - let repo = open_repo_db(&config.repo, &db, input.did.as_str()) 155 - .await 156 - .context("failed to open repository")?; 157 - 158 - let cid = repo.root(); 159 - let commit = repo.commit(); 160 - 161 - Ok(Json( 162 - sync::get_latest_commit::OutputData { 163 - cid: atrium_api::types::string::Cid::new(cid), 164 - rev: commit.rev(), 165 - } 166 - .into(), 167 - )) 168 - } 169 - 170 - /// Get data blocks needed to prove the existence or non-existence of record in the current version of repo. Does not require auth. 171 - /// ### Query Parameters 172 - /// - `did`: The DID of the repo. 173 - /// - `collection`: nsid 174 - /// - `rkey`: record-key 175 - /// ### Responses 176 - /// - 200 OK: ... 177 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RecordNotFound`, `RepoNotFound`, `RepoTakendown`, 178 - /// `RepoSuspended`, `RepoDeactivated`]} 179 - async fn get_record( 180 - State(config): State<AppConfig>, 181 - State(db): State<Db>, 182 - Query(input): Query<sync::get_record::ParametersData>, 183 - ) -> Result<Response<Body>> { 184 - let mut repo = open_repo_db(&config.repo, &db, input.did.as_str()) 185 - .await 186 - .context("failed to open repo")?; 187 - 188 - let key = format!("{}/{}", input.collection.as_str(), input.rkey.as_str()); 189 - 190 - let mut contents = Vec::new(); 191 - let mut ret_store = 192 - CarStore::create_with_roots(std::io::Cursor::new(&mut contents), [repo.root()]) 193 - .await 194 - .context("failed to create car store")?; 195 - 196 - repo.extract_raw_into(&key, &mut ret_store) 197 - .await 198 - .context("failed to extract records")?; 199 - 200 - Ok(Response::builder() 201 - .header(http::header::CONTENT_TYPE, "application/vnd.ipld.car") 202 - .body(Body::from(contents)) 203 - .context("failed to construct response")?) 204 - } 205 - 206 - /// Get the hosting status for a repository, on this server. Expected to be implemented by PDS and Relay. 207 - /// ### Query Parameters 208 - /// - `did`: The DID of the repo. 209 - /// ### Responses 210 - /// - 200 OK: {"did": "string","active": true,"status": "takendown","rev": "string"} 211 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoNotFound`]} 212 - async fn get_repo_status( 213 - State(db): State<Db>, 214 - Query(input): Query<sync::get_repo::ParametersData>, 215 - ) -> Result<Json<sync::get_repo_status::Output>> { 216 - let did = input.did.as_str(); 217 - let r = sqlx::query!(r#"SELECT rev, status FROM accounts WHERE did = ?"#, did) 218 - .fetch_optional(&db) 219 - .await 220 - .context("failed to execute query")?; 221 - 222 - let Some(r) = r else { 223 - return Err(Error::with_status( 224 - StatusCode::NOT_FOUND, 225 - anyhow!("account not found"), 226 - )); 227 - }; 228 - 229 - let active = r.status == "active"; 230 - let status = if active { None } else { Some(r.status) }; 231 - 232 - Ok(Json( 233 - sync::get_repo_status::OutputData { 234 - active, 235 - status, 236 - did: input.did.clone(), 237 - rev: Some( 238 - atrium_api::types::string::Tid::new(r.rev).expect("should be able to convert Tid"), 239 - ), 240 - } 241 - .into(), 242 - )) 243 - } 244 - 245 - /// Download a repository export as CAR file. Optionally only a 'diff' since a previous revision. 246 - /// Does not require auth; implemented by PDS. 247 - /// ### Query Parameters 248 - /// - `did`: The DID of the repo. 249 - /// - `since`: The revision ('rev') of the repo to create a diff from. 250 - /// ### Responses 251 - /// - 200 OK: ... 252 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoNotFound`, 253 - /// `RepoTakendown`, `RepoSuspended`, `RepoDeactivated`]} 254 - async fn get_repo( 255 - State(config): State<AppConfig>, 256 - State(db): State<Db>, 257 - Query(input): Query<sync::get_repo::ParametersData>, 258 - ) -> Result<Response<Body>> { 259 - let mut repo = open_repo_db(&config.repo, &db, input.did.as_str()) 260 - .await 261 - .context("failed to open repo")?; 262 - 263 - let mut contents = Vec::new(); 264 - let mut store = CarStore::create_with_roots(std::io::Cursor::new(&mut contents), [repo.root()]) 265 - .await 266 - .context("failed to create car store")?; 267 - 268 - repo.export_into(&mut store) 269 - .await 270 - .context("failed to extract records")?; 271 - 272 - Ok(Response::builder() 273 - .header(http::header::CONTENT_TYPE, "application/vnd.ipld.car") 274 - .body(Body::from(contents)) 275 - .context("failed to construct response")?) 276 - } 277 - 278 - /// List blob CIDs for an account, since some repo revision. Does not require auth; implemented by PDS. 279 - /// ### Query Parameters 280 - /// - `did`: The DID of the repo. Required. 281 - /// - `since`: Optional revision of the repo to list blobs since. 282 - /// - `limit`: >= 1 and <= 1000, default 500 283 - /// - `cursor`: string 284 - /// ### Responses 285 - /// - 200 OK: {"cursor": "string","cids": [string]} 286 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`, `RepoNotFound`, `RepoTakendown`, 287 - /// `RepoSuspended`, `RepoDeactivated`]} 288 - async fn list_blobs( 289 - State(db): State<Db>, 290 - Query(input): Query<sync::list_blobs::ParametersData>, 291 - ) -> Result<Json<sync::list_blobs::Output>> { 292 - let did_str = input.did.as_str(); 293 - 294 - // TODO: `input.since` 295 - // TODO: `input.limit` 296 - // TODO: `input.cursor` 297 - 298 - let cids = sqlx::query_scalar!(r#"SELECT cid FROM blob_ref WHERE did = ?"#, did_str) 299 - .fetch_all(&db) 300 - .await 301 - .context("failed to query blobs")?; 302 - 303 - let cids = cids 304 - .into_iter() 305 - .map(|c| { 306 - Cid::from_str(&c) 307 - .map(atrium_api::types::string::Cid::new) 308 - .map_err(anyhow::Error::new) 309 - }) 310 - .collect::<anyhow::Result<Vec<_>>>() 311 - .context("failed to convert cids")?; 312 - 313 - Ok(Json( 314 - sync::list_blobs::OutputData { cursor: None, cids }.into(), 315 - )) 316 - } 317 - 318 - /// Enumerates all the DID, rev, and commit CID for all repos hosted by this service. 319 - /// Does not require auth; implemented by PDS and Relay. 320 - /// ### Query Parameters 321 - /// - `limit`: >= 1 and <= 1000, default 500 322 - /// - `cursor`: string 323 - /// ### Responses 324 - /// - 200 OK: {"cursor": "string","repos": [{"did": "string","head": "string","rev": "string","active": true,"status": "takendown"}]} 325 - /// - 400 Bad Request: {error:[`InvalidRequest`, `ExpiredToken`, `InvalidToken`]} 326 - async fn list_repos( 327 - State(db): State<Db>, 328 - Query(input): Query<sync::list_repos::ParametersData>, 329 - ) -> Result<Json<sync::list_repos::Output>> { 330 - struct Record { 331 - /// The DID of the repo. 332 - did: String, 333 - /// The commit CID of the repo. 334 - rev: String, 335 - /// The root CID of the repo. 336 - root: String, 337 - } 338 - 339 - let limit: u16 = input.limit.unwrap_or(LimitedNonZeroU16::MAX).into(); 340 - 341 - let r = if let Some(ref cursor) = input.cursor { 342 - let r = sqlx::query_as!( 343 - Record, 344 - r#"SELECT did, root, rev FROM accounts WHERE did > ? LIMIT ?"#, 345 - cursor, 346 - limit 347 - ) 348 - .fetch(&db); 349 - 350 - r.try_collect::<Vec<_>>() 351 - .await 352 - .context("failed to fetch profiles")? 353 - } else { 354 - let r = sqlx::query_as!( 355 - Record, 356 - r#"SELECT did, root, rev FROM accounts LIMIT ?"#, 357 - limit 358 - ) 359 - .fetch(&db); 360 - 361 - r.try_collect::<Vec<_>>() 362 - .await 363 - .context("failed to fetch profiles")? 364 - }; 365 - 366 - let cursor = r.last().map(|r| r.did.clone()); 367 - let repos = r 368 - .into_iter() 369 - .map(|r| { 370 - sync::list_repos::RepoData { 371 - active: Some(true), 372 - did: Did::new(r.did).expect("should be a valid DID"), 373 - head: atrium_api::types::string::Cid::new( 374 - Cid::from_str(&r.root).expect("should be a valid CID"), 375 - ), 376 - rev: atrium_api::types::string::Tid::new(r.rev) 377 - .expect("should be able to convert Tid"), 378 - status: None, 379 - } 380 - .into() 381 - }) 382 - .collect::<Vec<_>>(); 383 - 384 - Ok(Json(sync::list_repos::OutputData { cursor, repos }.into())) 385 - } 386 - 387 - /// Repository event stream, aka Firehose endpoint. Outputs repo commits with diff data, and identity update events, 388 - /// for all repositories on the current server. See the atproto specifications for details around stream sequencing, 389 - /// repo versioning, CAR diff format, and more. Public and does not require auth; implemented by PDS and Relay. 390 - /// ### Query Parameters 391 - /// - `cursor`: The last known event seq number to backfill from. 392 - /// ### Responses 393 - /// - 200 OK: ... 394 - async fn subscribe_repos( 395 - ws_up: WebSocketUpgrade, 396 - State(fh): State<FirehoseProducer>, 397 - Query(input): Query<sync::subscribe_repos::ParametersData>, 398 - ) -> impl IntoResponse { 399 - ws_up.on_upgrade(async move |ws| { 400 - fh.client_connection(ws, input.cursor).await; 401 - }) 402 - } 403 - 404 - #[rustfmt::skip] 405 - /// These endpoints are part of the atproto repository synchronization APIs. Requests usually do not require authentication, 406 - /// and can be made to PDS intances or Relay instances. 407 - /// ### Routes 408 - /// - `GET /xrpc/com.atproto.sync.getBlob` -> [`get_blob`] 409 - /// - `GET /xrpc/com.atproto.sync.getBlocks` -> [`get_blocks`] 410 - /// - `GET /xrpc/com.atproto.sync.getLatestCommit` -> [`get_latest_commit`] 411 - /// - `GET /xrpc/com.atproto.sync.getRecord` -> [`get_record`] 412 - /// - `GET /xrpc/com.atproto.sync.getRepoStatus` -> [`get_repo_status`] 413 - /// - `GET /xrpc/com.atproto.sync.getRepo` -> [`get_repo`] 414 - /// - `GET /xrpc/com.atproto.sync.listBlobs` -> [`list_blobs`] 415 - /// - `GET /xrpc/com.atproto.sync.listRepos` -> [`list_repos`] 416 - /// - `GET /xrpc/com.atproto.sync.subscribeRepos` -> [`subscribe_repos`] 417 - pub(super) fn routes() -> Router<AppState> { 418 - Router::new() 419 - .route(concat!("/", sync::get_blob::NSID), get(get_blob)) 420 - .route(concat!("/", sync::get_blocks::NSID), get(get_blocks)) 421 - .route(concat!("/", sync::get_latest_commit::NSID), get(get_latest_commit)) 422 - .route(concat!("/", sync::get_record::NSID), get(get_record)) 423 - .route(concat!("/", sync::get_repo_status::NSID), get(get_repo_status)) 424 - .route(concat!("/", sync::get_repo::NSID), get(get_repo)) 425 - .route(concat!("/", sync::list_blobs::NSID), get(list_blobs)) 426 - .route(concat!("/", sync::list_repos::NSID), get(list_repos)) 427 - .route(concat!("/", sync::subscribe_repos::NSID), get(subscribe_repos)) 428 - }
+151
src/error.rs
··· 4 4 http::StatusCode, 5 5 response::{IntoResponse, Response}, 6 6 }; 7 + use rsky_pds::handle::{self, errors::ErrorKind}; 7 8 use thiserror::Error; 8 9 use tracing::error; 9 10 ··· 118 119 } 119 120 } 120 121 } 122 + 123 + /// API error types that can be returned to clients 124 + #[derive(Clone, Debug)] 125 + pub enum ApiError { 126 + RuntimeError, 127 + InvalidLogin, 128 + AccountTakendown, 129 + InvalidRequest(String), 130 + ExpiredToken, 131 + InvalidToken, 132 + RecordNotFound, 133 + InvalidHandle, 134 + InvalidEmail, 135 + InvalidPassword, 136 + InvalidInviteCode, 137 + HandleNotAvailable, 138 + EmailNotAvailable, 139 + UnsupportedDomain, 140 + UnresolvableDid, 141 + IncompatibleDidDoc, 142 + WellKnownNotFound, 143 + AccountNotFound, 144 + BlobNotFound, 145 + BadRequest(String, String), 146 + AuthRequiredError(String), 147 + } 148 + 149 + impl ApiError { 150 + /// Get the appropriate HTTP status code for this error 151 + const fn status_code(&self) -> StatusCode { 152 + match self { 153 + Self::RuntimeError => StatusCode::INTERNAL_SERVER_ERROR, 154 + Self::InvalidLogin 155 + | Self::ExpiredToken 156 + | Self::InvalidToken 157 + | Self::AuthRequiredError(_) => StatusCode::UNAUTHORIZED, 158 + Self::AccountTakendown => StatusCode::FORBIDDEN, 159 + Self::RecordNotFound 160 + | Self::WellKnownNotFound 161 + | Self::AccountNotFound 162 + | Self::BlobNotFound => StatusCode::NOT_FOUND, 163 + // All bad requests grouped together 164 + _ => StatusCode::BAD_REQUEST, 165 + } 166 + } 167 + 168 + /// Get the error type string for API responses 169 + fn error_type(&self) -> String { 170 + match self { 171 + Self::RuntimeError => "InternalServerError", 172 + Self::InvalidLogin => "InvalidLogin", 173 + Self::AccountTakendown => "AccountTakendown", 174 + Self::InvalidRequest(_) => "InvalidRequest", 175 + Self::ExpiredToken => "ExpiredToken", 176 + Self::InvalidToken => "InvalidToken", 177 + Self::RecordNotFound => "RecordNotFound", 178 + Self::InvalidHandle => "InvalidHandle", 179 + Self::InvalidEmail => "InvalidEmail", 180 + Self::InvalidPassword => "InvalidPassword", 181 + Self::InvalidInviteCode => "InvalidInviteCode", 182 + Self::HandleNotAvailable => "HandleNotAvailable", 183 + Self::EmailNotAvailable => "EmailNotAvailable", 184 + Self::UnsupportedDomain => "UnsupportedDomain", 185 + Self::UnresolvableDid => "UnresolvableDid", 186 + Self::IncompatibleDidDoc => "IncompatibleDidDoc", 187 + Self::WellKnownNotFound => "WellKnownNotFound", 188 + Self::AccountNotFound => "AccountNotFound", 189 + Self::BlobNotFound => "BlobNotFound", 190 + Self::BadRequest(error, _) => error, 191 + Self::AuthRequiredError(_) => "AuthRequiredError", 192 + } 193 + .to_owned() 194 + } 195 + 196 + /// Get the user-facing error message 197 + fn message(&self) -> String { 198 + match self { 199 + Self::RuntimeError => "Something went wrong", 200 + Self::InvalidLogin => "Invalid identifier or password", 201 + Self::AccountTakendown => "Account has been taken down", 202 + Self::InvalidRequest(msg) => msg, 203 + Self::ExpiredToken => "Token is expired", 204 + Self::InvalidToken => "Token is invalid", 205 + Self::RecordNotFound => "Record could not be found", 206 + Self::InvalidHandle => "Handle is invalid", 207 + Self::InvalidEmail => "Invalid email", 208 + Self::InvalidPassword => "Invalid Password", 209 + Self::InvalidInviteCode => "Invalid invite code", 210 + Self::HandleNotAvailable => "Handle not available", 211 + Self::EmailNotAvailable => "Email not available", 212 + Self::UnsupportedDomain => "Unsupported domain", 213 + Self::UnresolvableDid => "Unresolved Did", 214 + Self::IncompatibleDidDoc => "IncompatibleDidDoc", 215 + Self::WellKnownNotFound => "User not found", 216 + Self::AccountNotFound => "Account could not be found", 217 + Self::BlobNotFound => "Blob could not be found", 218 + Self::BadRequest(_, msg) => msg, 219 + Self::AuthRequiredError(msg) => msg, 220 + } 221 + .to_owned() 222 + } 223 + } 224 + 225 + impl From<Error> for ApiError { 226 + fn from(_value: Error) -> Self { 227 + Self::RuntimeError 228 + } 229 + } 230 + 231 + impl From<anyhow::Error> for ApiError { 232 + fn from(_value: anyhow::Error) -> Self { 233 + Self::RuntimeError 234 + } 235 + } 236 + 237 + impl From<handle::errors::Error> for ApiError { 238 + fn from(value: handle::errors::Error) -> Self { 239 + match value.kind { 240 + ErrorKind::InvalidHandle => Self::InvalidHandle, 241 + ErrorKind::HandleNotAvailable => Self::HandleNotAvailable, 242 + ErrorKind::UnsupportedDomain => Self::UnsupportedDomain, 243 + ErrorKind::InternalError => Self::RuntimeError, 244 + } 245 + } 246 + } 247 + 248 + impl IntoResponse for ApiError { 249 + fn into_response(self) -> Response { 250 + let status = self.status_code(); 251 + let error_type = self.error_type(); 252 + let message = self.message(); 253 + 254 + if cfg!(debug_assertions) { 255 + error!("API Error: {}: {}", error_type, message); 256 + } 257 + 258 + // Create the error message and serialize to JSON 259 + let error_message = ErrorMessage::new(error_type, message); 260 + let body = serde_json::to_string(&error_message).unwrap_or_else(|_| { 261 + r#"{"error":"InternalServerError","message":"Error serializing response"}"#.to_owned() 262 + }); 263 + 264 + // Build the response 265 + Response::builder() 266 + .status(status) 267 + .header("Content-Type", "application/json") 268 + .body(Body::new(body)) 269 + .expect("should be a valid response") 270 + } 271 + }
-426
src/firehose.rs
··· 1 - //! The firehose module. 2 - use std::{collections::VecDeque, time::Duration}; 3 - 4 - use anyhow::{Result, bail}; 5 - use atrium_api::{ 6 - com::atproto::sync::{self}, 7 - types::string::{Datetime, Did, Tid}, 8 - }; 9 - use atrium_repo::Cid; 10 - use axum::extract::ws::{Message, WebSocket}; 11 - use metrics::{counter, gauge}; 12 - use rand::Rng as _; 13 - use serde::{Serialize, ser::SerializeMap as _}; 14 - use tracing::{debug, error, info, warn}; 15 - 16 - use crate::{ 17 - Client, 18 - config::AppConfig, 19 - metrics::{FIREHOSE_HISTORY, FIREHOSE_LISTENERS, FIREHOSE_MESSAGES, FIREHOSE_SEQUENCE}, 20 - }; 21 - 22 - enum FirehoseMessage { 23 - Broadcast(sync::subscribe_repos::Message), 24 - Connect(Box<(WebSocket, Option<i64>)>), 25 - } 26 - 27 - enum FrameHeader { 28 - Error, 29 - Message(String), 30 - } 31 - 32 - impl Serialize for FrameHeader { 33 - #[expect(clippy::question_mark_used, reason = "returns a Result")] 34 - fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error> 35 - where 36 - S: serde::Serializer, 37 - { 38 - let mut map = serializer.serialize_map(None)?; 39 - 40 - match *self { 41 - Self::Message(ref s) => { 42 - map.serialize_key("op")?; 43 - map.serialize_value(&1_i32)?; 44 - map.serialize_key("t")?; 45 - map.serialize_value(s.as_str())?; 46 - } 47 - Self::Error => { 48 - map.serialize_key("op")?; 49 - map.serialize_value(&-1_i32)?; 50 - } 51 - } 52 - 53 - map.end() 54 - } 55 - } 56 - 57 - /// A repository operation. 58 - pub(crate) enum RepoOp { 59 - /// Create a new record. 60 - Create { 61 - /// The CID of the record. 62 - cid: Cid, 63 - /// The path of the record. 64 - path: String, 65 - }, 66 - /// Delete an existing record. 67 - Delete { 68 - /// The path of the record. 69 - path: String, 70 - /// The previous CID of the record. 71 - prev: Cid, 72 - }, 73 - /// Update an existing record. 74 - Update { 75 - /// The CID of the record. 76 - cid: Cid, 77 - /// The path of the record. 78 - path: String, 79 - /// The previous CID of the record. 80 - prev: Cid, 81 - }, 82 - } 83 - 84 - impl From<RepoOp> for sync::subscribe_repos::RepoOp { 85 - fn from(val: RepoOp) -> Self { 86 - let (action, cid, prev, path) = match val { 87 - RepoOp::Create { cid, path } => ("create", Some(cid), None, path), 88 - RepoOp::Update { cid, path, prev } => ("update", Some(cid), Some(prev), path), 89 - RepoOp::Delete { path, prev } => ("delete", None, Some(prev), path), 90 - }; 91 - 92 - sync::subscribe_repos::RepoOpData { 93 - action: action.to_owned(), 94 - cid: cid.map(atrium_api::types::CidLink), 95 - prev: prev.map(atrium_api::types::CidLink), 96 - path, 97 - } 98 - .into() 99 - } 100 - } 101 - 102 - /// A commit to the repository. 103 - pub(crate) struct Commit { 104 - /// Blobs that were created in this commit. 105 - pub blobs: Vec<Cid>, 106 - /// The car file containing the commit blocks. 107 - pub car: Vec<u8>, 108 - /// The CID of the commit. 109 - pub cid: Cid, 110 - /// The DID of the repository changed. 111 - pub did: Did, 112 - /// The operations performed in this commit. 113 - pub ops: Vec<RepoOp>, 114 - /// The previous commit's CID (if applicable). 115 - pub pcid: Option<Cid>, 116 - /// The revision of the commit. 117 - pub rev: String, 118 - } 119 - 120 - impl From<Commit> for sync::subscribe_repos::Commit { 121 - fn from(val: Commit) -> Self { 122 - sync::subscribe_repos::CommitData { 123 - blobs: val 124 - .blobs 125 - .into_iter() 126 - .map(atrium_api::types::CidLink) 127 - .collect::<Vec<_>>(), 128 - blocks: val.car, 129 - commit: atrium_api::types::CidLink(val.cid), 130 - ops: val.ops.into_iter().map(Into::into).collect::<Vec<_>>(), 131 - prev_data: val.pcid.map(atrium_api::types::CidLink), 132 - rebase: false, 133 - repo: val.did, 134 - rev: Tid::new(val.rev).expect("should be valid revision"), 135 - seq: 0, 136 - since: None, 137 - time: Datetime::now(), 138 - too_big: false, 139 - } 140 - .into() 141 - } 142 - } 143 - 144 - /// A firehose producer. This is used to transmit messages to the firehose for broadcast. 145 - #[derive(Clone, Debug)] 146 - pub(crate) struct FirehoseProducer { 147 - /// The channel to send messages to the firehose. 148 - tx: tokio::sync::mpsc::Sender<FirehoseMessage>, 149 - } 150 - 151 - impl FirehoseProducer { 152 - /// Broadcast an `#account` event. 153 - pub(crate) async fn account(&self, account: impl Into<sync::subscribe_repos::Account>) { 154 - drop( 155 - self.tx 156 - .send(FirehoseMessage::Broadcast( 157 - sync::subscribe_repos::Message::Account(Box::new(account.into())), 158 - )) 159 - .await, 160 - ); 161 - } 162 - /// Handle client connection. 163 - pub(crate) async fn client_connection(&self, ws: WebSocket, cursor: Option<i64>) { 164 - drop( 165 - self.tx 166 - .send(FirehoseMessage::Connect(Box::new((ws, cursor)))) 167 - .await, 168 - ); 169 - } 170 - /// Broadcast a `#commit` event. 171 - pub(crate) async fn commit(&self, commit: impl Into<sync::subscribe_repos::Commit>) { 172 - drop( 173 - self.tx 174 - .send(FirehoseMessage::Broadcast( 175 - sync::subscribe_repos::Message::Commit(Box::new(commit.into())), 176 - )) 177 - .await, 178 - ); 179 - } 180 - /// Broadcast an `#identity` event. 181 - pub(crate) async fn identity(&self, identity: impl Into<sync::subscribe_repos::Identity>) { 182 - drop( 183 - self.tx 184 - .send(FirehoseMessage::Broadcast( 185 - sync::subscribe_repos::Message::Identity(Box::new(identity.into())), 186 - )) 187 - .await, 188 - ); 189 - } 190 - } 191 - 192 - #[expect( 193 - clippy::as_conversions, 194 - clippy::cast_possible_truncation, 195 - clippy::cast_sign_loss, 196 - clippy::cast_precision_loss, 197 - clippy::arithmetic_side_effects 198 - )] 199 - /// Convert a `usize` to a `f64`. 200 - const fn convert_usize_f64(x: usize) -> Result<f64, &'static str> { 201 - let result = x as f64; 202 - if result as usize - x > 0 { 203 - return Err("cannot convert"); 204 - } 205 - Ok(result) 206 - } 207 - 208 - /// Serialize a message. 209 - fn serialize_message(seq: u64, mut msg: sync::subscribe_repos::Message) -> (&'static str, Vec<u8>) { 210 - let mut dummy_seq = 0_i64; 211 - #[expect(clippy::pattern_type_mismatch)] 212 - let (ty, nseq) = match &mut msg { 213 - sync::subscribe_repos::Message::Account(m) => ("#account", &mut m.seq), 214 - sync::subscribe_repos::Message::Commit(m) => ("#commit", &mut m.seq), 215 - sync::subscribe_repos::Message::Identity(m) => ("#identity", &mut m.seq), 216 - sync::subscribe_repos::Message::Sync(m) => ("#sync", &mut m.seq), 217 - sync::subscribe_repos::Message::Info(_m) => ("#info", &mut dummy_seq), 218 - }; 219 - // Set the sequence number. 220 - *nseq = i64::try_from(seq).expect("should find seq"); 221 - 222 - let hdr = FrameHeader::Message(ty.to_owned()); 223 - 224 - let mut frame = Vec::new(); 225 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 226 - serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message"); 227 - 228 - (ty, frame) 229 - } 230 - 231 - /// Broadcast a message out to all clients. 232 - async fn broadcast_message(clients: &mut Vec<WebSocket>, msg: Message) -> Result<()> { 233 - counter!(FIREHOSE_MESSAGES).increment(1); 234 - 235 - for i in (0..clients.len()).rev() { 236 - let client = clients.get_mut(i).expect("should find client"); 237 - if let Err(e) = client.send(msg.clone()).await { 238 - debug!("Firehose client disconnected: {e}"); 239 - drop(clients.remove(i)); 240 - } 241 - } 242 - 243 - gauge!(FIREHOSE_LISTENERS) 244 - .set(convert_usize_f64(clients.len()).expect("should find clients length")); 245 - Ok(()) 246 - } 247 - 248 - /// Handle a new connection from a websocket client created by subscribeRepos. 249 - async fn handle_connect( 250 - mut ws: WebSocket, 251 - seq: u64, 252 - history: &VecDeque<(u64, &str, sync::subscribe_repos::Message)>, 253 - cursor: Option<i64>, 254 - ) -> Result<WebSocket> { 255 - if let Some(cursor) = cursor { 256 - let mut frame = Vec::new(); 257 - let cursor = u64::try_from(cursor); 258 - if cursor.is_err() { 259 - tracing::warn!("cursor is not a valid u64"); 260 - return Ok(ws); 261 - } 262 - let cursor = cursor.expect("should be valid u64"); 263 - // Cursor specified; attempt to backfill the consumer. 264 - if cursor > seq { 265 - let hdr = FrameHeader::Error; 266 - let msg = sync::subscribe_repos::Error::FutureCursor(Some(format!( 267 - "cursor {cursor} is greater than the current sequence number {seq}" 268 - ))); 269 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 270 - serde_ipld_dagcbor::to_writer(&mut frame, &msg).expect("should serialize message"); 271 - // Drop the connection. 272 - drop(ws.send(Message::binary(frame)).await); 273 - bail!( 274 - "connection dropped: cursor {cursor} is greater than the current sequence number {seq}" 275 - ); 276 - } 277 - 278 - for &(historical_seq, ty, ref msg) in history { 279 - if cursor > historical_seq { 280 - continue; 281 - } 282 - let hdr = FrameHeader::Message(ty.to_owned()); 283 - serde_ipld_dagcbor::to_writer(&mut frame, &hdr).expect("should serialize header"); 284 - serde_ipld_dagcbor::to_writer(&mut frame, msg).expect("should serialize message"); 285 - if let Err(e) = ws.send(Message::binary(frame.clone())).await { 286 - debug!("Firehose client disconnected during backfill: {e}"); 287 - break; 288 - } 289 - // Clear out the frame to begin a new one. 290 - frame.clear(); 291 - } 292 - } 293 - 294 - Ok(ws) 295 - } 296 - 297 - /// Reconnect to upstream relays. 298 - pub(crate) async fn reconnect_relays(client: &Client, config: &AppConfig) { 299 - // Avoid connecting to upstream relays in test mode. 300 - if config.test { 301 - return; 302 - } 303 - 304 - info!("attempting to reconnect to upstream relays"); 305 - for relay in &config.firehose.relays { 306 - let Some(host) = relay.host_str() else { 307 - warn!("relay {} has no host specified", relay); 308 - continue; 309 - }; 310 - 311 - let r = client 312 - .post(format!("https://{host}/xrpc/com.atproto.sync.requestCrawl")) 313 - .json(&serde_json::json!({ 314 - "hostname": format!("https://{}", config.host_name) 315 - })) 316 - .send() 317 - .await; 318 - 319 - let r = match r { 320 - Ok(r) => r, 321 - Err(e) => { 322 - error!("failed to hit upstream relay {host}: {e}"); 323 - continue; 324 - } 325 - }; 326 - 327 - let s = r.status(); 328 - if let Err(e) = r.error_for_status_ref() { 329 - error!("failed to hit upstream relay {host}: {e}"); 330 - } 331 - 332 - let b = r.json::<serde_json::Value>().await; 333 - if let Ok(b) = b { 334 - info!("relay {host}: {} {}", s, b); 335 - } else { 336 - info!("relay {host}: {}", s); 337 - } 338 - } 339 - } 340 - 341 - /// The main entrypoint for the firehose. 342 - /// 343 - /// This will broadcast all updates in this PDS out to anyone who is listening. 344 - /// 345 - /// Reference: <https://atproto.com/specs/sync> 346 - pub(crate) fn spawn( 347 - client: Client, 348 - config: AppConfig, 349 - ) -> (tokio::task::JoinHandle<()>, FirehoseProducer) { 350 - let (tx, mut rx) = tokio::sync::mpsc::channel(1000); 351 - let handle = tokio::spawn(async move { 352 - fn time_since_inception() -> u64 { 353 - chrono::Utc::now() 354 - .timestamp_micros() 355 - .checked_sub(1_743_442_000_000_000) 356 - .expect("should not wrap") 357 - .unsigned_abs() 358 - } 359 - let mut clients: Vec<WebSocket> = Vec::new(); 360 - let mut history = VecDeque::with_capacity(1000); 361 - let mut seq = time_since_inception(); 362 - 363 - loop { 364 - if let Ok(msg) = tokio::time::timeout(Duration::from_secs(30), rx.recv()).await { 365 - match msg { 366 - Some(FirehoseMessage::Broadcast(msg)) => { 367 - let (ty, by) = serialize_message(seq, msg.clone()); 368 - 369 - history.push_back((seq, ty, msg)); 370 - gauge!(FIREHOSE_HISTORY).set( 371 - convert_usize_f64(history.len()).expect("should find history length"), 372 - ); 373 - 374 - info!( 375 - "Broadcasting message {} {} to {} clients", 376 - seq, 377 - ty, 378 - clients.len() 379 - ); 380 - 381 - counter!(FIREHOSE_SEQUENCE).absolute(seq); 382 - let now = time_since_inception(); 383 - if now > seq { 384 - seq = now; 385 - } else { 386 - seq = seq.checked_add(1).expect("should not wrap"); 387 - } 388 - 389 - drop(broadcast_message(&mut clients, Message::binary(by)).await); 390 - } 391 - Some(FirehoseMessage::Connect(ws_cursor)) => { 392 - let (ws, cursor) = *ws_cursor; 393 - match handle_connect(ws, seq, &history, cursor).await { 394 - Ok(r) => { 395 - gauge!(FIREHOSE_LISTENERS).increment(1_i32); 396 - clients.push(r); 397 - } 398 - Err(e) => { 399 - error!("failed to connect new client: {e}"); 400 - } 401 - } 402 - } 403 - // All producers have been destroyed. 404 - None => break, 405 - } 406 - } else { 407 - if clients.is_empty() { 408 - reconnect_relays(&client, &config).await; 409 - } 410 - 411 - let contents = rand::thread_rng() 412 - .sample_iter(rand::distributions::Alphanumeric) 413 - .take(15) 414 - .map(char::from) 415 - .collect::<String>(); 416 - 417 - // Send a websocket ping message. 418 - // Reference: https://developer.mozilla.org/en-US/docs/Web/API/WebSockets_API/Writing_WebSocket_servers#pings_and_pongs_the_heartbeat_of_websockets 419 - let message = Message::Ping(axum::body::Bytes::from_owner(contents)); 420 - drop(broadcast_message(&mut clients, message).await); 421 - } 422 - } 423 - }); 424 - 425 - (handle, FirehoseProducer { tx }) 426 - }
+42
src/lib.rs
··· 1 + //! PDS implementation. 2 + mod account_manager; 3 + mod actor_endpoints; 4 + mod actor_store; 5 + mod apis; 6 + mod auth; 7 + mod config; 8 + mod db; 9 + mod did; 10 + pub mod error; 11 + mod metrics; 12 + mod models; 13 + mod oauth; 14 + mod pipethrough; 15 + mod schema; 16 + mod serve; 17 + mod service_proxy; 18 + 19 + pub use serve::run; 20 + 21 + /// The index (/) route. 22 + async fn index() -> impl axum::response::IntoResponse { 23 + r" 24 + __ __ 25 + /\ \__ /\ \__ 26 + __ \ \ ,_\ _____ _ __ ___\ \ ,_\ ___ 27 + /'__'\ \ \ \/ /\ '__'\/\''__\/ __'\ \ \/ / __'\ 28 + /\ \L\.\_\ \ \_\ \ \L\ \ \ \//\ \L\ \ \ \_/\ \L\ \ 29 + \ \__/.\_\\ \__\\ \ ,__/\ \_\\ \____/\ \__\ \____/ 30 + \/__/\/_/ \/__/ \ \ \/ \/_/ \/___/ \/__/\/___/ 31 + \ \_\ 32 + \/_/ 33 + 34 + 35 + This is an AT Protocol Personal Data Server (aka, an atproto PDS) 36 + 37 + Most API routes are under /xrpc/ 38 + 39 + Code: https://github.com/DrChat/bluepds 40 + Protocol: https://atproto.com 41 + " 42 + }
+3 -558
src/main.rs
··· 1 - //! PDS implementation. 2 - mod account_manager; 3 - mod actor_store; 4 - mod auth; 5 - mod config; 6 - mod db; 7 - mod did; 8 - mod endpoints; 9 - mod error; 10 - mod firehose; 11 - mod metrics; 12 - mod mmap; 13 - mod oauth; 14 - mod plc; 15 - mod schema; 16 - #[cfg(test)] 17 - mod tests; 18 - 19 - /// HACK: store private user preferences in the PDS. 20 - /// 21 - /// We shouldn't have to know about any bsky endpoints to store private user data. 22 - /// This will _very likely_ be changed in the future. 23 - mod actor_endpoints; 1 + //! BluePDS binary entry point. 24 2 25 - use anyhow::{Context as _, anyhow}; 26 - use atrium_api::types::string::Did; 27 - use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 28 - use auth::AuthenticatedUser; 29 - use axum::{ 30 - Router, 31 - body::Body, 32 - extract::{FromRef, Request, State}, 33 - http::{self, HeaderMap, Response, StatusCode, Uri}, 34 - response::IntoResponse, 35 - routing::get, 36 - }; 37 - use azure_core::credentials::TokenCredential; 38 - use clap::Parser; 39 - use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 40 - use config::AppConfig; 41 - use db::establish_pool; 42 - use deadpool_diesel::sqlite::Pool; 43 - use diesel::prelude::*; 44 - use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 45 - #[expect(clippy::pub_use, clippy::useless_attribute)] 46 - pub use error::Error; 47 - use figment::{Figment, providers::Format as _}; 48 - use firehose::FirehoseProducer; 49 - use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 50 - use rand::Rng as _; 51 - use serde::{Deserialize, Serialize}; 52 - use std::{ 53 - net::{IpAddr, Ipv4Addr, SocketAddr}, 54 - path::PathBuf, 55 - str::FromStr as _, 56 - sync::Arc, 57 - }; 58 - use tokio::net::TcpListener; 59 - use tower_http::{cors::CorsLayer, trace::TraceLayer}; 60 - use tracing::{info, warn}; 61 - use uuid::Uuid; 62 - 63 - /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 64 - pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 65 - 66 - /// Embedded migrations 67 - pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 68 - 69 - /// The application-wide result type. 70 - pub type Result<T> = std::result::Result<T, Error>; 71 - /// The reqwest client type with middleware. 72 - pub type Client = reqwest_middleware::ClientWithMiddleware; 73 - /// The Azure credential type. 74 - pub type Cred = Arc<dyn TokenCredential>; 75 - 76 - #[expect( 77 - clippy::arbitrary_source_item_ordering, 78 - reason = "serialized data might be structured" 79 - )] 80 - #[derive(Serialize, Deserialize, Debug, Clone)] 81 - /// The key data structure. 82 - struct KeyData { 83 - /// Primary signing key for all repo operations. 84 - skey: Vec<u8>, 85 - /// Primary signing (rotation) key for all PLC operations. 86 - rkey: Vec<u8>, 87 - } 88 - 89 - // FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 90 - // and the implementations of this algorithm are much more limited as compared to P256. 91 - // 92 - // Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 93 - #[derive(Clone)] 94 - /// The signing key for PLC/DID operations. 95 - pub struct SigningKey(Arc<Secp256k1Keypair>); 96 - #[derive(Clone)] 97 - /// The rotation key for PLC operations. 98 - pub struct RotationKey(Arc<Secp256k1Keypair>); 99 - 100 - impl std::ops::Deref for SigningKey { 101 - type Target = Secp256k1Keypair; 102 - 103 - fn deref(&self) -> &Self::Target { 104 - &self.0 105 - } 106 - } 107 - 108 - impl SigningKey { 109 - /// Import from a private key. 110 - pub fn import(key: &[u8]) -> Result<Self> { 111 - let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 112 - Ok(Self(Arc::new(key))) 113 - } 114 - } 115 - 116 - impl std::ops::Deref for RotationKey { 117 - type Target = Secp256k1Keypair; 118 - 119 - fn deref(&self) -> &Self::Target { 120 - &self.0 121 - } 122 - } 123 - 124 - #[derive(Parser, Debug, Clone)] 125 - /// Command line arguments. 126 - struct Args { 127 - /// Path to the configuration file 128 - #[arg(short, long, default_value = "default.toml")] 129 - config: PathBuf, 130 - /// The verbosity level. 131 - #[command(flatten)] 132 - verbosity: Verbosity<InfoLevel>, 133 - } 134 - 135 - struct ActorPools { 136 - repo: Pool, 137 - blob: Pool, 138 - } 139 - impl Clone for ActorPools { 140 - fn clone(&self) -> Self { 141 - Self { 142 - repo: self.repo.clone(), 143 - blob: self.blob.clone(), 144 - } 145 - } 146 - } 147 - 148 - #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 149 - #[derive(Clone, FromRef)] 150 - struct AppState { 151 - /// The application configuration. 152 - config: AppConfig, 153 - /// The Azure credential. 154 - cred: Cred, 155 - /// The main database connection pool. Used for common PDS data, like invite codes. 156 - db: Pool, 157 - /// Actor-specific database connection pools. Hashed by DID. 158 - db_actors: std::collections::HashMap<String, ActorPools>, 159 - 160 - /// The HTTP client with middleware. 161 - client: Client, 162 - /// The simple HTTP client. 163 - simple_client: reqwest::Client, 164 - /// The firehose producer. 165 - firehose: FirehoseProducer, 166 - 167 - /// The signing key. 168 - signing_key: SigningKey, 169 - /// The rotation key. 170 - rotation_key: RotationKey, 171 - } 172 - 173 - /// The index (/) route. 174 - async fn index() -> impl IntoResponse { 175 - r" 176 - __ __ 177 - /\ \__ /\ \__ 178 - __ \ \ ,_\ _____ _ __ ___\ \ ,_\ ___ 179 - /'__'\ \ \ \/ /\ '__'\/\''__\/ __'\ \ \/ / __'\ 180 - /\ \L\.\_\ \ \_\ \ \L\ \ \ \//\ \L\ \ \ \_/\ \L\ \ 181 - \ \__/.\_\\ \__\\ \ ,__/\ \_\\ \____/\ \__\ \____/ 182 - \/__/\/_/ \/__/ \ \ \/ \/_/ \/___/ \/__/\/___/ 183 - \ \_\ 184 - \/_/ 185 - 186 - 187 - This is an AT Protocol Personal Data Server (aka, an atproto PDS) 188 - 189 - Most API routes are under /xrpc/ 190 - 191 - Code: https://github.com/DrChat/bluepds 192 - Protocol: https://atproto.com 193 - " 194 - } 195 - 196 - /// Service proxy. 197 - /// 198 - /// Reference: <https://atproto.com/specs/xrpc#service-proxying> 199 - async fn service_proxy( 200 - uri: Uri, 201 - user: AuthenticatedUser, 202 - State(skey): State<SigningKey>, 203 - State(client): State<reqwest::Client>, 204 - headers: HeaderMap, 205 - request: Request<Body>, 206 - ) -> Result<Response<Body>> { 207 - let url_path = uri.path_and_query().context("invalid service proxy url")?; 208 - let lxm = url_path 209 - .path() 210 - .strip_prefix("/") 211 - .with_context(|| format!("invalid service proxy url prefix: {}", url_path.path()))?; 212 - 213 - let user_did = user.did(); 214 - let (did, id) = match headers.get("atproto-proxy") { 215 - Some(val) => { 216 - let val = 217 - std::str::from_utf8(val.as_bytes()).context("proxy header not valid utf-8")?; 218 - 219 - let (did, id) = val.split_once('#').context("invalid proxy header")?; 220 - 221 - let did = 222 - Did::from_str(did).map_err(|e| anyhow!("atproto proxy not a valid DID: {e}"))?; 223 - 224 - (did, format!("#{id}")) 225 - } 226 - // HACK: Assume the bluesky appview by default. 227 - None => ( 228 - Did::new("did:web:api.bsky.app".to_owned()) 229 - .expect("service proxy should be a valid DID"), 230 - "#bsky_appview".to_owned(), 231 - ), 232 - }; 233 - 234 - let did_doc = did::resolve(&Client::new(client.clone(), []), did.clone()) 235 - .await 236 - .with_context(|| format!("failed to resolve did document {}", did.as_str()))?; 237 - 238 - let Some(service) = did_doc.service.iter().find(|s| s.id == id) else { 239 - return Err(Error::with_status( 240 - StatusCode::BAD_REQUEST, 241 - anyhow!("could not find resolve service #{id}"), 242 - )); 243 - }; 244 - 245 - let target_url: url::Url = service 246 - .service_endpoint 247 - .join(&format!("/xrpc{url_path}")) 248 - .context("failed to construct target url")?; 249 - 250 - let exp = (chrono::Utc::now().checked_add_signed(chrono::Duration::minutes(1))) 251 - .context("should be valid expiration datetime")? 252 - .timestamp(); 253 - let jti = rand::thread_rng() 254 - .sample_iter(rand::distributions::Alphanumeric) 255 - .take(10) 256 - .map(char::from) 257 - .collect::<String>(); 258 - 259 - // Mint a bearer token by signing a JSON web token. 260 - // https://github.com/DavidBuchanan314/millipds/blob/5c7529a739d394e223c0347764f1cf4e8fd69f94/src/millipds/appview_proxy.py#L47-L59 261 - let token = auth::sign( 262 - &skey, 263 - "JWT", 264 - &serde_json::json!({ 265 - "iss": user_did.as_str(), 266 - "aud": did.as_str(), 267 - "lxm": lxm, 268 - "exp": exp, 269 - "jti": jti, 270 - }), 271 - ) 272 - .context("failed to sign jwt")?; 273 - 274 - let mut h = HeaderMap::new(); 275 - if let Some(hdr) = request.headers().get("atproto-accept-labelers") { 276 - drop(h.insert("atproto-accept-labelers", hdr.clone())); 277 - } 278 - if let Some(hdr) = request.headers().get(http::header::CONTENT_TYPE) { 279 - drop(h.insert(http::header::CONTENT_TYPE, hdr.clone())); 280 - } 281 - 282 - let r = client 283 - .request(request.method().clone(), target_url) 284 - .headers(h) 285 - .header(http::header::AUTHORIZATION, format!("Bearer {token}")) 286 - .body(reqwest::Body::wrap_stream( 287 - request.into_body().into_data_stream(), 288 - )) 289 - .send() 290 - .await 291 - .context("failed to send request")?; 292 - 293 - let mut resp = Response::builder().status(r.status()); 294 - if let Some(hdrs) = resp.headers_mut() { 295 - *hdrs = r.headers().clone(); 296 - } 297 - 298 - let resp = resp 299 - .body(Body::from_stream(r.bytes_stream())) 300 - .context("failed to construct response")?; 301 - 302 - Ok(resp) 303 - } 304 - 305 - /// The main application entry point. 306 - #[expect( 307 - clippy::cognitive_complexity, 308 - clippy::too_many_lines, 309 - unused_qualifications, 310 - reason = "main function has high complexity" 311 - )] 312 - async fn run() -> anyhow::Result<()> { 313 - let args = Args::parse(); 314 - 315 - // Set up trace logging to console and account for the user-provided verbosity flag. 316 - if args.verbosity.log_level_filter() != LevelFilter::Off { 317 - let lvl = match args.verbosity.log_level_filter() { 318 - LevelFilter::Error => tracing::Level::ERROR, 319 - LevelFilter::Warn => tracing::Level::WARN, 320 - LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 321 - LevelFilter::Debug => tracing::Level::DEBUG, 322 - LevelFilter::Trace => tracing::Level::TRACE, 323 - }; 324 - tracing_subscriber::fmt().with_max_level(lvl).init(); 325 - } 326 - 327 - if !args.config.exists() { 328 - // Throw up a warning if the config file does not exist. 329 - // 330 - // This is not fatal because users can specify all configuration settings via 331 - // the environment, but the most likely scenario here is that a user accidentally 332 - // omitted the config file for some reason (e.g. forgot to mount it into Docker). 333 - warn!( 334 - "configuration file {} does not exist", 335 - args.config.display() 336 - ); 337 - } 338 - 339 - // Read and parse the user-provided configuration. 340 - let config: AppConfig = Figment::new() 341 - .admerge(figment::providers::Toml::file(args.config)) 342 - .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 343 - .extract() 344 - .context("failed to load configuration")?; 345 - 346 - if config.test { 347 - warn!("BluePDS starting up in TEST mode."); 348 - warn!("This means the application will not federate with the rest of the network."); 349 - warn!( 350 - "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 351 - ); 352 - } 353 - 354 - // Initialize metrics reporting. 355 - metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 356 - 357 - // Create a reqwest client that will be used for all outbound requests. 358 - let simple_client = reqwest::Client::builder() 359 - .user_agent(APP_USER_AGENT) 360 - .build() 361 - .context("failed to build requester client")?; 362 - let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 363 - .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 364 - mode: CacheMode::Default, 365 - manager: MokaManager::default(), 366 - options: HttpCacheOptions::default(), 367 - })) 368 - .build(); 369 - 370 - tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 371 - .await 372 - .context("failed to create key directory")?; 373 - 374 - // Check if crypto keys exist. If not, create new ones. 375 - let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 376 - let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 377 - .context("failed to deserialize crypto keys")?; 378 - 379 - let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 380 - let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 381 - 382 - (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 383 - } else { 384 - info!("signing keys not found, generating new ones"); 385 - 386 - let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 387 - let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 388 - 389 - let keys = KeyData { 390 - skey: skey.export(), 391 - rkey: rkey.export(), 392 - }; 393 - 394 - let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 395 - serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 396 - 397 - (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 398 - }; 399 - 400 - tokio::fs::create_dir_all(&config.repo.path).await?; 401 - tokio::fs::create_dir_all(&config.plc.path).await?; 402 - tokio::fs::create_dir_all(&config.blob.path).await?; 403 - 404 - let cred = azure_identity::DefaultAzureCredential::new() 405 - .context("failed to create Azure credential")?; 406 - 407 - // Create a database connection manager and pool for the main database. 408 - let pool = 409 - establish_pool(&config.db).context("failed to establish database connection pool")?; 410 - // Create a dictionary of database connection pools for each actor. 411 - let mut actor_pools = std::collections::HashMap::new(); 412 - // let mut actor_blob_pools = std::collections::HashMap::new(); 413 - // We'll determine actors by looking in the data/repo dir for .db files. 414 - let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 415 - .await 416 - .context("failed to read repo directory")?; 417 - while let Some(entry) = actor_dbs 418 - .next_entry() 419 - .await 420 - .context("failed to read repo dir")? 421 - { 422 - let path = entry.path(); 423 - if path.extension().and_then(|s| s.to_str()) == Some("db") { 424 - let did = path 425 - .file_stem() 426 - .and_then(|s| s.to_str()) 427 - .context("failed to get actor DID")?; 428 - let did = Did::from_str(did).expect("should be able to parse actor DID"); 429 - 430 - // Create a new database connection manager and pool for the actor. 431 - // The path for the SQLite connection needs to look like "sqlite://data/repo/<actor>.db" 432 - let path_repo = format!("sqlite://{}", path.display()); 433 - let actor_repo_pool = 434 - establish_pool(&path_repo).context("failed to create database connection pool")?; 435 - // Create a new database connection manager and pool for the actor blobs. 436 - // The path for the SQLite connection needs to look like "sqlite://data/blob/<actor>.db" 437 - let path_blob = path_repo.replace("repo", "blob"); 438 - let actor_blob_pool = 439 - establish_pool(&path_blob).context("failed to create database connection pool")?; 440 - actor_pools.insert( 441 - did.to_string(), 442 - ActorPools { 443 - repo: actor_repo_pool, 444 - blob: actor_blob_pool, 445 - }, 446 - ); 447 - } 448 - } 449 - // Apply pending migrations 450 - // let conn = pool.get().await?; 451 - // conn.run_pending_migrations(MIGRATIONS) 452 - // .expect("should be able to run migrations"); 453 - 454 - let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 455 - 456 - let addr = config 457 - .listen_address 458 - .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 459 - 460 - let app = Router::new() 461 - .route("/", get(index)) 462 - .merge(oauth::routes()) 463 - .nest( 464 - "/xrpc", 465 - endpoints::routes() 466 - .merge(actor_endpoints::routes()) 467 - .fallback(service_proxy), 468 - ) 469 - // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 470 - .layer(CorsLayer::permissive()) 471 - .layer(TraceLayer::new_for_http()) 472 - .with_state(AppState { 473 - cred, 474 - config: config.clone(), 475 - db: pool.clone(), 476 - db_actors: actor_pools.clone(), 477 - client: client.clone(), 478 - simple_client, 479 - firehose: fhp, 480 - signing_key: skey, 481 - rotation_key: rkey, 482 - }); 483 - 484 - info!("listening on {addr}"); 485 - info!("connect to: http://127.0.0.1:{}", addr.port()); 486 - 487 - // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 488 - // If so, create an invite code and share it via the console. 489 - let conn = pool.get().await.context("failed to get db connection")?; 490 - 491 - #[derive(QueryableByName)] 492 - struct TotalCount { 493 - #[diesel(sql_type = diesel::sql_types::Integer)] 494 - total_count: i32, 495 - } 496 - 497 - // let result = diesel::sql_query( 498 - // "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count", 499 - // ) 500 - // .get_result::<TotalCount>(conn) 501 - // .context("failed to query database")?; 502 - let result = conn.interact(move |conn| { 503 - diesel::sql_query( 504 - "SELECT (SELECT COUNT(*) FROM accounts) + (SELECT COUNT(*) FROM invites) AS total_count", 505 - ) 506 - .get_result::<TotalCount>(conn) 507 - }) 508 - .await 509 - .expect("should be able to query database")?; 510 - 511 - let c = result.total_count; 512 - 513 - #[expect(clippy::print_stdout)] 514 - if c == 0 { 515 - let uuid = Uuid::new_v4().to_string(); 516 - 517 - let uuid_clone = uuid.clone(); 518 - conn.interact(move |conn| { 519 - diesel::sql_query( 520 - "INSERT INTO invites (id, did, count, created_at) VALUES (?, NULL, 1, datetime('now'))", 521 - ) 522 - .bind::<diesel::sql_types::Text, _>(uuid_clone) 523 - .execute(conn) 524 - .context("failed to create new invite code") 525 - .expect("should be able to create invite code") 526 - }); 527 - 528 - // N.B: This is a sensitive message, so we're bypassing `tracing` here and 529 - // logging it directly to console. 530 - println!("====================================="); 531 - println!(" FIRST STARTUP "); 532 - println!("====================================="); 533 - println!("Use this code to create an account:"); 534 - println!("{uuid}"); 535 - println!("====================================="); 536 - } 537 - 538 - let listener = TcpListener::bind(&addr) 539 - .await 540 - .context("failed to bind address")?; 541 - 542 - // Serve the app, and request crawling from upstream relays. 543 - let serve = tokio::spawn(async move { 544 - axum::serve(listener, app.into_make_service()) 545 - .await 546 - .context("failed to serve app") 547 - }); 548 - 549 - // Now that the app is live, request a crawl from upstream relays. 550 - firehose::reconnect_relays(&client, &config).await; 551 - 552 - serve 553 - .await 554 - .map_err(Into::into) 555 - .and_then(|r| r) 556 - .context("failed to serve app") 557 - } 3 + use anyhow::Context as _; 558 4 559 5 #[tokio::main(flavor = "multi_thread")] 560 6 async fn main() -> anyhow::Result<()> { 561 - // Dispatch out to a separate function without a derive macro to help rust-analyzer along. 562 - run().await 7 + bluepds::run().await.context("failed to run application") 563 8 }
-274
src/mmap.rs
··· 1 - #![allow(clippy::arbitrary_source_item_ordering)] 2 - use std::io::{ErrorKind, Read as _, Seek as _, Write as _}; 3 - 4 - #[cfg(unix)] 5 - use std::os::fd::AsRawFd as _; 6 - #[cfg(windows)] 7 - use std::os::windows::io::AsRawHandle; 8 - 9 - use memmap2::{MmapMut, MmapOptions}; 10 - 11 - pub(crate) struct MappedFile { 12 - /// The underlying file handle. 13 - file: std::fs::File, 14 - /// The length of the file. 15 - len: u64, 16 - /// The mapped memory region. 17 - map: MmapMut, 18 - /// Our current offset into the file. 19 - off: u64, 20 - } 21 - 22 - impl MappedFile { 23 - pub(crate) fn new(mut f: std::fs::File) -> std::io::Result<Self> { 24 - let len = f.seek(std::io::SeekFrom::End(0))?; 25 - 26 - #[cfg(windows)] 27 - let raw = f.as_raw_handle(); 28 - #[cfg(unix)] 29 - let raw = f.as_raw_fd(); 30 - 31 - #[expect(unsafe_code)] 32 - Ok(Self { 33 - // SAFETY: 34 - // All file-backed memory map constructors are marked \ 35 - // unsafe because of the potential for Undefined Behavior (UB) \ 36 - // using the map if the underlying file is subsequently modified, in or out of process. 37 - map: unsafe { MmapOptions::new().map_mut(raw)? }, 38 - file: f, 39 - len, 40 - off: 0, 41 - }) 42 - } 43 - 44 - /// Resize the memory-mapped file. This will reallocate the memory mapping. 45 - #[expect(unsafe_code)] 46 - fn resize(&mut self, len: u64) -> std::io::Result<()> { 47 - // Resize the file. 48 - self.file.set_len(len)?; 49 - 50 - #[cfg(windows)] 51 - let raw = self.file.as_raw_handle(); 52 - #[cfg(unix)] 53 - let raw = self.file.as_raw_fd(); 54 - 55 - // SAFETY: 56 - // All file-backed memory map constructors are marked \ 57 - // unsafe because of the potential for Undefined Behavior (UB) \ 58 - // using the map if the underlying file is subsequently modified, in or out of process. 59 - self.map = unsafe { MmapOptions::new().map_mut(raw)? }; 60 - self.len = len; 61 - 62 - Ok(()) 63 - } 64 - } 65 - 66 - impl std::io::Read for MappedFile { 67 - fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> { 68 - if self.off == self.len { 69 - // If we're at EOF, return an EOF error code. `Ok(0)` tends to trip up some implementations. 70 - return Err(std::io::Error::new(ErrorKind::UnexpectedEof, "eof")); 71 - } 72 - 73 - // Calculate the number of bytes we're going to read. 74 - let remaining_bytes = self.len.saturating_sub(self.off); 75 - let buf_len = u64::try_from(buf.len()).unwrap_or(u64::MAX); 76 - let len = usize::try_from(std::cmp::min(remaining_bytes, buf_len)).unwrap_or(usize::MAX); 77 - 78 - let off = usize::try_from(self.off).map_err(|e| { 79 - std::io::Error::new( 80 - ErrorKind::InvalidInput, 81 - format!("offset too large for this platform: {e}"), 82 - ) 83 - })?; 84 - 85 - if let (Some(dest), Some(src)) = ( 86 - buf.get_mut(..len), 87 - self.map.get(off..off.saturating_add(len)), 88 - ) { 89 - dest.copy_from_slice(src); 90 - self.off = self.off.saturating_add(u64::try_from(len).unwrap_or(0)); 91 - Ok(len) 92 - } else { 93 - Err(std::io::Error::new( 94 - ErrorKind::InvalidInput, 95 - "invalid buffer range", 96 - )) 97 - } 98 - } 99 - } 100 - 101 - impl std::io::Write for MappedFile { 102 - fn flush(&mut self) -> std::io::Result<()> { 103 - // This is done by the system. 104 - Ok(()) 105 - } 106 - fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> { 107 - // Determine if we need to resize the file. 108 - let buf_len = u64::try_from(buf.len()).map_err(|e| { 109 - std::io::Error::new( 110 - ErrorKind::InvalidInput, 111 - format!("buffer length too large for this platform: {e}"), 112 - ) 113 - })?; 114 - 115 - if self.off.saturating_add(buf_len) >= self.len { 116 - self.resize(self.off.saturating_add(buf_len))?; 117 - } 118 - 119 - let off = usize::try_from(self.off).map_err(|e| { 120 - std::io::Error::new( 121 - ErrorKind::InvalidInput, 122 - format!("offset too large for this platform: {e}"), 123 - ) 124 - })?; 125 - let len = buf.len(); 126 - 127 - if let Some(dest) = self.map.get_mut(off..off.saturating_add(len)) { 128 - dest.copy_from_slice(buf); 129 - self.off = self.off.saturating_add(buf_len); 130 - Ok(len) 131 - } else { 132 - Err(std::io::Error::new( 133 - ErrorKind::InvalidInput, 134 - "invalid buffer range", 135 - )) 136 - } 137 - } 138 - } 139 - 140 - impl std::io::Seek for MappedFile { 141 - fn seek(&mut self, pos: std::io::SeekFrom) -> std::io::Result<u64> { 142 - let off = match pos { 143 - std::io::SeekFrom::Start(i) => i, 144 - std::io::SeekFrom::End(i) => { 145 - if i <= 0 { 146 - // If i is negative or zero, we're seeking backwards from the end 147 - // or exactly at the end 148 - self.len.saturating_sub(i.unsigned_abs()) 149 - } else { 150 - // If i is positive, we're seeking beyond the end, which is allowed 151 - // but requires extending the file 152 - self.len.saturating_add(i.unsigned_abs()) 153 - } 154 - } 155 - std::io::SeekFrom::Current(i) => { 156 - if i >= 0 { 157 - self.off.saturating_add(i.unsigned_abs()) 158 - } else { 159 - self.off.saturating_sub(i.unsigned_abs()) 160 - } 161 - } 162 - }; 163 - 164 - // If the offset is beyond EOF, extend the file to the new size. 165 - if off > self.len { 166 - self.resize(off)?; 167 - } 168 - 169 - self.off = off; 170 - Ok(off) 171 - } 172 - } 173 - 174 - impl tokio::io::AsyncRead for MappedFile { 175 - fn poll_read( 176 - mut self: std::pin::Pin<&mut Self>, 177 - _cx: &mut std::task::Context<'_>, 178 - buf: &mut tokio::io::ReadBuf<'_>, 179 - ) -> std::task::Poll<std::io::Result<()>> { 180 - let wbuf = buf.initialize_unfilled(); 181 - let len = wbuf.len(); 182 - 183 - std::task::Poll::Ready(match self.read(wbuf) { 184 - Ok(_) => { 185 - buf.advance(len); 186 - Ok(()) 187 - } 188 - Err(e) => Err(e), 189 - }) 190 - } 191 - } 192 - 193 - impl tokio::io::AsyncWrite for MappedFile { 194 - fn poll_flush( 195 - self: std::pin::Pin<&mut Self>, 196 - _cx: &mut std::task::Context<'_>, 197 - ) -> std::task::Poll<Result<(), std::io::Error>> { 198 - std::task::Poll::Ready(Ok(())) 199 - } 200 - 201 - fn poll_shutdown( 202 - self: std::pin::Pin<&mut Self>, 203 - _cx: &mut std::task::Context<'_>, 204 - ) -> std::task::Poll<Result<(), std::io::Error>> { 205 - std::task::Poll::Ready(Ok(())) 206 - } 207 - 208 - fn poll_write( 209 - mut self: std::pin::Pin<&mut Self>, 210 - _cx: &mut std::task::Context<'_>, 211 - buf: &[u8], 212 - ) -> std::task::Poll<Result<usize, std::io::Error>> { 213 - std::task::Poll::Ready(self.write(buf)) 214 - } 215 - } 216 - 217 - impl tokio::io::AsyncSeek for MappedFile { 218 - fn poll_complete( 219 - self: std::pin::Pin<&mut Self>, 220 - _cx: &mut std::task::Context<'_>, 221 - ) -> std::task::Poll<std::io::Result<u64>> { 222 - std::task::Poll::Ready(Ok(self.off)) 223 - } 224 - 225 - fn start_seek( 226 - mut self: std::pin::Pin<&mut Self>, 227 - position: std::io::SeekFrom, 228 - ) -> std::io::Result<()> { 229 - self.seek(position).map(|_p| ()) 230 - } 231 - } 232 - 233 - #[cfg(test)] 234 - mod test { 235 - use rand::Rng as _; 236 - use std::io::Write as _; 237 - 238 - use super::*; 239 - 240 - #[test] 241 - fn basic_rw() { 242 - let tmp = std::env::temp_dir().join( 243 - rand::thread_rng() 244 - .sample_iter(rand::distributions::Alphanumeric) 245 - .take(10) 246 - .map(char::from) 247 - .collect::<String>(), 248 - ); 249 - 250 - let mut m = MappedFile::new( 251 - std::fs::File::options() 252 - .create(true) 253 - .truncate(true) 254 - .read(true) 255 - .write(true) 256 - .open(&tmp) 257 - .expect("Failed to open temporary file"), 258 - ) 259 - .expect("Failed to create MappedFile"); 260 - 261 - m.write_all(b"abcd123").expect("Failed to write data"); 262 - let _: u64 = m 263 - .seek(std::io::SeekFrom::Start(0)) 264 - .expect("Failed to seek to start"); 265 - 266 - let mut buf = [0_u8; 7]; 267 - m.read_exact(&mut buf).expect("Failed to read data"); 268 - 269 - assert_eq!(&buf, b"abcd123"); 270 - 271 - drop(m); 272 - std::fs::remove_file(tmp).expect("Failed to remove temporary file"); 273 - } 274 - }
+809
src/models.rs
··· 1 + // Generated by diesel_ext 2 + 3 + #![allow(unused, non_snake_case)] 4 + #![allow(clippy::all)] 5 + 6 + pub mod pds { 7 + 8 + #![allow(unnameable_types, unused_qualifications)] 9 + use anyhow::{Result, bail}; 10 + use chrono::DateTime; 11 + use chrono::offset::Utc; 12 + use diesel::backend::Backend; 13 + use diesel::deserialize::FromSql; 14 + use diesel::prelude::*; 15 + use diesel::serialize::{Output, ToSql}; 16 + use diesel::sql_types::Text; 17 + use diesel::sqlite::Sqlite; 18 + use diesel::*; 19 + use serde::{Deserialize, Serialize}; 20 + 21 + #[derive( 22 + Queryable, 23 + Identifiable, 24 + Selectable, 25 + Clone, 26 + Debug, 27 + PartialEq, 28 + Default, 29 + Serialize, 30 + Deserialize, 31 + )] 32 + #[diesel(primary_key(request_uri))] 33 + #[diesel(table_name = crate::schema::pds::oauth_par_requests)] 34 + #[diesel(check_for_backend(Sqlite))] 35 + pub struct OauthParRequest { 36 + pub request_uri: String, 37 + pub client_id: String, 38 + pub response_type: String, 39 + pub code_challenge: String, 40 + pub code_challenge_method: String, 41 + pub state: Option<String>, 42 + pub login_hint: Option<String>, 43 + pub scope: Option<String>, 44 + pub redirect_uri: Option<String>, 45 + pub response_mode: Option<String>, 46 + pub display: Option<String>, 47 + pub created_at: i64, 48 + pub expires_at: i64, 49 + } 50 + 51 + #[derive( 52 + Queryable, 53 + Identifiable, 54 + Selectable, 55 + Clone, 56 + Debug, 57 + PartialEq, 58 + Default, 59 + Serialize, 60 + Deserialize, 61 + )] 62 + #[diesel(primary_key(code))] 63 + #[diesel(table_name = crate::schema::pds::oauth_authorization_codes)] 64 + #[diesel(check_for_backend(Sqlite))] 65 + pub struct OauthAuthorizationCode { 66 + pub code: String, 67 + pub client_id: String, 68 + pub subject: String, 69 + pub code_challenge: String, 70 + pub code_challenge_method: String, 71 + pub redirect_uri: String, 72 + pub scope: Option<String>, 73 + pub created_at: i64, 74 + pub expires_at: i64, 75 + pub used: bool, 76 + } 77 + 78 + #[derive( 79 + Queryable, 80 + Identifiable, 81 + Selectable, 82 + Clone, 83 + Debug, 84 + PartialEq, 85 + Default, 86 + Serialize, 87 + Deserialize, 88 + )] 89 + #[diesel(primary_key(token))] 90 + #[diesel(table_name = crate::schema::pds::oauth_refresh_tokens)] 91 + #[diesel(check_for_backend(Sqlite))] 92 + pub struct OauthRefreshToken { 93 + pub token: String, 94 + pub client_id: String, 95 + pub subject: String, 96 + pub dpop_thumbprint: String, 97 + pub scope: Option<String>, 98 + pub created_at: i64, 99 + pub expires_at: i64, 100 + pub revoked: bool, 101 + } 102 + 103 + #[derive( 104 + Queryable, 105 + Identifiable, 106 + Selectable, 107 + Clone, 108 + Debug, 109 + PartialEq, 110 + Default, 111 + Serialize, 112 + Deserialize, 113 + )] 114 + #[diesel(primary_key(jti))] 115 + #[diesel(table_name = crate::schema::pds::oauth_used_jtis)] 116 + #[diesel(check_for_backend(Sqlite))] 117 + pub struct OauthUsedJti { 118 + pub jti: String, 119 + pub issuer: String, 120 + pub created_at: i64, 121 + pub expires_at: i64, 122 + } 123 + 124 + #[derive( 125 + Queryable, 126 + Identifiable, 127 + Selectable, 128 + Clone, 129 + Debug, 130 + PartialEq, 131 + Default, 132 + Serialize, 133 + Deserialize, 134 + )] 135 + #[diesel(primary_key(did))] 136 + #[diesel(table_name = crate::schema::pds::account)] 137 + #[diesel(check_for_backend(Sqlite))] 138 + pub struct Account { 139 + pub did: String, 140 + pub email: String, 141 + #[diesel(column_name = recoveryKey)] 142 + #[serde(rename = "recoveryKey")] 143 + pub recovery_key: Option<String>, 144 + pub password: String, 145 + #[diesel(column_name = createdAt)] 146 + #[serde(rename = "createdAt")] 147 + pub created_at: String, 148 + #[diesel(column_name = invitesDisabled)] 149 + #[serde(rename = "invitesDisabled")] 150 + pub invites_disabled: i16, 151 + #[diesel(column_name = emailConfirmedAt)] 152 + #[serde(rename = "emailConfirmedAt")] 153 + pub email_confirmed_at: Option<String>, 154 + } 155 + 156 + #[derive( 157 + Queryable, 158 + Identifiable, 159 + Selectable, 160 + Clone, 161 + Debug, 162 + PartialEq, 163 + Default, 164 + Serialize, 165 + Deserialize, 166 + )] 167 + #[diesel(primary_key(did))] 168 + #[diesel(table_name = crate::schema::pds::actor)] 169 + #[diesel(check_for_backend(Sqlite))] 170 + pub struct Actor { 171 + pub did: String, 172 + pub handle: Option<String>, 173 + #[diesel(column_name = createdAt)] 174 + #[serde(rename = "createdAt")] 175 + pub created_at: String, 176 + #[diesel(column_name = takedownRef)] 177 + #[serde(rename = "takedownRef")] 178 + pub takedown_ref: Option<String>, 179 + #[diesel(column_name = deactivatedAt)] 180 + #[serde(rename = "deactivatedAt")] 181 + pub deactivated_at: Option<String>, 182 + #[diesel(column_name = deleteAfter)] 183 + #[serde(rename = "deleteAfter")] 184 + pub delete_after: Option<String>, 185 + } 186 + 187 + #[derive( 188 + Queryable, 189 + Identifiable, 190 + Selectable, 191 + Clone, 192 + Debug, 193 + PartialEq, 194 + Default, 195 + Serialize, 196 + Deserialize, 197 + )] 198 + #[diesel(primary_key(did, name))] 199 + #[diesel(table_name = crate::schema::pds::app_password)] 200 + #[diesel(check_for_backend(Sqlite))] 201 + pub struct AppPassword { 202 + pub did: String, 203 + pub name: String, 204 + pub password: String, 205 + #[diesel(column_name = createdAt)] 206 + #[serde(rename = "createdAt")] 207 + pub created_at: String, 208 + } 209 + 210 + #[derive( 211 + Queryable, 212 + Identifiable, 213 + Selectable, 214 + Clone, 215 + Debug, 216 + PartialEq, 217 + Default, 218 + Serialize, 219 + Deserialize, 220 + )] 221 + #[diesel(primary_key(did))] 222 + #[diesel(table_name = crate::schema::pds::did_doc)] 223 + #[diesel(check_for_backend(Sqlite))] 224 + pub struct DidDoc { 225 + pub did: String, 226 + pub doc: String, 227 + #[diesel(column_name = updatedAt)] 228 + #[serde(rename = "updatedAt")] 229 + pub updated_at: i64, 230 + } 231 + 232 + #[derive( 233 + Clone, Copy, Debug, PartialEq, Eq, Hash, Default, Serialize, Deserialize, AsExpression, 234 + )] 235 + #[diesel(sql_type = Text)] 236 + pub enum EmailTokenPurpose { 237 + #[default] 238 + ConfirmEmail, 239 + UpdateEmail, 240 + ResetPassword, 241 + DeleteAccount, 242 + PlcOperation, 243 + } 244 + 245 + impl EmailTokenPurpose { 246 + pub fn as_str(&self) -> &'static str { 247 + match self { 248 + EmailTokenPurpose::ConfirmEmail => "confirm_email", 249 + EmailTokenPurpose::UpdateEmail => "update_email", 250 + EmailTokenPurpose::ResetPassword => "reset_password", 251 + EmailTokenPurpose::DeleteAccount => "delete_account", 252 + EmailTokenPurpose::PlcOperation => "plc_operation", 253 + } 254 + } 255 + 256 + pub fn from_str(s: &str) -> Result<Self> { 257 + match s { 258 + "confirm_email" => Ok(EmailTokenPurpose::ConfirmEmail), 259 + "update_email" => Ok(EmailTokenPurpose::UpdateEmail), 260 + "reset_password" => Ok(EmailTokenPurpose::ResetPassword), 261 + "delete_account" => Ok(EmailTokenPurpose::DeleteAccount), 262 + "plc_operation" => Ok(EmailTokenPurpose::PlcOperation), 263 + _ => bail!("Unable to parse as EmailTokenPurpose: `{s:?}`"), 264 + } 265 + } 266 + } 267 + 268 + impl<DB> Queryable<sql_types::Text, DB> for EmailTokenPurpose 269 + where 270 + DB: backend::Backend, 271 + String: deserialize::FromSql<sql_types::Text, DB>, 272 + { 273 + type Row = String; 274 + 275 + fn build(s: String) -> deserialize::Result<Self> { 276 + Ok(Self::from_str(&s)?) 277 + } 278 + } 279 + 280 + impl serialize::ToSql<sql_types::Text, sqlite::Sqlite> for EmailTokenPurpose 281 + where 282 + String: serialize::ToSql<sql_types::Text, sqlite::Sqlite>, 283 + { 284 + fn to_sql<'lifetime>( 285 + &'lifetime self, 286 + out: &mut serialize::Output<'lifetime, '_, sqlite::Sqlite>, 287 + ) -> serialize::Result { 288 + serialize::ToSql::<sql_types::Text, sqlite::Sqlite>::to_sql( 289 + match self { 290 + Self::ConfirmEmail => "confirm_email", 291 + Self::UpdateEmail => "update_email", 292 + Self::ResetPassword => "reset_password", 293 + Self::DeleteAccount => "delete_account", 294 + Self::PlcOperation => "plc_operation", 295 + }, 296 + out, 297 + ) 298 + } 299 + } 300 + 301 + #[derive( 302 + Queryable, 303 + Identifiable, 304 + Selectable, 305 + Clone, 306 + Debug, 307 + PartialEq, 308 + Default, 309 + Serialize, 310 + Deserialize, 311 + )] 312 + #[diesel(primary_key(purpose, did))] 313 + #[diesel(table_name = crate::schema::pds::email_token)] 314 + #[diesel(check_for_backend(Sqlite))] 315 + pub struct EmailToken { 316 + pub purpose: EmailTokenPurpose, 317 + pub did: String, 318 + pub token: String, 319 + #[diesel(column_name = requestedAt)] 320 + #[serde(rename = "requestedAt")] 321 + pub requested_at: String, 322 + } 323 + 324 + #[derive( 325 + Queryable, 326 + Identifiable, 327 + Insertable, 328 + Selectable, 329 + Clone, 330 + Debug, 331 + PartialEq, 332 + Default, 333 + Serialize, 334 + Deserialize, 335 + )] 336 + #[diesel(primary_key(code))] 337 + #[diesel(table_name = crate::schema::pds::invite_code)] 338 + #[diesel(check_for_backend(Sqlite))] 339 + pub struct InviteCode { 340 + pub code: String, 341 + #[diesel(column_name = availableUses)] 342 + #[serde(rename = "availableUses")] 343 + pub available_uses: i32, 344 + pub disabled: i16, 345 + #[diesel(column_name = forAccount)] 346 + #[serde(rename = "forAccount")] 347 + pub for_account: String, 348 + #[diesel(column_name = createdBy)] 349 + #[serde(rename = "createdBy")] 350 + pub created_by: String, 351 + #[diesel(column_name = createdAt)] 352 + #[serde(rename = "createdAt")] 353 + pub created_at: String, 354 + } 355 + 356 + #[derive( 357 + Queryable, 358 + Identifiable, 359 + Selectable, 360 + Clone, 361 + Debug, 362 + PartialEq, 363 + Default, 364 + Serialize, 365 + Deserialize, 366 + )] 367 + #[diesel(primary_key(code, usedBy))] 368 + #[diesel(table_name = crate::schema::pds::invite_code_use)] 369 + #[diesel(check_for_backend(Sqlite))] 370 + pub struct InviteCodeUse { 371 + pub code: String, 372 + #[diesel(column_name = usedBy)] 373 + #[serde(rename = "usedBy")] 374 + pub used_by: String, 375 + #[diesel(column_name = usedAt)] 376 + #[serde(rename = "usedAt")] 377 + pub used_at: String, 378 + } 379 + 380 + #[derive( 381 + Queryable, 382 + Identifiable, 383 + Selectable, 384 + Clone, 385 + Debug, 386 + PartialEq, 387 + Default, 388 + Serialize, 389 + Deserialize, 390 + )] 391 + #[diesel(table_name = crate::schema::pds::refresh_token)] 392 + #[diesel(check_for_backend(Sqlite))] 393 + pub struct RefreshToken { 394 + pub id: String, 395 + pub did: String, 396 + #[diesel(column_name = expiresAt)] 397 + #[serde(rename = "expiresAt")] 398 + pub expires_at: String, 399 + #[diesel(column_name = nextId)] 400 + #[serde(rename = "nextId")] 401 + pub next_id: Option<String>, 402 + #[diesel(column_name = appPasswordName)] 403 + #[serde(rename = "appPasswordName")] 404 + pub app_password_name: Option<String>, 405 + } 406 + 407 + #[derive( 408 + Queryable, 409 + Identifiable, 410 + Selectable, 411 + Insertable, 412 + Clone, 413 + Debug, 414 + PartialEq, 415 + Default, 416 + Serialize, 417 + Deserialize, 418 + )] 419 + #[diesel(primary_key(seq))] 420 + #[diesel(table_name = crate::schema::pds::repo_seq)] 421 + #[diesel(check_for_backend(Sqlite))] 422 + pub struct RepoSeq { 423 + #[diesel(deserialize_as = i64)] 424 + pub seq: Option<i64>, 425 + pub did: String, 426 + #[diesel(column_name = eventType)] 427 + #[serde(rename = "eventType")] 428 + pub event_type: String, 429 + #[diesel(sql_type = Bytea)] 430 + pub event: Vec<u8>, 431 + #[diesel(deserialize_as = i16)] 432 + pub invalidated: Option<i16>, 433 + #[diesel(column_name = sequencedAt)] 434 + #[serde(rename = "sequencedAt")] 435 + pub sequenced_at: String, 436 + } 437 + 438 + impl RepoSeq { 439 + pub fn new(did: String, event_type: String, event: Vec<u8>, sequenced_at: String) -> Self { 440 + RepoSeq { 441 + did, 442 + event_type, 443 + event, 444 + sequenced_at, 445 + invalidated: None, // default values used on insert 446 + seq: None, // default values used on insert 447 + } 448 + } 449 + } 450 + 451 + #[derive( 452 + Queryable, 453 + Identifiable, 454 + Insertable, 455 + Selectable, 456 + Clone, 457 + Debug, 458 + PartialEq, 459 + Default, 460 + Serialize, 461 + Deserialize, 462 + )] 463 + #[diesel(primary_key(id))] 464 + #[diesel(table_name = crate::schema::pds::token)] 465 + #[diesel(check_for_backend(Sqlite))] 466 + pub struct Token { 467 + pub id: String, 468 + pub did: String, 469 + #[diesel(column_name = tokenId)] 470 + #[serde(rename = "tokenId")] 471 + pub token_id: String, 472 + #[diesel(column_name = createdAt)] 473 + #[serde(rename = "createdAt")] 474 + pub created_at: DateTime<Utc>, 475 + #[diesel(column_name = updatedAt)] 476 + #[serde(rename = "updatedAt")] 477 + pub updated_at: DateTime<Utc>, 478 + #[diesel(column_name = expiresAt)] 479 + #[serde(rename = "expiresAt")] 480 + pub expires_at: DateTime<Utc>, 481 + #[diesel(column_name = clientId)] 482 + #[serde(rename = "clientId")] 483 + pub client_id: String, 484 + #[diesel(column_name = clientAuth)] 485 + #[serde(rename = "clientAuth")] 486 + pub client_auth: String, 487 + #[diesel(column_name = deviceId)] 488 + #[serde(rename = "deviceId")] 489 + pub device_id: Option<String>, 490 + pub parameters: String, 491 + pub details: Option<String>, 492 + pub code: Option<String>, 493 + #[diesel(column_name = currentRefreshToken)] 494 + #[serde(rename = "currentRefreshToken")] 495 + pub current_refresh_token: Option<String>, 496 + } 497 + 498 + #[derive( 499 + Queryable, 500 + Identifiable, 501 + Insertable, 502 + Selectable, 503 + Clone, 504 + Debug, 505 + PartialEq, 506 + Default, 507 + Serialize, 508 + Deserialize, 509 + )] 510 + #[diesel(primary_key(id))] 511 + #[diesel(table_name = crate::schema::pds::device)] 512 + #[diesel(check_for_backend(Sqlite))] 513 + pub struct Device { 514 + pub id: String, 515 + #[diesel(column_name = sessionId)] 516 + #[serde(rename = "sessionId")] 517 + pub session_id: Option<String>, 518 + #[diesel(column_name = userAgent)] 519 + #[serde(rename = "userAgent")] 520 + pub user_agent: Option<String>, 521 + #[diesel(column_name = ipAddress)] 522 + #[serde(rename = "ipAddress")] 523 + pub ip_address: String, 524 + #[diesel(column_name = lastSeenAt)] 525 + #[serde(rename = "lastSeenAt")] 526 + pub last_seen_at: DateTime<Utc>, 527 + } 528 + 529 + #[derive( 530 + Queryable, 531 + Identifiable, 532 + Insertable, 533 + Selectable, 534 + Clone, 535 + Debug, 536 + PartialEq, 537 + Default, 538 + Serialize, 539 + Deserialize, 540 + )] 541 + #[diesel(primary_key(did))] 542 + #[diesel(table_name = crate::schema::pds::device_account)] 543 + #[diesel(check_for_backend(Sqlite))] 544 + pub struct DeviceAccount { 545 + pub did: String, 546 + #[diesel(column_name = deviceId)] 547 + #[serde(rename = "deviceId")] 548 + pub device_id: String, 549 + #[diesel(column_name = authenticatedAt)] 550 + #[serde(rename = "authenticatedAt")] 551 + pub authenticated_at: DateTime<Utc>, 552 + pub remember: bool, 553 + #[diesel(column_name = authorizedClients)] 554 + #[serde(rename = "authorizedClients")] 555 + pub authorized_clients: String, 556 + } 557 + 558 + #[derive( 559 + Queryable, 560 + Identifiable, 561 + Insertable, 562 + Selectable, 563 + Clone, 564 + Debug, 565 + PartialEq, 566 + Default, 567 + Serialize, 568 + Deserialize, 569 + )] 570 + #[diesel(primary_key(id))] 571 + #[diesel(table_name = crate::schema::pds::authorization_request)] 572 + #[diesel(check_for_backend(Sqlite))] 573 + pub struct AuthorizationRequest { 574 + pub id: String, 575 + pub did: Option<String>, 576 + #[diesel(column_name = deviceId)] 577 + #[serde(rename = "deviceId")] 578 + pub device_id: Option<String>, 579 + #[diesel(column_name = clientId)] 580 + #[serde(rename = "clientId")] 581 + pub client_id: String, 582 + #[diesel(column_name = clientAuth)] 583 + #[serde(rename = "clientAuth")] 584 + pub client_auth: String, 585 + pub parameters: String, 586 + #[diesel(column_name = expiresAt)] 587 + #[serde(rename = "expiresAt")] 588 + pub expires_at: DateTime<Utc>, 589 + pub code: Option<String>, 590 + } 591 + 592 + #[derive( 593 + Queryable, Insertable, Selectable, Clone, Debug, PartialEq, Default, Serialize, Deserialize, 594 + )] 595 + #[diesel(table_name = crate::schema::pds::used_refresh_token)] 596 + #[diesel(check_for_backend(Sqlite))] 597 + pub struct UsedRefreshToken { 598 + #[diesel(column_name = tokenId)] 599 + #[serde(rename = "tokenId")] 600 + pub token_id: String, 601 + #[diesel(column_name = refreshToken)] 602 + #[serde(rename = "refreshToken")] 603 + pub refresh_token: String, 604 + } 605 + } 606 + 607 + pub mod actor_store { 608 + 609 + #![allow(unnameable_types, unused_qualifications)] 610 + use anyhow::{Result, bail}; 611 + use chrono::DateTime; 612 + use chrono::offset::Utc; 613 + use diesel::backend::Backend; 614 + use diesel::deserialize::FromSql; 615 + use diesel::prelude::*; 616 + use diesel::serialize::{Output, ToSql}; 617 + use diesel::sql_types::Text; 618 + use diesel::sqlite::Sqlite; 619 + use diesel::*; 620 + use serde::{Deserialize, Serialize}; 621 + 622 + #[derive( 623 + Queryable, 624 + Identifiable, 625 + Insertable, 626 + Selectable, 627 + Clone, 628 + Debug, 629 + PartialEq, 630 + Default, 631 + Serialize, 632 + Deserialize, 633 + )] 634 + #[diesel(table_name = crate::schema::actor_store::account_pref)] 635 + #[diesel(check_for_backend(Sqlite))] 636 + pub struct AccountPref { 637 + pub id: i32, 638 + pub name: String, 639 + #[diesel(column_name = valueJson)] 640 + #[serde(rename = "valueJson")] 641 + pub value_json: Option<String>, 642 + } 643 + 644 + #[derive( 645 + Queryable, 646 + Identifiable, 647 + Insertable, 648 + Selectable, 649 + Clone, 650 + Debug, 651 + PartialEq, 652 + Default, 653 + Serialize, 654 + Deserialize, 655 + )] 656 + #[diesel(primary_key(uri, path))] 657 + #[diesel(table_name = crate::schema::actor_store::backlink)] 658 + #[diesel(check_for_backend(Sqlite))] 659 + pub struct Backlink { 660 + pub uri: String, 661 + pub path: String, 662 + #[diesel(column_name = linkTo)] 663 + #[serde(rename = "linkTo")] 664 + pub link_to: String, 665 + } 666 + 667 + #[derive( 668 + Queryable, 669 + Identifiable, 670 + Selectable, 671 + Clone, 672 + Debug, 673 + PartialEq, 674 + Default, 675 + Serialize, 676 + Deserialize, 677 + )] 678 + #[diesel(treat_none_as_null = true)] 679 + #[diesel(primary_key(cid))] 680 + #[diesel(table_name = crate::schema::actor_store::blob)] 681 + #[diesel(check_for_backend(Sqlite))] 682 + pub struct Blob { 683 + pub cid: String, 684 + pub did: String, 685 + #[diesel(column_name = mimeType)] 686 + #[serde(rename = "mimeType")] 687 + pub mime_type: String, 688 + pub size: i32, 689 + #[diesel(column_name = tempKey)] 690 + #[serde(rename = "tempKey")] 691 + pub temp_key: Option<String>, 692 + pub width: Option<i32>, 693 + pub height: Option<i32>, 694 + #[diesel(column_name = createdAt)] 695 + #[serde(rename = "createdAt")] 696 + pub created_at: String, 697 + #[diesel(column_name = takedownRef)] 698 + #[serde(rename = "takedownRef")] 699 + pub takedown_ref: Option<String>, 700 + } 701 + 702 + #[derive( 703 + Queryable, 704 + Identifiable, 705 + Insertable, 706 + Selectable, 707 + Clone, 708 + Debug, 709 + PartialEq, 710 + Default, 711 + Serialize, 712 + Deserialize, 713 + )] 714 + #[diesel(primary_key(uri))] 715 + #[diesel(table_name = crate::schema::actor_store::record)] 716 + #[diesel(check_for_backend(Sqlite))] 717 + pub struct Record { 718 + pub uri: String, 719 + pub cid: String, 720 + pub did: String, 721 + pub collection: String, 722 + pub rkey: String, 723 + #[diesel(column_name = repoRev)] 724 + #[serde(rename = "repoRev")] 725 + pub repo_rev: Option<String>, 726 + #[diesel(column_name = indexedAt)] 727 + #[serde(rename = "indexedAt")] 728 + pub indexed_at: String, 729 + #[diesel(column_name = takedownRef)] 730 + #[serde(rename = "takedownRef")] 731 + pub takedown_ref: Option<String>, 732 + } 733 + 734 + #[derive( 735 + QueryableByName, 736 + Queryable, 737 + Identifiable, 738 + Selectable, 739 + Clone, 740 + Debug, 741 + PartialEq, 742 + Default, 743 + Serialize, 744 + Deserialize, 745 + )] 746 + #[diesel(primary_key(blobCid, recordUri))] 747 + #[diesel(table_name = crate::schema::actor_store::record_blob)] 748 + #[diesel(check_for_backend(Sqlite))] 749 + pub struct RecordBlob { 750 + #[diesel(column_name = blobCid, sql_type = Text)] 751 + #[serde(rename = "blobCid")] 752 + pub blob_cid: String, 753 + #[diesel(column_name = recordUri, sql_type = Text)] 754 + #[serde(rename = "recordUri")] 755 + pub record_uri: String, 756 + #[diesel(sql_type = Text)] 757 + pub did: String, 758 + } 759 + 760 + #[derive( 761 + Queryable, 762 + Identifiable, 763 + Selectable, 764 + Insertable, 765 + Clone, 766 + Debug, 767 + PartialEq, 768 + Default, 769 + Serialize, 770 + Deserialize, 771 + )] 772 + #[diesel(primary_key(cid))] 773 + #[diesel(table_name = crate::schema::actor_store::repo_block)] 774 + #[diesel(check_for_backend(Sqlite))] 775 + pub struct RepoBlock { 776 + #[diesel(sql_type = Text)] 777 + pub cid: String, 778 + pub did: String, 779 + #[diesel(column_name = repoRev)] 780 + #[serde(rename = "repoRev")] 781 + pub repo_rev: String, 782 + pub size: i32, 783 + #[diesel(sql_type = Bytea)] 784 + pub content: Vec<u8>, 785 + } 786 + 787 + #[derive( 788 + Queryable, 789 + Identifiable, 790 + Selectable, 791 + Clone, 792 + Debug, 793 + PartialEq, 794 + Default, 795 + Serialize, 796 + Deserialize, 797 + )] 798 + #[diesel(primary_key(did))] 799 + #[diesel(table_name = crate::schema::actor_store::repo_root)] 800 + #[diesel(check_for_backend(Sqlite))] 801 + pub struct RepoRoot { 802 + pub did: String, 803 + pub cid: String, 804 + pub rev: String, 805 + #[diesel(column_name = indexedAt)] 806 + #[serde(rename = "indexedAt")] 807 + pub indexed_at: String, 808 + } 809 + }
+451 -240
src/oauth.rs
··· 1 1 //! OAuth endpoints 2 - 2 + #![allow(unnameable_types, unused_qualifications)] 3 + use crate::config::AppConfig; 4 + use crate::error::Error; 3 5 use crate::metrics::AUTH_FAILED; 4 - use crate::{AppConfig, AppState, Client, Db, Error, Result, SigningKey}; 6 + use crate::serve::{AppState, Client, Result, SigningKey}; 5 7 use anyhow::{Context as _, anyhow}; 6 8 use argon2::{Argon2, PasswordHash, PasswordVerifier as _}; 7 9 use atrium_crypto::keypair::Did as _; ··· 14 16 routing::{get, post}, 15 17 }; 16 18 use base64::Engine as _; 19 + use deadpool_diesel::sqlite::Pool; 20 + use diesel::*; 17 21 use metrics::counter; 18 22 use rand::distributions::Alphanumeric; 19 23 use rand::{Rng as _, thread_rng}; ··· 252 256 /// POST `/oauth/par` 253 257 #[expect(clippy::too_many_lines)] 254 258 async fn par( 255 - State(db): State<Db>, 259 + State(db): State<Pool>, 256 260 State(client): State<Client>, 257 261 Json(form_data): Json<HashMap<String, String>>, 258 262 ) -> Result<Json<Value>> { ··· 357 361 .context("failed to compute expiration time")? 358 362 .timestamp(); 359 363 360 - _ = sqlx::query!( 361 - r#" 362 - INSERT INTO oauth_par_requests ( 363 - request_uri, client_id, response_type, code_challenge, code_challenge_method, 364 - state, login_hint, scope, redirect_uri, response_mode, display, 365 - created_at, expires_at 366 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 367 - "#, 368 - request_uri, 369 - client_id, 370 - response_type, 371 - code_challenge, 372 - code_challenge_method, 373 - state, 374 - login_hint, 375 - scope, 376 - redirect_uri, 377 - response_mode, 378 - display, 379 - created_at, 380 - expires_at 381 - ) 382 - .execute(&db) 383 - .await 384 - .context("failed to store PAR request")?; 364 + use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 365 + let client_id = client_id.to_owned(); 366 + let request_uri_cloned = request_uri.to_owned(); 367 + let response_type = response_type.to_owned(); 368 + let code_challenge = code_challenge.to_owned(); 369 + let code_challenge_method = code_challenge_method.to_owned(); 370 + _ = db 371 + .get() 372 + .await 373 + .expect("Failed to get database connection") 374 + .interact(move |conn| { 375 + insert_into(ParRequestSchema::oauth_par_requests) 376 + .values(( 377 + ParRequestSchema::request_uri.eq(&request_uri_cloned), 378 + ParRequestSchema::client_id.eq(client_id), 379 + ParRequestSchema::response_type.eq(response_type), 380 + ParRequestSchema::code_challenge.eq(code_challenge), 381 + ParRequestSchema::code_challenge_method.eq(code_challenge_method), 382 + ParRequestSchema::state.eq(state), 383 + ParRequestSchema::login_hint.eq(login_hint), 384 + ParRequestSchema::scope.eq(scope), 385 + ParRequestSchema::redirect_uri.eq(redirect_uri), 386 + ParRequestSchema::response_mode.eq(response_mode), 387 + ParRequestSchema::display.eq(display), 388 + ParRequestSchema::created_at.eq(created_at), 389 + ParRequestSchema::expires_at.eq(expires_at), 390 + )) 391 + .execute(conn) 392 + }) 393 + .await 394 + .expect("Failed to store PAR request") 395 + .expect("Failed to store PAR request"); 385 396 386 397 Ok(Json(json!({ 387 398 "request_uri": request_uri, ··· 392 403 /// OAuth Authorization endpoint 393 404 /// GET `/oauth/authorize` 394 405 async fn authorize( 395 - State(db): State<Db>, 406 + State(db): State<Pool>, 396 407 State(client): State<Client>, 397 408 Query(params): Query<HashMap<String, String>>, 398 409 ) -> Result<impl IntoResponse> { ··· 407 418 let timestamp = chrono::Utc::now().timestamp(); 408 419 409 420 // Retrieve the PAR request from the database 410 - let par_request = sqlx::query!( 411 - r#" 412 - SELECT * FROM oauth_par_requests 413 - WHERE request_uri = ? AND client_id = ? AND expires_at > ? 414 - "#, 415 - request_uri, 416 - client_id, 417 - timestamp 418 - ) 419 - .fetch_optional(&db) 420 - .await 421 - .context("failed to query PAR request")? 422 - .context("PAR request not found or expired")?; 421 + use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 422 + 423 + let request_uri_clone = request_uri.to_owned(); 424 + let client_id_clone = client_id.to_owned(); 425 + let timestamp_clone = timestamp.clone(); 426 + let login_hint = db 427 + .get() 428 + .await 429 + .expect("Failed to get database connection") 430 + .interact(move |conn| { 431 + ParRequestSchema::oauth_par_requests 432 + .select(ParRequestSchema::login_hint) 433 + .filter(ParRequestSchema::request_uri.eq(request_uri_clone)) 434 + .filter(ParRequestSchema::client_id.eq(client_id_clone)) 435 + .filter(ParRequestSchema::expires_at.gt(timestamp_clone)) 436 + .first::<Option<String>>(conn) 437 + .optional() 438 + }) 439 + .await 440 + .expect("Failed to query PAR request") 441 + .expect("Failed to query PAR request") 442 + .expect("Failed to query PAR request"); 423 443 424 444 // Validate client metadata 425 445 let client_metadata = fetch_client_metadata(&client, client_id).await?; 426 446 427 447 // Authorization page with login form 428 - let login_hint = par_request.login_hint.unwrap_or_default(); 448 + let login_hint = login_hint.unwrap_or_default(); 429 449 let html = format!( 430 450 r#"<!DOCTYPE html> 431 451 <html> ··· 491 511 /// POST `/oauth/authorize/sign-in` 492 512 #[expect(clippy::too_many_lines)] 493 513 async fn authorize_signin( 494 - State(db): State<Db>, 514 + State(db): State<Pool>, 495 515 State(config): State<AppConfig>, 496 516 State(client): State<Client>, 497 517 extract::Form(form_data): extract::Form<HashMap<String, String>>, ··· 511 531 let timestamp = chrono::Utc::now().timestamp(); 512 532 513 533 // Retrieve the PAR request 514 - let par_request = sqlx::query!( 515 - r#" 516 - SELECT * FROM oauth_par_requests 517 - WHERE request_uri = ? AND client_id = ? AND expires_at > ? 518 - "#, 519 - request_uri, 520 - client_id, 521 - timestamp 522 - ) 523 - .fetch_optional(&db) 524 - .await 525 - .context("failed to query PAR request")? 526 - .context("PAR request not found or expired")?; 534 + use crate::schema::pds::oauth_par_requests::dsl as ParRequestSchema; 535 + #[derive(Queryable, Selectable)] 536 + #[diesel(table_name = crate::schema::pds::oauth_par_requests)] 537 + #[diesel(check_for_backend(sqlite::Sqlite))] 538 + struct ParRequest { 539 + request_uri: String, 540 + client_id: String, 541 + response_type: String, 542 + code_challenge: String, 543 + code_challenge_method: String, 544 + state: Option<String>, 545 + login_hint: Option<String>, 546 + scope: Option<String>, 547 + redirect_uri: Option<String>, 548 + response_mode: Option<String>, 549 + display: Option<String>, 550 + created_at: i64, 551 + expires_at: i64, 552 + } 553 + let request_uri_clone = request_uri.to_owned(); 554 + let client_id_clone = client_id.to_owned(); 555 + let timestamp_clone = timestamp.clone(); 556 + let par_request = db 557 + .get() 558 + .await 559 + .expect("Failed to get database connection") 560 + .interact(move |conn| { 561 + ParRequestSchema::oauth_par_requests 562 + .filter(ParRequestSchema::request_uri.eq(request_uri_clone)) 563 + .filter(ParRequestSchema::client_id.eq(client_id_clone)) 564 + .filter(ParRequestSchema::expires_at.gt(timestamp_clone)) 565 + .first::<ParRequest>(conn) 566 + .optional() 567 + }) 568 + .await 569 + .expect("Failed to query PAR request") 570 + .expect("Failed to query PAR request") 571 + .expect("Failed to query PAR request"); 527 572 528 573 // Authenticate the user 529 - let account = sqlx::query!( 530 - r#" 531 - WITH LatestHandles AS ( 532 - SELECT did, handle 533 - FROM handles 534 - WHERE (did, created_at) IN ( 535 - SELECT did, MAX(created_at) AS max_created_at 536 - FROM handles 537 - GROUP BY did 538 - ) 539 - ) 540 - SELECT a.did, a.email, a.password, h.handle 541 - FROM accounts a 542 - LEFT JOIN LatestHandles h ON a.did = h.did 543 - WHERE h.handle = ? 544 - "#, 545 - username 546 - ) 547 - .fetch_optional(&db) 548 - .await 549 - .context("failed to query database")? 550 - .context("user not found")?; 574 + use crate::schema::pds::account::dsl as AccountSchema; 575 + use crate::schema::pds::actor::dsl as ActorSchema; 576 + let username_clone = username.to_owned(); 577 + let account = db 578 + .get() 579 + .await 580 + .expect("Failed to get database connection") 581 + .interact(move |conn| { 582 + AccountSchema::account 583 + .filter(AccountSchema::email.eq(username_clone)) 584 + .first::<crate::models::pds::Account>(conn) 585 + .optional() 586 + }) 587 + .await 588 + .expect("Failed to query account") 589 + .expect("Failed to query account") 590 + .expect("Failed to query account"); 591 + // let actor = db 592 + // .get() 593 + // .await 594 + // .expect("Failed to get database connection") 595 + // .interact(move |conn| { 596 + // ActorSchema::actor 597 + // .filter(ActorSchema::did.eq(did)) 598 + // .first::<rsky_pds::models::Actor>(conn) 599 + // .optional() 600 + // }) 601 + // .await 602 + // .expect("Failed to query actor") 603 + // .expect("Failed to query actor") 604 + // .expect("Failed to query actor"); 551 605 552 606 // Verify password - fixed to use equality check instead of pattern matching 553 607 if Argon2::default().verify_password( ··· 592 646 .context("failed to compute expiration time")? 593 647 .timestamp(); 594 648 595 - _ = sqlx::query!( 596 - r#" 597 - INSERT INTO oauth_authorization_codes ( 598 - code, client_id, subject, code_challenge, code_challenge_method, 599 - redirect_uri, scope, created_at, expires_at, used 600 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 601 - "#, 602 - code, 603 - client_id, 604 - account.did, 605 - par_request.code_challenge, 606 - par_request.code_challenge_method, 607 - redirect_uri, 608 - par_request.scope, 609 - created_at, 610 - expires_at, 611 - false 612 - ) 613 - .execute(&db) 614 - .await 615 - .context("failed to store authorization code")?; 649 + use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema; 650 + let code_cloned = code.to_owned(); 651 + let client_id = client_id.to_owned(); 652 + let subject = account.did.to_owned(); 653 + let code_challenge = par_request.code_challenge.to_owned(); 654 + let code_challenge_method = par_request.code_challenge_method.to_owned(); 655 + let redirect_uri_cloned = redirect_uri.to_owned(); 656 + let scope = par_request.scope.to_owned(); 657 + let used = false; 658 + _ = db 659 + .get() 660 + .await 661 + .expect("Failed to get database connection") 662 + .interact(move |conn| { 663 + insert_into(AuthCodeSchema::oauth_authorization_codes) 664 + .values(( 665 + AuthCodeSchema::code.eq(code_cloned), 666 + AuthCodeSchema::client_id.eq(client_id), 667 + AuthCodeSchema::subject.eq(subject), 668 + AuthCodeSchema::code_challenge.eq(code_challenge), 669 + AuthCodeSchema::code_challenge_method.eq(code_challenge_method), 670 + AuthCodeSchema::redirect_uri.eq(redirect_uri_cloned), 671 + AuthCodeSchema::scope.eq(scope), 672 + AuthCodeSchema::created_at.eq(created_at), 673 + AuthCodeSchema::expires_at.eq(expires_at), 674 + AuthCodeSchema::used.eq(used), 675 + )) 676 + .execute(conn) 677 + }) 678 + .await 679 + .expect("Failed to store authorization code") 680 + .expect("Failed to store authorization code"); 616 681 617 682 // Use state from the PAR request or generate one 618 683 let state = par_request.state.unwrap_or_else(|| { ··· 673 738 dpop_token: &str, 674 739 http_method: &str, 675 740 http_uri: &str, 676 - db: &Db, 741 + db: &Pool, 677 742 access_token: Option<&str>, 678 743 bound_key_thumbprint: Option<&str>, 679 744 ) -> Result<String> { ··· 811 876 } 812 877 813 878 // 11. Check for replay attacks via JTI tracking 814 - let jti_used = 815 - sqlx::query_scalar!(r#"SELECT COUNT(*) FROM oauth_used_jtis WHERE jti = ?"#, jti) 816 - .fetch_one(db) 817 - .await 818 - .context("failed to check JTI")?; 879 + use crate::schema::pds::oauth_used_jtis::dsl as JtiSchema; 880 + let jti_clone = jti.to_owned(); 881 + let jti_used = db 882 + .get() 883 + .await 884 + .expect("Failed to get database connection") 885 + .interact(move |conn| { 886 + JtiSchema::oauth_used_jtis 887 + .filter(JtiSchema::jti.eq(jti_clone)) 888 + .count() 889 + .get_result::<i64>(conn) 890 + .optional() 891 + }) 892 + .await 893 + .expect("Failed to check JTI") 894 + .expect("Failed to check JTI") 895 + .unwrap_or(0); 819 896 820 897 if jti_used > 0 { 821 898 return Err(Error::with_status( ··· 825 902 } 826 903 827 904 // 12. Store the JTI to prevent replay attacks 828 - _ = sqlx::query!( 829 - r#" 830 - INSERT INTO oauth_used_jtis (jti, issuer, created_at, expires_at) 831 - VALUES (?, ?, ?, ?) 832 - "#, 833 - jti, 834 - thumbprint, // Use thumbprint as issuer identifier 835 - now, 836 - exp 837 - ) 838 - .execute(db) 839 - .await 840 - .context("failed to store JTI")?; 905 + let jti_cloned = jti.to_owned(); 906 + let issuer = thumbprint.to_owned(); 907 + let created_at = now; 908 + let expires_at = exp; 909 + _ = db 910 + .get() 911 + .await 912 + .expect("Failed to get database connection") 913 + .interact(move |conn| { 914 + insert_into(JtiSchema::oauth_used_jtis) 915 + .values(( 916 + JtiSchema::jti.eq(jti_cloned), 917 + JtiSchema::issuer.eq(issuer), 918 + JtiSchema::created_at.eq(created_at), 919 + JtiSchema::expires_at.eq(expires_at), 920 + )) 921 + .execute(conn) 922 + }) 923 + .await 924 + .expect("Failed to store JTI") 925 + .expect("Failed to store JTI"); 841 926 842 927 // 13. Cleanup expired JTIs periodically (1% chance on each request) 843 928 if thread_rng().gen_range(0_i32..100_i32) == 0_i32 { 844 - _ = sqlx::query!(r#"DELETE FROM oauth_used_jtis WHERE expires_at < ?"#, now) 845 - .execute(db) 929 + let now_clone = now.to_owned(); 930 + _ = db 931 + .get() 932 + .await 933 + .expect("Failed to get database connection") 934 + .interact(move |conn| { 935 + delete(JtiSchema::oauth_used_jtis) 936 + .filter(JtiSchema::expires_at.lt(now_clone)) 937 + .execute(conn) 938 + }) 846 939 .await 847 - .context("failed to clean up expired JTIs")?; 940 + .expect("Failed to clean up expired JTIs") 941 + .expect("Failed to clean up expired JTIs"); 848 942 } 849 943 850 944 Ok(thumbprint) ··· 882 976 /// Handles both `authorization_code` and `refresh_token` grants 883 977 #[expect(clippy::too_many_lines)] 884 978 async fn token( 885 - State(db): State<Db>, 979 + State(db): State<Pool>, 886 980 State(skey): State<SigningKey>, 887 981 State(config): State<AppConfig>, 888 982 State(client): State<Client>, ··· 913 1007 == "private_key_jwt"; 914 1008 915 1009 // Verify DPoP proof 916 - let dpop_thumbprint = verify_dpop_proof( 1010 + let dpop_thumbprint_res = verify_dpop_proof( 917 1011 dpop_token, 918 1012 "POST", 919 1013 &format!("https://{}/oauth/token", config.host_name), ··· 959 1053 // } 960 1054 } else { 961 1055 // Rule 2: For public clients, check if this DPoP key has been used before 962 - let is_key_reused = sqlx::query_scalar!( 963 - r#"SELECT COUNT(*) FROM oauth_refresh_tokens WHERE dpop_thumbprint = ? AND client_id = ?"#, 964 - dpop_thumbprint, 965 - client_id 966 - ) 967 - .fetch_one(&db) 968 - .await 969 - .context("failed to check key usage history")? > 0; 1056 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1057 + let dpop_thumbprint_clone = dpop_thumbprint_res.to_owned(); 1058 + let client_id_clone = client_id.to_owned(); 1059 + let is_key_reused = db 1060 + .get() 1061 + .await 1062 + .expect("Failed to get database connection") 1063 + .interact(move |conn| { 1064 + RefreshTokenSchema::oauth_refresh_tokens 1065 + .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_clone)) 1066 + .filter(RefreshTokenSchema::client_id.eq(client_id_clone)) 1067 + .count() 1068 + .get_result::<i64>(conn) 1069 + .optional() 1070 + }) 1071 + .await 1072 + .expect("Failed to check key usage history") 1073 + .expect("Failed to check key usage history") 1074 + .unwrap_or(0) 1075 + > 0; 970 1076 971 1077 if is_key_reused && grant_type == "authorization_code" { 972 1078 return Err(Error::with_status( ··· 990 1096 let timestamp = chrono::Utc::now().timestamp(); 991 1097 992 1098 // Retrieve and validate the authorization code 993 - let auth_code = sqlx::query!( 994 - r#" 995 - SELECT * FROM oauth_authorization_codes 996 - WHERE code = ? AND client_id = ? AND redirect_uri = ? AND expires_at > ? AND used = FALSE 997 - "#, 998 - code, 999 - client_id, 1000 - redirect_uri, 1001 - timestamp 1002 - ) 1003 - .fetch_optional(&db) 1004 - .await 1005 - .context("failed to query authorization code")? 1006 - .context("authorization code not found, expired, or already used")?; 1099 + use crate::schema::pds::oauth_authorization_codes::dsl as AuthCodeSchema; 1100 + #[derive(Queryable, Selectable, Serialize)] 1101 + #[diesel(table_name = crate::schema::pds::oauth_authorization_codes)] 1102 + #[diesel(check_for_backend(sqlite::Sqlite))] 1103 + struct AuthCode { 1104 + code: String, 1105 + client_id: String, 1106 + subject: String, 1107 + code_challenge: String, 1108 + code_challenge_method: String, 1109 + redirect_uri: String, 1110 + scope: Option<String>, 1111 + created_at: i64, 1112 + expires_at: i64, 1113 + used: bool, 1114 + } 1115 + let code_clone = code.to_owned(); 1116 + let client_id_clone = client_id.to_owned(); 1117 + let redirect_uri_clone = redirect_uri.to_owned(); 1118 + let auth_code = db 1119 + .get() 1120 + .await 1121 + .expect("Failed to get database connection") 1122 + .interact(move |conn| { 1123 + AuthCodeSchema::oauth_authorization_codes 1124 + .filter(AuthCodeSchema::code.eq(code_clone)) 1125 + .filter(AuthCodeSchema::client_id.eq(client_id_clone)) 1126 + .filter(AuthCodeSchema::redirect_uri.eq(redirect_uri_clone)) 1127 + .filter(AuthCodeSchema::expires_at.gt(timestamp)) 1128 + .filter(AuthCodeSchema::used.eq(false)) 1129 + .first::<AuthCode>(conn) 1130 + .optional() 1131 + }) 1132 + .await 1133 + .expect("Failed to query authorization code") 1134 + .expect("Failed to query authorization code") 1135 + .expect("Failed to query authorization code"); 1007 1136 1008 1137 // Verify PKCE code challenge 1009 1138 verify_pkce( ··· 1013 1142 )?; 1014 1143 1015 1144 // Mark the code as used 1016 - _ = sqlx::query!( 1017 - r#"UPDATE oauth_authorization_codes SET used = TRUE WHERE code = ?"#, 1018 - code 1019 - ) 1020 - .execute(&db) 1021 - .await 1022 - .context("failed to mark code as used")?; 1145 + let code_cloned = code.to_owned(); 1146 + _ = db 1147 + .get() 1148 + .await 1149 + .expect("Failed to get database connection") 1150 + .interact(move |conn| { 1151 + update(AuthCodeSchema::oauth_authorization_codes) 1152 + .filter(AuthCodeSchema::code.eq(code_cloned)) 1153 + .set(AuthCodeSchema::used.eq(true)) 1154 + .execute(conn) 1155 + }) 1156 + .await 1157 + .expect("Failed to mark code as used") 1158 + .expect("Failed to mark code as used"); 1023 1159 1024 1160 // Generate tokens with appropriate lifetimes 1025 1161 let now = chrono::Utc::now().timestamp(); ··· 1043 1179 "exp": access_token_expires_at, 1044 1180 "iat": now, 1045 1181 "cnf": { 1046 - "jkt": dpop_thumbprint // Rule 1: Bind to DPoP key 1182 + "jkt": dpop_thumbprint_res // Rule 1: Bind to DPoP key 1047 1183 }, 1048 1184 "scope": auth_code.scope 1049 1185 }); ··· 1059 1195 "exp": refresh_token_expires_at, 1060 1196 "iat": now, 1061 1197 "cnf": { 1062 - "jkt": dpop_thumbprint // Rule 1: Bind to DPoP key 1198 + "jkt": dpop_thumbprint_res // Rule 1: Bind to DPoP key 1063 1199 }, 1064 1200 "scope": auth_code.scope 1065 1201 }); ··· 1068 1204 .context("failed to sign refresh token")?; 1069 1205 1070 1206 // Store the refresh token with DPoP binding 1071 - _ = sqlx::query!( 1072 - r#" 1073 - INSERT INTO oauth_refresh_tokens ( 1074 - token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked 1075 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 1076 - "#, 1077 - refresh_token, 1078 - client_id, 1079 - auth_code.subject, 1080 - dpop_thumbprint, 1081 - auth_code.scope, 1082 - now, 1083 - refresh_token_expires_at, 1084 - false 1085 - ) 1086 - .execute(&db) 1087 - .await 1088 - .context("failed to store refresh token")?; 1207 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1208 + let refresh_token_cloned = refresh_token.to_owned(); 1209 + let client_id_cloned = client_id.to_owned(); 1210 + let subject = auth_code.subject.to_owned(); 1211 + let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned(); 1212 + let scope = auth_code.scope.to_owned(); 1213 + let created_at = now; 1214 + let expires_at = refresh_token_expires_at; 1215 + _ = db 1216 + .get() 1217 + .await 1218 + .expect("Failed to get database connection") 1219 + .interact(move |conn| { 1220 + insert_into(RefreshTokenSchema::oauth_refresh_tokens) 1221 + .values(( 1222 + RefreshTokenSchema::token.eq(refresh_token_cloned), 1223 + RefreshTokenSchema::client_id.eq(client_id_cloned), 1224 + RefreshTokenSchema::subject.eq(subject), 1225 + RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1226 + RefreshTokenSchema::scope.eq(scope), 1227 + RefreshTokenSchema::created_at.eq(created_at), 1228 + RefreshTokenSchema::expires_at.eq(expires_at), 1229 + RefreshTokenSchema::revoked.eq(false), 1230 + )) 1231 + .execute(conn) 1232 + }) 1233 + .await 1234 + .expect("Failed to store refresh token") 1235 + .expect("Failed to store refresh token"); 1089 1236 1090 1237 // Return token response with the subject claim 1091 1238 Ok(Json(json!({ ··· 1107 1254 1108 1255 // Rules 7 & 8: Verify refresh token and DPoP consistency 1109 1256 // Retrieve the refresh token 1110 - let token_data = sqlx::query!( 1111 - r#" 1112 - SELECT * FROM oauth_refresh_tokens 1113 - WHERE token = ? AND client_id = ? AND expires_at > ? AND revoked = FALSE AND dpop_thumbprint = ? 1114 - "#, 1115 - refresh_token, 1116 - client_id, 1117 - timestamp, 1118 - dpop_thumbprint // Rule 8: Must use same DPoP key 1119 - ) 1120 - .fetch_optional(&db) 1121 - .await 1122 - .context("failed to query refresh token")? 1123 - .context("refresh token not found, expired, revoked, or invalid for this DPoP key")?; 1257 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1258 + #[derive(Queryable, Selectable, Serialize)] 1259 + #[diesel(table_name = crate::schema::pds::oauth_refresh_tokens)] 1260 + #[diesel(check_for_backend(sqlite::Sqlite))] 1261 + struct TokenData { 1262 + token: String, 1263 + client_id: String, 1264 + subject: String, 1265 + dpop_thumbprint: String, 1266 + scope: Option<String>, 1267 + created_at: i64, 1268 + expires_at: i64, 1269 + revoked: bool, 1270 + } 1271 + let dpop_thumbprint_clone = dpop_thumbprint_res.to_owned(); 1272 + let refresh_token_clone = refresh_token.to_owned(); 1273 + let client_id_clone = client_id.to_owned(); 1274 + let token_data = db 1275 + .get() 1276 + .await 1277 + .expect("Failed to get database connection") 1278 + .interact(move |conn| { 1279 + RefreshTokenSchema::oauth_refresh_tokens 1280 + .filter(RefreshTokenSchema::token.eq(refresh_token_clone)) 1281 + .filter(RefreshTokenSchema::client_id.eq(client_id_clone)) 1282 + .filter(RefreshTokenSchema::expires_at.gt(timestamp)) 1283 + .filter(RefreshTokenSchema::revoked.eq(false)) 1284 + .filter(RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_clone)) 1285 + .first::<TokenData>(conn) 1286 + .optional() 1287 + }) 1288 + .await 1289 + .expect("Failed to query refresh token") 1290 + .expect("Failed to query refresh token") 1291 + .expect("Failed to query refresh token"); 1124 1292 1125 1293 // Rule 10: For confidential clients, verify key is still advertised in their jwks 1126 1294 if is_confidential_client { 1127 1295 let client_still_advertises_key = true; // Implement actual check against client jwks 1128 1296 if !client_still_advertises_key { 1129 1297 // Revoke all tokens bound to this key 1130 - _ = sqlx::query!( 1131 - r#"UPDATE oauth_refresh_tokens SET revoked = TRUE 1132 - WHERE client_id = ? AND dpop_thumbprint = ?"#, 1133 - client_id, 1134 - dpop_thumbprint 1135 - ) 1136 - .execute(&db) 1137 - .await 1138 - .context("failed to revoke tokens")?; 1298 + let client_id_cloned = client_id.to_owned(); 1299 + let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned(); 1300 + _ = db 1301 + .get() 1302 + .await 1303 + .expect("Failed to get database connection") 1304 + .interact(move |conn| { 1305 + update(RefreshTokenSchema::oauth_refresh_tokens) 1306 + .filter(RefreshTokenSchema::client_id.eq(client_id_cloned)) 1307 + .filter( 1308 + RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1309 + ) 1310 + .set(RefreshTokenSchema::revoked.eq(true)) 1311 + .execute(conn) 1312 + }) 1313 + .await 1314 + .expect("Failed to revoke tokens") 1315 + .expect("Failed to revoke tokens"); 1139 1316 1140 1317 return Err(Error::with_status( 1141 1318 StatusCode::BAD_REQUEST, ··· 1145 1322 } 1146 1323 1147 1324 // Rotate the refresh token 1148 - _ = sqlx::query!( 1149 - r#"UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?"#, 1150 - refresh_token 1151 - ) 1152 - .execute(&db) 1153 - .await 1154 - .context("failed to revoke old refresh token")?; 1325 + let refresh_token_cloned = refresh_token.to_owned(); 1326 + _ = db 1327 + .get() 1328 + .await 1329 + .expect("Failed to get database connection") 1330 + .interact(move |conn| { 1331 + update(RefreshTokenSchema::oauth_refresh_tokens) 1332 + .filter(RefreshTokenSchema::token.eq(refresh_token_cloned)) 1333 + .set(RefreshTokenSchema::revoked.eq(true)) 1334 + .execute(conn) 1335 + }) 1336 + .await 1337 + .expect("Failed to revoke old refresh token") 1338 + .expect("Failed to revoke old refresh token"); 1155 1339 1156 1340 // Generate new tokens 1157 1341 let now = chrono::Utc::now().timestamp(); ··· 1170 1354 "exp": access_token_expires_at, 1171 1355 "iat": now, 1172 1356 "cnf": { 1173 - "jkt": dpop_thumbprint 1357 + "jkt": dpop_thumbprint_res 1174 1358 }, 1175 1359 "scope": token_data.scope 1176 1360 }); ··· 1186 1370 "exp": refresh_token_expires_at, 1187 1371 "iat": now, 1188 1372 "cnf": { 1189 - "jkt": dpop_thumbprint 1373 + "jkt": dpop_thumbprint_res 1190 1374 }, 1191 1375 "scope": token_data.scope 1192 1376 }); ··· 1195 1379 .context("failed to sign refresh token")?; 1196 1380 1197 1381 // Store the new refresh token 1198 - _ = sqlx::query!( 1199 - r#" 1200 - INSERT INTO oauth_refresh_tokens ( 1201 - token, client_id, subject, dpop_thumbprint, scope, created_at, expires_at, revoked 1202 - ) VALUES (?, ?, ?, ?, ?, ?, ?, ?) 1203 - "#, 1204 - new_refresh_token, 1205 - client_id, 1206 - token_data.subject, 1207 - dpop_thumbprint, 1208 - token_data.scope, 1209 - now, 1210 - refresh_token_expires_at, 1211 - false 1212 - ) 1213 - .execute(&db) 1214 - .await 1215 - .context("failed to store refresh token")?; 1382 + let new_refresh_token_cloned = new_refresh_token.to_owned(); 1383 + let client_id_cloned = client_id.to_owned(); 1384 + let subject = token_data.subject.to_owned(); 1385 + let dpop_thumbprint_cloned = dpop_thumbprint_res.to_owned(); 1386 + let scope = token_data.scope.to_owned(); 1387 + let created_at = now; 1388 + let expires_at = refresh_token_expires_at; 1389 + _ = db 1390 + .get() 1391 + .await 1392 + .expect("Failed to get database connection") 1393 + .interact(move |conn| { 1394 + insert_into(RefreshTokenSchema::oauth_refresh_tokens) 1395 + .values(( 1396 + RefreshTokenSchema::token.eq(new_refresh_token_cloned), 1397 + RefreshTokenSchema::client_id.eq(client_id_cloned), 1398 + RefreshTokenSchema::subject.eq(subject), 1399 + RefreshTokenSchema::dpop_thumbprint.eq(dpop_thumbprint_cloned), 1400 + RefreshTokenSchema::scope.eq(scope), 1401 + RefreshTokenSchema::created_at.eq(created_at), 1402 + RefreshTokenSchema::expires_at.eq(expires_at), 1403 + RefreshTokenSchema::revoked.eq(false), 1404 + )) 1405 + .execute(conn) 1406 + }) 1407 + .await 1408 + .expect("Failed to store refresh token") 1409 + .expect("Failed to store refresh token"); 1216 1410 1217 1411 // Return token response 1218 1412 Ok(Json(json!({ ··· 1289 1483 /// 1290 1484 /// Implements RFC7009 for revoking refresh tokens 1291 1485 async fn revoke( 1292 - State(db): State<Db>, 1486 + State(db): State<Pool>, 1293 1487 Json(form_data): Json<HashMap<String, String>>, 1294 1488 ) -> Result<Json<Value>> { 1295 1489 // Extract required parameters ··· 1308 1502 } 1309 1503 1310 1504 // Revoke the token 1311 - _ = sqlx::query!( 1312 - r#"UPDATE oauth_refresh_tokens SET revoked = TRUE WHERE token = ?"#, 1313 - token 1314 - ) 1315 - .execute(&db) 1316 - .await 1317 - .context("failed to revoke token")?; 1505 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1506 + let token_cloned = token.to_owned(); 1507 + _ = db 1508 + .get() 1509 + .await 1510 + .expect("Failed to get database connection") 1511 + .interact(move |conn| { 1512 + update(RefreshTokenSchema::oauth_refresh_tokens) 1513 + .filter(RefreshTokenSchema::token.eq(token_cloned)) 1514 + .set(RefreshTokenSchema::revoked.eq(true)) 1515 + .execute(conn) 1516 + }) 1517 + .await 1518 + .expect("Failed to revoke token") 1519 + .expect("Failed to revoke token"); 1318 1520 1319 1521 // RFC7009 requires a 200 OK with an empty response 1320 1522 Ok(Json(json!({}))) ··· 1325 1527 /// 1326 1528 /// Implements RFC7662 for introspecting tokens 1327 1529 async fn introspect( 1328 - State(db): State<Db>, 1530 + State(db): State<Pool>, 1329 1531 State(skey): State<SigningKey>, 1330 1532 Json(form_data): Json<HashMap<String, String>>, 1331 1533 ) -> Result<Json<Value>> { ··· 1368 1570 1369 1571 // For refresh tokens, check if it's been revoked 1370 1572 if is_refresh_token { 1371 - let is_revoked = sqlx::query_scalar!( 1372 - r#"SELECT revoked FROM oauth_refresh_tokens WHERE token = ?"#, 1373 - token 1374 - ) 1375 - .fetch_optional(&db) 1376 - .await 1377 - .context("failed to query token")? 1378 - .unwrap_or(true); 1573 + use crate::schema::pds::oauth_refresh_tokens::dsl as RefreshTokenSchema; 1574 + let token_cloned = token.to_owned(); 1575 + let is_revoked = db 1576 + .get() 1577 + .await 1578 + .expect("Failed to get database connection") 1579 + .interact(move |conn| { 1580 + RefreshTokenSchema::oauth_refresh_tokens 1581 + .filter(RefreshTokenSchema::token.eq(token_cloned)) 1582 + .select(RefreshTokenSchema::revoked) 1583 + .first::<bool>(conn) 1584 + .optional() 1585 + }) 1586 + .await 1587 + .expect("Failed to query token") 1588 + .expect("Failed to query token") 1589 + .unwrap_or(true); 1379 1590 1380 1591 if is_revoked { 1381 1592 return Ok(Json(json!({"active": false})));
+606
src/pipethrough.rs
··· 1 + //! Based on https://github.com/blacksky-algorithms/rsky/blob/main/rsky-pds/src/pipethrough.rs 2 + //! blacksky-algorithms/rsky is licensed under the Apache License 2.0 3 + //! 4 + //! Modified for Axum instead of Rocket 5 + 6 + use anyhow::{Result, bail}; 7 + use axum::extract::{FromRequestParts, State}; 8 + use rsky_identity::IdResolver; 9 + use rsky_pds::apis::ApiError; 10 + use rsky_pds::auth_verifier::{AccessOutput, AccessStandard}; 11 + use rsky_pds::config::{ServerConfig, ServiceConfig, env_to_cfg}; 12 + use rsky_pds::pipethrough::{OverrideOpts, ProxyHeader, UrlAndAud}; 13 + use rsky_pds::xrpc_server::types::{HandlerPipeThrough, InvalidRequestError, XRPCError}; 14 + use rsky_pds::{APP_USER_AGENT, SharedIdResolver, context}; 15 + // use lazy_static::lazy_static; 16 + use reqwest::header::{CONTENT_TYPE, HeaderValue}; 17 + use reqwest::{Client, Method, RequestBuilder, Response}; 18 + // use rocket::data::ToByteUnit; 19 + // use rocket::http::{Method, Status}; 20 + // use rocket::request::{FromRequest, Outcome, Request}; 21 + // use rocket::{Data, State}; 22 + use axum::{ 23 + body::Bytes, 24 + http::{self, HeaderMap}, 25 + }; 26 + use rsky_common::{GetServiceEndpointOpts, get_service_endpoint}; 27 + use rsky_repo::types::Ids; 28 + use serde::de::DeserializeOwned; 29 + use serde_json::Value as JsonValue; 30 + use std::collections::{BTreeMap, HashSet}; 31 + use std::str::FromStr; 32 + use std::sync::Arc; 33 + use std::time::Duration; 34 + use ubyte::ToByteUnit as _; 35 + use url::Url; 36 + 37 + use crate::serve::AppState; 38 + 39 + // pub struct OverrideOpts { 40 + // pub aud: Option<String>, 41 + // pub lxm: Option<String>, 42 + // } 43 + 44 + // pub struct UrlAndAud { 45 + // pub url: Url, 46 + // pub aud: String, 47 + // pub lxm: String, 48 + // } 49 + 50 + // pub struct ProxyHeader { 51 + // pub did: String, 52 + // pub service_url: String, 53 + // } 54 + 55 + pub struct ProxyRequest { 56 + pub headers: BTreeMap<String, String>, 57 + pub query: Option<String>, 58 + pub path: String, 59 + pub method: Method, 60 + pub id_resolver: Arc<tokio::sync::RwLock<rsky_identity::IdResolver>>, 61 + pub cfg: ServerConfig, 62 + } 63 + impl FromRequestParts<AppState> for ProxyRequest { 64 + // type Rejection = ApiError; 65 + type Rejection = axum::response::Response; 66 + 67 + async fn from_request_parts( 68 + parts: &mut axum::http::request::Parts, 69 + state: &AppState, 70 + ) -> Result<Self, Self::Rejection> { 71 + let headers = parts 72 + .headers 73 + .iter() 74 + .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string())) 75 + .collect::<BTreeMap<String, String>>(); 76 + let query = parts.uri.query().map(|s| s.to_string()); 77 + let path = parts.uri.path().to_string(); 78 + let method = parts.method.clone(); 79 + let id_resolver = state.id_resolver.clone(); 80 + // let cfg = state.cfg.clone(); 81 + let cfg = env_to_cfg(); // TODO: use state.cfg.clone(); 82 + 83 + Ok(Self { 84 + headers, 85 + query, 86 + path, 87 + method, 88 + id_resolver, 89 + cfg, 90 + }) 91 + } 92 + } 93 + 94 + // #[rocket::async_trait] 95 + // impl<'r> FromRequest<'r> for HandlerPipeThrough { 96 + // type Error = anyhow::Error; 97 + 98 + // #[tracing::instrument(skip_all)] 99 + // async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { 100 + // match AccessStandard::from_request(req).await { 101 + // Outcome::Success(output) => { 102 + // let AccessOutput { credentials, .. } = output.access; 103 + // let requester: Option<String> = match credentials { 104 + // None => None, 105 + // Some(credentials) => credentials.did, 106 + // }; 107 + // let headers = req.headers().clone().into_iter().fold( 108 + // BTreeMap::new(), 109 + // |mut acc: BTreeMap<String, String>, cur| { 110 + // let _ = acc.insert(cur.name().to_string(), cur.value().to_string()); 111 + // acc 112 + // }, 113 + // ); 114 + // let proxy_req = ProxyRequest { 115 + // headers, 116 + // query: match req.uri().query() { 117 + // None => None, 118 + // Some(query) => Some(query.to_string()), 119 + // }, 120 + // path: req.uri().path().to_string(), 121 + // method: req.method(), 122 + // id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(), 123 + // cfg: req.guard::<&State<ServerConfig>>().await.unwrap(), 124 + // }; 125 + // match pipethrough( 126 + // &proxy_req, 127 + // requester, 128 + // OverrideOpts { 129 + // aud: None, 130 + // lxm: None, 131 + // }, 132 + // ) 133 + // .await 134 + // { 135 + // Ok(res) => Outcome::Success(res), 136 + // Err(error) => match error.downcast_ref() { 137 + // Some(InvalidRequestError::XRPCError(xrpc)) => { 138 + // if let XRPCError::FailedResponse { 139 + // status, 140 + // error, 141 + // message, 142 + // headers, 143 + // } = xrpc 144 + // { 145 + // tracing::error!( 146 + // "@LOG: XRPC ERROR Status:{status}; Message: {message:?}; Error: {error:?}; Headers: {headers:?}" 147 + // ); 148 + // } 149 + // req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string()))); 150 + // Outcome::Error((Status::BadRequest, error)) 151 + // } 152 + // _ => { 153 + // req.local_cache(|| Some(ApiError::InvalidRequest(error.to_string()))); 154 + // Outcome::Error((Status::BadRequest, error)) 155 + // } 156 + // }, 157 + // } 158 + // } 159 + // Outcome::Error(err) => { 160 + // req.local_cache(|| Some(ApiError::RuntimeError)); 161 + // Outcome::Error(( 162 + // Status::BadRequest, 163 + // anyhow::Error::new(InvalidRequestError::AuthError(err.1)), 164 + // )) 165 + // } 166 + // _ => panic!("Unexpected outcome during Pipethrough"), 167 + // } 168 + // } 169 + // } 170 + 171 + // #[rocket::async_trait] 172 + // impl<'r> FromRequest<'r> for ProxyRequest<'r> { 173 + // type Error = anyhow::Error; 174 + 175 + // async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> { 176 + // let headers = req.headers().clone().into_iter().fold( 177 + // BTreeMap::new(), 178 + // |mut acc: BTreeMap<String, String>, cur| { 179 + // let _ = acc.insert(cur.name().to_string(), cur.value().to_string()); 180 + // acc 181 + // }, 182 + // ); 183 + // Outcome::Success(Self { 184 + // headers, 185 + // query: match req.uri().query() { 186 + // None => None, 187 + // Some(query) => Some(query.to_string()), 188 + // }, 189 + // path: req.uri().path().to_string(), 190 + // method: req.method(), 191 + // id_resolver: req.guard::<&State<SharedIdResolver>>().await.unwrap(), 192 + // cfg: req.guard::<&State<ServerConfig>>().await.unwrap(), 193 + // }) 194 + // } 195 + // } 196 + 197 + pub async fn pipethrough( 198 + req: &ProxyRequest, 199 + requester: Option<String>, 200 + override_opts: OverrideOpts, 201 + ) -> Result<HandlerPipeThrough> { 202 + let UrlAndAud { 203 + url, 204 + aud, 205 + lxm: nsid, 206 + } = format_url_and_aud(req, override_opts.aud).await?; 207 + let lxm = override_opts.lxm.unwrap_or(nsid); 208 + let headers = format_headers(req, aud, lxm, requester).await?; 209 + let req_init = format_req_init(req, url, headers, None)?; 210 + let res = make_request(req_init).await?; 211 + parse_proxy_res(res).await 212 + } 213 + 214 + pub async fn pipethrough_procedure<T: serde::Serialize>( 215 + req: &ProxyRequest, 216 + requester: Option<String>, 217 + body: Option<T>, 218 + ) -> Result<HandlerPipeThrough> { 219 + let UrlAndAud { 220 + url, 221 + aud, 222 + lxm: nsid, 223 + } = format_url_and_aud(req, None).await?; 224 + let headers = format_headers(req, aud, nsid, requester).await?; 225 + let encoded_body: Option<Vec<u8>> = match body { 226 + None => None, 227 + Some(body) => Some(serde_json::to_string(&body)?.into_bytes()), 228 + }; 229 + let req_init = format_req_init(req, url, headers, encoded_body)?; 230 + let res = make_request(req_init).await?; 231 + parse_proxy_res(res).await 232 + } 233 + 234 + #[tracing::instrument(skip_all)] 235 + pub async fn pipethrough_procedure_post( 236 + req: &ProxyRequest, 237 + requester: Option<String>, 238 + body: Option<Bytes>, 239 + ) -> Result<HandlerPipeThrough, ApiError> { 240 + let UrlAndAud { 241 + url, 242 + aud, 243 + lxm: nsid, 244 + } = format_url_and_aud(req, None).await?; 245 + let headers = format_headers(req, aud, nsid, requester).await?; 246 + let encoded_body: Option<JsonValue>; 247 + match body { 248 + None => encoded_body = None, 249 + Some(body) => { 250 + // let res = match body.open(50.megabytes()).into_string().await { 251 + // Ok(res1) => { 252 + // tracing::info!(res1.value); 253 + // res1.value 254 + // } 255 + // Err(error) => { 256 + // tracing::error!("{error}"); 257 + // return Err(ApiError::RuntimeError); 258 + // } 259 + // }; 260 + let res = String::from_utf8(body.to_vec()).expect("Invalid UTF-8"); 261 + 262 + match serde_json::from_str(res.as_str()) { 263 + Ok(res) => { 264 + encoded_body = Some(res); 265 + } 266 + Err(error) => { 267 + tracing::error!("{error}"); 268 + return Err(ApiError::RuntimeError); 269 + } 270 + } 271 + } 272 + }; 273 + let req_init = format_req_init_with_value(req, url, headers, encoded_body)?; 274 + let res = make_request(req_init).await?; 275 + Ok(parse_proxy_res(res).await?) 276 + } 277 + 278 + // Request setup/formatting 279 + // ------------------- 280 + 281 + const REQ_HEADERS_TO_FORWARD: [&str; 4] = [ 282 + "accept-language", 283 + "content-type", 284 + "atproto-accept-labelers", 285 + "x-bsky-topics", 286 + ]; 287 + 288 + #[tracing::instrument(skip_all)] 289 + pub async fn format_url_and_aud( 290 + req: &ProxyRequest, 291 + aud_override: Option<String>, 292 + ) -> Result<UrlAndAud> { 293 + let proxy_to = parse_proxy_header(req).await?; 294 + let nsid = parse_req_nsid(req); 295 + let default_proxy = default_service(req, &nsid).await; 296 + let service_url = match proxy_to { 297 + Some(ref proxy_to) => { 298 + tracing::info!( 299 + "@LOG: format_url_and_aud() proxy_to: {:?}", 300 + proxy_to.service_url 301 + ); 302 + Some(proxy_to.service_url.clone()) 303 + } 304 + None => match default_proxy { 305 + Some(ref default_proxy) => Some(default_proxy.url.clone()), 306 + None => None, 307 + }, 308 + }; 309 + let aud = match aud_override { 310 + Some(_) => aud_override, 311 + None => match proxy_to { 312 + Some(proxy_to) => Some(proxy_to.did), 313 + None => match default_proxy { 314 + Some(default_proxy) => Some(default_proxy.did), 315 + None => None, 316 + }, 317 + }, 318 + }; 319 + match (service_url, aud) { 320 + (Some(service_url), Some(aud)) => { 321 + let mut url = Url::parse(format!("{0}{1}", service_url, req.path).as_str())?; 322 + if let Some(ref params) = req.query { 323 + url.set_query(Some(params.as_str())); 324 + } 325 + if !req.cfg.service.dev_mode && !is_safe_url(url.clone()) { 326 + bail!(InvalidRequestError::InvalidServiceUrl(url.to_string())); 327 + } 328 + Ok(UrlAndAud { 329 + url, 330 + aud, 331 + lxm: nsid, 332 + }) 333 + } 334 + _ => bail!(InvalidRequestError::NoServiceConfigured(req.path.clone())), 335 + } 336 + } 337 + 338 + pub async fn format_headers( 339 + req: &ProxyRequest, 340 + aud: String, 341 + lxm: String, 342 + requester: Option<String>, 343 + ) -> Result<HeaderMap> { 344 + let mut headers: HeaderMap = match requester { 345 + Some(requester) => context::service_auth_headers(&requester, &aud, &lxm).await?, 346 + None => HeaderMap::new(), 347 + }; 348 + // forward select headers to upstream services 349 + for header in REQ_HEADERS_TO_FORWARD { 350 + let val = req.headers.get(header); 351 + if let Some(val) = val { 352 + headers.insert(header, HeaderValue::from_str(val)?); 353 + } 354 + } 355 + Ok(headers) 356 + } 357 + 358 + pub fn format_req_init( 359 + req: &ProxyRequest, 360 + url: Url, 361 + headers: HeaderMap, 362 + body: Option<Vec<u8>>, 363 + ) -> Result<RequestBuilder> { 364 + match req.method { 365 + Method::GET => { 366 + let client = Client::builder() 367 + .user_agent(APP_USER_AGENT) 368 + .http2_keep_alive_while_idle(true) 369 + .http2_keep_alive_timeout(Duration::from_secs(5)) 370 + .default_headers(headers) 371 + .build()?; 372 + Ok(client.get(url)) 373 + } 374 + Method::HEAD => { 375 + let client = Client::builder() 376 + .user_agent(APP_USER_AGENT) 377 + .http2_keep_alive_while_idle(true) 378 + .http2_keep_alive_timeout(Duration::from_secs(5)) 379 + .default_headers(headers) 380 + .build()?; 381 + Ok(client.head(url)) 382 + } 383 + Method::POST => { 384 + let client = Client::builder() 385 + .user_agent(APP_USER_AGENT) 386 + .http2_keep_alive_while_idle(true) 387 + .http2_keep_alive_timeout(Duration::from_secs(5)) 388 + .default_headers(headers) 389 + .build()?; 390 + Ok(client.post(url).body(body.unwrap())) 391 + } 392 + _ => bail!(InvalidRequestError::MethodNotFound), 393 + } 394 + } 395 + 396 + pub fn format_req_init_with_value( 397 + req: &ProxyRequest, 398 + url: Url, 399 + headers: HeaderMap, 400 + body: Option<JsonValue>, 401 + ) -> Result<RequestBuilder> { 402 + match req.method { 403 + Method::GET => { 404 + let client = Client::builder() 405 + .user_agent(APP_USER_AGENT) 406 + .http2_keep_alive_while_idle(true) 407 + .http2_keep_alive_timeout(Duration::from_secs(5)) 408 + .default_headers(headers) 409 + .build()?; 410 + Ok(client.get(url)) 411 + } 412 + Method::HEAD => { 413 + let client = Client::builder() 414 + .user_agent(APP_USER_AGENT) 415 + .http2_keep_alive_while_idle(true) 416 + .http2_keep_alive_timeout(Duration::from_secs(5)) 417 + .default_headers(headers) 418 + .build()?; 419 + Ok(client.head(url)) 420 + } 421 + Method::POST => { 422 + let client = Client::builder() 423 + .user_agent(APP_USER_AGENT) 424 + .http2_keep_alive_while_idle(true) 425 + .http2_keep_alive_timeout(Duration::from_secs(5)) 426 + .default_headers(headers) 427 + .build()?; 428 + Ok(client.post(url).json(&body.unwrap())) 429 + } 430 + _ => bail!(InvalidRequestError::MethodNotFound), 431 + } 432 + } 433 + 434 + pub async fn parse_proxy_header(req: &ProxyRequest) -> Result<Option<ProxyHeader>> { 435 + let headers = &req.headers; 436 + let proxy_to: Option<&String> = headers.get("atproto-proxy"); 437 + match proxy_to { 438 + None => Ok(None), 439 + Some(proxy_to) => { 440 + let parts: Vec<&str> = proxy_to.split("#").collect::<Vec<&str>>(); 441 + match (parts.get(0), parts.get(1), parts.get(2)) { 442 + (Some(did), Some(service_id), None) => { 443 + let did = did.to_string(); 444 + let mut lock = req.id_resolver.write().await; 445 + match lock.did.resolve(did.clone(), None).await? { 446 + None => bail!(InvalidRequestError::CannotResolveProxyDid), 447 + Some(did_doc) => { 448 + match get_service_endpoint( 449 + did_doc, 450 + GetServiceEndpointOpts { 451 + id: format!("#{service_id}"), 452 + r#type: None, 453 + }, 454 + ) { 455 + None => bail!(InvalidRequestError::CannotResolveServiceUrl), 456 + Some(service_url) => Ok(Some(ProxyHeader { did, service_url })), 457 + } 458 + } 459 + } 460 + } 461 + (_, None, _) => bail!(InvalidRequestError::NoServiceId), 462 + _ => bail!("error parsing atproto-proxy header"), 463 + } 464 + } 465 + } 466 + } 467 + 468 + pub fn parse_req_nsid(req: &ProxyRequest) -> String { 469 + let nsid = req.path.as_str().replace("/xrpc/", ""); 470 + match nsid.ends_with("/") { 471 + false => nsid, 472 + true => nsid 473 + .trim_end_matches(|c| c == nsid.chars().last().unwrap()) 474 + .to_string(), 475 + } 476 + } 477 + 478 + // Sending request 479 + // ------------------- 480 + #[tracing::instrument(skip_all)] 481 + pub async fn make_request(req_init: RequestBuilder) -> Result<Response> { 482 + let res = req_init.send().await; 483 + match res { 484 + Err(e) => { 485 + tracing::error!("@LOG WARN: pipethrough network error {}", e.to_string()); 486 + bail!(InvalidRequestError::XRPCError(XRPCError::UpstreamFailure)) 487 + } 488 + Ok(res) => match res.error_for_status_ref() { 489 + Ok(_) => Ok(res), 490 + Err(_) => { 491 + let status = res.status().to_string(); 492 + let headers = res.headers().clone(); 493 + let error_body = res.json::<JsonValue>().await?; 494 + bail!(InvalidRequestError::XRPCError(XRPCError::FailedResponse { 495 + status, 496 + headers, 497 + error: match error_body["error"].as_str() { 498 + None => None, 499 + Some(error_body_error) => Some(error_body_error.to_string()), 500 + }, 501 + message: match error_body["message"].as_str() { 502 + None => None, 503 + Some(error_body_message) => Some(error_body_message.to_string()), 504 + } 505 + })) 506 + } 507 + }, 508 + } 509 + } 510 + 511 + // Response parsing/forwarding 512 + // ------------------- 513 + 514 + const RES_HEADERS_TO_FORWARD: [&str; 4] = [ 515 + "content-type", 516 + "content-language", 517 + "atproto-repo-rev", 518 + "atproto-content-labelers", 519 + ]; 520 + 521 + pub async fn parse_proxy_res(res: Response) -> Result<HandlerPipeThrough> { 522 + let encoding = match res.headers().get(CONTENT_TYPE) { 523 + Some(content_type) => content_type.to_str()?, 524 + None => "application/json", 525 + }; 526 + // Release borrow 527 + let encoding = encoding.to_string(); 528 + let res_headers = RES_HEADERS_TO_FORWARD.into_iter().fold( 529 + BTreeMap::new(), 530 + |mut acc: BTreeMap<String, String>, cur| { 531 + let _ = match res.headers().get(cur) { 532 + Some(res_header_val) => acc.insert( 533 + cur.to_string(), 534 + res_header_val.clone().to_str().unwrap().to_string(), 535 + ), 536 + None => None, 537 + }; 538 + acc 539 + }, 540 + ); 541 + let buffer = read_array_buffer_res(res).await?; 542 + Ok(HandlerPipeThrough { 543 + encoding, 544 + buffer, 545 + headers: Some(res_headers), 546 + }) 547 + } 548 + 549 + // Utils 550 + // ------------------- 551 + 552 + pub async fn default_service(req: &ProxyRequest, nsid: &str) -> Option<ServiceConfig> { 553 + let cfg = req.cfg.clone(); 554 + match Ids::from_str(nsid) { 555 + Ok(Ids::ToolsOzoneTeamAddMember) => cfg.mod_service, 556 + Ok(Ids::ToolsOzoneTeamDeleteMember) => cfg.mod_service, 557 + Ok(Ids::ToolsOzoneTeamUpdateMember) => cfg.mod_service, 558 + Ok(Ids::ToolsOzoneTeamListMembers) => cfg.mod_service, 559 + Ok(Ids::ToolsOzoneCommunicationCreateTemplate) => cfg.mod_service, 560 + Ok(Ids::ToolsOzoneCommunicationDeleteTemplate) => cfg.mod_service, 561 + Ok(Ids::ToolsOzoneCommunicationUpdateTemplate) => cfg.mod_service, 562 + Ok(Ids::ToolsOzoneCommunicationListTemplates) => cfg.mod_service, 563 + Ok(Ids::ToolsOzoneModerationEmitEvent) => cfg.mod_service, 564 + Ok(Ids::ToolsOzoneModerationGetEvent) => cfg.mod_service, 565 + Ok(Ids::ToolsOzoneModerationGetRecord) => cfg.mod_service, 566 + Ok(Ids::ToolsOzoneModerationGetRepo) => cfg.mod_service, 567 + Ok(Ids::ToolsOzoneModerationQueryEvents) => cfg.mod_service, 568 + Ok(Ids::ToolsOzoneModerationQueryStatuses) => cfg.mod_service, 569 + Ok(Ids::ToolsOzoneModerationSearchRepos) => cfg.mod_service, 570 + Ok(Ids::ComAtprotoModerationCreateReport) => cfg.report_service, 571 + _ => cfg.bsky_app_view, 572 + } 573 + } 574 + 575 + pub fn parse_res<T: DeserializeOwned>(_nsid: String, res: HandlerPipeThrough) -> Result<T> { 576 + let buffer = res.buffer; 577 + let record = serde_json::from_slice::<T>(buffer.as_slice())?; 578 + Ok(record) 579 + } 580 + 581 + #[tracing::instrument(skip_all)] 582 + pub async fn read_array_buffer_res(res: Response) -> Result<Vec<u8>> { 583 + match res.bytes().await { 584 + Ok(bytes) => Ok(bytes.to_vec()), 585 + Err(err) => { 586 + tracing::error!("@LOG WARN: pipethrough network error {}", err.to_string()); 587 + bail!("UpstreamFailure") 588 + } 589 + } 590 + } 591 + 592 + pub fn is_safe_url(url: Url) -> bool { 593 + if url.scheme() != "https" { 594 + return false; 595 + } 596 + match url.host_str() { 597 + None => false, 598 + Some(hostname) if hostname == "localhost" => false, 599 + Some(hostname) => { 600 + if std::net::IpAddr::from_str(hostname).is_ok() { 601 + return false; 602 + } 603 + true 604 + } 605 + } 606 + }
-114
src/plc.rs
··· 1 - //! PLC operations. 2 - use std::collections::HashMap; 3 - 4 - use anyhow::{Context as _, bail}; 5 - use base64::Engine as _; 6 - use serde::{Deserialize, Serialize}; 7 - use tracing::debug; 8 - 9 - use crate::{Client, RotationKey}; 10 - 11 - /// The URL of the public PLC directory. 12 - const PLC_DIRECTORY: &str = "https://plc.directory/"; 13 - 14 - #[derive(Debug, Deserialize, Serialize, Clone)] 15 - #[serde(rename_all = "camelCase", tag = "type")] 16 - /// A PLC service. 17 - pub(crate) enum PlcService { 18 - #[serde(rename = "AtprotoPersonalDataServer")] 19 - /// A personal data server. 20 - Pds { 21 - /// The URL of the PDS. 22 - endpoint: String, 23 - }, 24 - } 25 - 26 - #[expect( 27 - clippy::arbitrary_source_item_ordering, 28 - reason = "serialized data might be structured" 29 - )] 30 - #[derive(Debug, Deserialize, Serialize, Clone)] 31 - #[serde(rename_all = "camelCase")] 32 - pub(crate) struct PlcOperation { 33 - #[serde(rename = "type")] 34 - pub typ: String, 35 - pub rotation_keys: Vec<String>, 36 - pub verification_methods: HashMap<String, String>, 37 - pub also_known_as: Vec<String>, 38 - pub services: HashMap<String, PlcService>, 39 - pub prev: Option<String>, 40 - } 41 - 42 - impl PlcOperation { 43 - /// Sign an operation with the provided signature. 44 - pub(crate) fn sign(self, sig: Vec<u8>) -> SignedPlcOperation { 45 - SignedPlcOperation { 46 - typ: self.typ, 47 - rotation_keys: self.rotation_keys, 48 - verification_methods: self.verification_methods, 49 - also_known_as: self.also_known_as, 50 - services: self.services, 51 - prev: self.prev, 52 - sig: base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(sig), 53 - } 54 - } 55 - } 56 - 57 - #[expect( 58 - clippy::arbitrary_source_item_ordering, 59 - reason = "serialized data might be structured" 60 - )] 61 - #[derive(Debug, Deserialize, Serialize, Clone)] 62 - #[serde(rename_all = "camelCase")] 63 - /// A signed PLC operation. 64 - pub(crate) struct SignedPlcOperation { 65 - #[serde(rename = "type")] 66 - pub typ: String, 67 - pub rotation_keys: Vec<String>, 68 - pub verification_methods: HashMap<String, String>, 69 - pub also_known_as: Vec<String>, 70 - pub services: HashMap<String, PlcService>, 71 - pub prev: Option<String>, 72 - pub sig: String, 73 - } 74 - 75 - pub(crate) fn sign_op(rkey: &RotationKey, op: PlcOperation) -> anyhow::Result<SignedPlcOperation> { 76 - let bytes = serde_ipld_dagcbor::to_vec(&op).context("failed to encode op")?; 77 - let bytes = rkey.sign(&bytes).context("failed to sign op")?; 78 - 79 - Ok(op.sign(bytes)) 80 - } 81 - 82 - /// Submit a PLC operation to the public directory. 83 - pub(crate) async fn submit( 84 - client: &Client, 85 - did: &str, 86 - op: &SignedPlcOperation, 87 - ) -> anyhow::Result<()> { 88 - debug!( 89 - "submitting {} {}", 90 - did, 91 - serde_json::to_string(&op).context("should serialize")? 92 - ); 93 - 94 - let res = client 95 - .post(format!("{PLC_DIRECTORY}{did}")) 96 - .json(&op) 97 - .send() 98 - .await 99 - .context("failed to send directory request")?; 100 - 101 - if res.status().is_success() { 102 - Ok(()) 103 - } else { 104 - let e = res 105 - .json::<serde_json::Value>() 106 - .await 107 - .context("failed to read error response")?; 108 - 109 - bail!( 110 - "error from PLC directory: {}", 111 - serde_json::to_string(&e).context("should serialize")? 112 - ); 113 - } 114 - }
+212 -15
src/schema.rs
··· 1 + #![allow(unnameable_types, unused_qualifications)] 1 2 pub mod pds { 2 3 3 4 // Legacy tables 4 5 5 6 diesel::table! { 6 - pds.oauth_par_requests (request_uri) { 7 + oauth_par_requests (request_uri) { 7 8 request_uri -> Varchar, 8 9 client_id -> Varchar, 9 10 response_type -> Varchar, ··· 20 21 } 21 22 } 22 23 diesel::table! { 23 - pds.oauth_authorization_codes (code) { 24 + oauth_authorization_codes (code) { 24 25 code -> Varchar, 25 26 client_id -> Varchar, 26 27 subject -> Varchar, ··· 34 35 } 35 36 } 36 37 diesel::table! { 37 - pds.oauth_refresh_tokens (token) { 38 + oauth_refresh_tokens (token) { 38 39 token -> Varchar, 39 40 client_id -> Varchar, 40 41 subject -> Varchar, ··· 46 47 } 47 48 } 48 49 diesel::table! { 49 - pds.oauth_used_jtis (jti) { 50 + oauth_used_jtis (jti) { 50 51 jti -> Varchar, 51 52 issuer -> Varchar, 52 53 created_at -> Int8, ··· 57 58 // Upcoming tables 58 59 59 60 diesel::table! { 60 - pds.authorization_request (id) { 61 + account (did) { 62 + did -> Varchar, 63 + email -> Varchar, 64 + recoveryKey -> Nullable<Varchar>, 65 + password -> Varchar, 66 + createdAt -> Varchar, 67 + invitesDisabled -> Int2, 68 + emailConfirmedAt -> Nullable<Varchar>, 69 + } 70 + } 71 + 72 + diesel::table! { 73 + actor (did) { 74 + did -> Varchar, 75 + handle -> Nullable<Varchar>, 76 + createdAt -> Varchar, 77 + takedownRef -> Nullable<Varchar>, 78 + deactivatedAt -> Nullable<Varchar>, 79 + deleteAfter -> Nullable<Varchar>, 80 + } 81 + } 82 + 83 + diesel::table! { 84 + app_password (did, name) { 85 + did -> Varchar, 86 + name -> Varchar, 87 + password -> Varchar, 88 + createdAt -> Varchar, 89 + } 90 + } 91 + 92 + diesel::table! { 93 + authorization_request (id) { 61 94 id -> Varchar, 62 95 did -> Nullable<Varchar>, 63 96 deviceId -> Nullable<Varchar>, 64 97 clientId -> Varchar, 65 98 clientAuth -> Varchar, 66 99 parameters -> Varchar, 67 - expiresAt -> Timestamptz, 100 + expiresAt -> TimestamptzSqlite, 68 101 code -> Nullable<Varchar>, 69 102 } 70 103 } 71 104 72 105 diesel::table! { 73 - pds.device (id) { 106 + device (id) { 74 107 id -> Varchar, 75 108 sessionId -> Nullable<Varchar>, 76 109 userAgent -> Nullable<Varchar>, 77 110 ipAddress -> Varchar, 78 - lastSeenAt -> Timestamptz, 111 + lastSeenAt -> TimestamptzSqlite, 79 112 } 80 113 } 81 114 82 115 diesel::table! { 83 - pds.device_account (deviceId, did) { 116 + device_account (deviceId, did) { 84 117 did -> Varchar, 85 118 deviceId -> Varchar, 86 - authenticatedAt -> Timestamptz, 119 + authenticatedAt -> TimestamptzSqlite, 87 120 remember -> Bool, 88 121 authorizedClients -> Varchar, 89 122 } 90 123 } 91 124 92 125 diesel::table! { 93 - pds.token (id) { 126 + did_doc (did) { 127 + did -> Varchar, 128 + doc -> Text, 129 + updatedAt -> Int8, 130 + } 131 + } 132 + 133 + diesel::table! { 134 + email_token (purpose, did) { 135 + purpose -> Varchar, 136 + did -> Varchar, 137 + token -> Varchar, 138 + requestedAt -> Varchar, 139 + } 140 + } 141 + 142 + diesel::table! { 143 + invite_code (code) { 144 + code -> Varchar, 145 + availableUses -> Int4, 146 + disabled -> Int2, 147 + forAccount -> Varchar, 148 + createdBy -> Varchar, 149 + createdAt -> Varchar, 150 + } 151 + } 152 + 153 + diesel::table! { 154 + invite_code_use (code, usedBy) { 155 + code -> Varchar, 156 + usedBy -> Varchar, 157 + usedAt -> Varchar, 158 + } 159 + } 160 + 161 + diesel::table! { 162 + refresh_token (id) { 163 + id -> Varchar, 164 + did -> Varchar, 165 + expiresAt -> Varchar, 166 + nextId -> Nullable<Varchar>, 167 + appPasswordName -> Nullable<Varchar>, 168 + } 169 + } 170 + 171 + diesel::table! { 172 + repo_seq (seq) { 173 + seq -> Int8, 174 + did -> Varchar, 175 + eventType -> Varchar, 176 + event -> Bytea, 177 + invalidated -> Int2, 178 + sequencedAt -> Varchar, 179 + } 180 + } 181 + 182 + diesel::table! { 183 + token (id) { 94 184 id -> Varchar, 95 185 did -> Varchar, 96 186 tokenId -> Varchar, 97 - createdAt -> Timestamptz, 98 - updatedAt -> Timestamptz, 99 - expiresAt -> Timestamptz, 187 + createdAt -> TimestamptzSqlite, 188 + updatedAt -> TimestamptzSqlite, 189 + expiresAt -> TimestamptzSqlite, 100 190 clientId -> Varchar, 101 191 clientAuth -> Varchar, 102 192 deviceId -> Nullable<Varchar>, ··· 108 198 } 109 199 110 200 diesel::table! { 111 - pds.used_refresh_token (refreshToken) { 201 + used_refresh_token (refreshToken) { 112 202 refreshToken -> Varchar, 113 203 tokenId -> Varchar, 114 204 } 115 205 } 206 + 207 + diesel::allow_tables_to_appear_in_same_query!( 208 + account, 209 + actor, 210 + app_password, 211 + authorization_request, 212 + device, 213 + device_account, 214 + did_doc, 215 + email_token, 216 + invite_code, 217 + invite_code_use, 218 + refresh_token, 219 + repo_seq, 220 + token, 221 + used_refresh_token, 222 + ); 223 + } 224 + 225 + pub mod actor_store { 226 + // Actor Store 227 + 228 + // Blob 229 + diesel::table! { 230 + blob (cid, did) { 231 + cid -> Varchar, 232 + did -> Varchar, 233 + mimeType -> Varchar, 234 + size -> Int4, 235 + tempKey -> Nullable<Varchar>, 236 + width -> Nullable<Int4>, 237 + height -> Nullable<Int4>, 238 + createdAt -> Varchar, 239 + takedownRef -> Nullable<Varchar>, 240 + } 241 + } 242 + 243 + diesel::table! { 244 + record_blob (blobCid, recordUri) { 245 + blobCid -> Varchar, 246 + recordUri -> Varchar, 247 + did -> Varchar, 248 + } 249 + } 250 + 251 + // Preference 252 + 253 + diesel::table! { 254 + account_pref (id) { 255 + id -> Int4, 256 + did -> Varchar, 257 + name -> Varchar, 258 + valueJson -> Nullable<Text>, 259 + } 260 + } 261 + // Record 262 + 263 + diesel::table! { 264 + record (uri) { 265 + uri -> Varchar, 266 + cid -> Varchar, 267 + did -> Varchar, 268 + collection -> Varchar, 269 + rkey -> Varchar, 270 + repoRev -> Nullable<Varchar>, 271 + indexedAt -> Varchar, 272 + takedownRef -> Nullable<Varchar>, 273 + } 274 + } 275 + 276 + diesel::table! { 277 + repo_block (cid, did) { 278 + cid -> Varchar, 279 + did -> Varchar, 280 + repoRev -> Varchar, 281 + size -> Int4, 282 + content -> Bytea, 283 + } 284 + } 285 + 286 + diesel::table! { 287 + backlink (uri, path) { 288 + uri -> Varchar, 289 + path -> Varchar, 290 + linkTo -> Varchar, 291 + } 292 + } 293 + // sql_repo 294 + 295 + diesel::table! { 296 + repo_root (did) { 297 + did -> Varchar, 298 + cid -> Varchar, 299 + rev -> Varchar, 300 + indexedAt -> Varchar, 301 + } 302 + } 303 + 304 + diesel::allow_tables_to_appear_in_same_query!( 305 + account_pref, 306 + backlink, 307 + blob, 308 + record, 309 + record_blob, 310 + repo_block, 311 + repo_root, 312 + ); 116 313 }
+429
src/serve.rs
··· 1 + use super::account_manager::AccountManager; 2 + use super::config::AppConfig; 3 + use super::db::establish_pool; 4 + pub use super::error::Error; 5 + use super::service_proxy::service_proxy; 6 + use anyhow::Context as _; 7 + use atrium_api::types::string::Did; 8 + use atrium_crypto::keypair::{Export as _, Secp256k1Keypair}; 9 + use axum::{Router, extract::FromRef, routing::get}; 10 + use clap::Parser; 11 + use clap_verbosity_flag::{InfoLevel, Verbosity, log::LevelFilter}; 12 + use deadpool_diesel::sqlite::Pool; 13 + use diesel::prelude::*; 14 + use diesel_migrations::{EmbeddedMigrations, embed_migrations}; 15 + use figment::{Figment, providers::Format as _}; 16 + use http_cache_reqwest::{CacheMode, HttpCacheOptions, MokaManager}; 17 + use rsky_common::env::env_list; 18 + use rsky_identity::IdResolver; 19 + use rsky_identity::types::{DidCache, IdentityResolverOpts}; 20 + use rsky_pds::{crawlers::Crawlers, sequencer::Sequencer}; 21 + use serde::{Deserialize, Serialize}; 22 + use std::env; 23 + use std::{ 24 + net::{IpAddr, Ipv4Addr, SocketAddr}, 25 + path::PathBuf, 26 + str::FromStr as _, 27 + sync::Arc, 28 + }; 29 + use tokio::{net::TcpListener, sync::RwLock}; 30 + use tower_http::{cors::CorsLayer, trace::TraceLayer}; 31 + use tracing::{info, warn}; 32 + use uuid::Uuid; 33 + 34 + /// The application user agent. Concatenates the package name and version. e.g. `bluepds/0.0.0`. 35 + pub const APP_USER_AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"),); 36 + 37 + /// Embedded migrations 38 + pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations"); 39 + pub const MIGRATIONS_ACTOR: EmbeddedMigrations = embed_migrations!("./migrations_actor"); 40 + 41 + /// The application-wide result type. 42 + pub type Result<T> = std::result::Result<T, Error>; 43 + /// The reqwest client type with middleware. 44 + pub type Client = reqwest_middleware::ClientWithMiddleware; 45 + 46 + #[expect( 47 + clippy::arbitrary_source_item_ordering, 48 + reason = "serialized data might be structured" 49 + )] 50 + #[derive(Serialize, Deserialize, Debug, Clone)] 51 + /// The key data structure. 52 + struct KeyData { 53 + /// Primary signing key for all repo operations. 54 + skey: Vec<u8>, 55 + /// Primary signing (rotation) key for all PLC operations. 56 + rkey: Vec<u8>, 57 + } 58 + 59 + // FIXME: We should use P256Keypair instead. SecP256K1 is primarily used for cryptocurrencies, 60 + // and the implementations of this algorithm are much more limited as compared to P256. 61 + // 62 + // Reference: https://soatok.blog/2022/05/19/guidance-for-choosing-an-elliptic-curve-signature-algorithm-in-2022/ 63 + #[derive(Clone)] 64 + /// The signing key for PLC/DID operations. 65 + pub struct SigningKey(Arc<Secp256k1Keypair>); 66 + #[derive(Clone)] 67 + /// The rotation key for PLC operations. 68 + pub struct RotationKey(Arc<Secp256k1Keypair>); 69 + 70 + impl std::ops::Deref for SigningKey { 71 + type Target = Secp256k1Keypair; 72 + 73 + fn deref(&self) -> &Self::Target { 74 + &self.0 75 + } 76 + } 77 + 78 + impl SigningKey { 79 + /// Import from a private key. 80 + pub fn import(key: &[u8]) -> Result<Self> { 81 + let key = Secp256k1Keypair::import(key).context("failed to import signing key")?; 82 + Ok(Self(Arc::new(key))) 83 + } 84 + } 85 + 86 + impl std::ops::Deref for RotationKey { 87 + type Target = Secp256k1Keypair; 88 + 89 + fn deref(&self) -> &Self::Target { 90 + &self.0 91 + } 92 + } 93 + 94 + #[derive(Parser, Debug, Clone)] 95 + /// Command line arguments. 96 + pub struct Args { 97 + /// Path to the configuration file 98 + #[arg(short, long, default_value = "default.toml")] 99 + pub config: PathBuf, 100 + /// The verbosity level. 101 + #[command(flatten)] 102 + pub verbosity: Verbosity<InfoLevel>, 103 + } 104 + 105 + /// The actor pools for the database connections. 106 + pub struct ActorStorage { 107 + /// The database connection pool for the actor's repository. 108 + pub repo: Pool, 109 + /// The file storage path for the actor's blobs. 110 + pub blob: PathBuf, 111 + } 112 + 113 + impl Clone for ActorStorage { 114 + fn clone(&self) -> Self { 115 + Self { 116 + repo: self.repo.clone(), 117 + blob: self.blob.clone(), 118 + } 119 + } 120 + } 121 + 122 + #[expect(clippy::arbitrary_source_item_ordering, reason = "arbitrary")] 123 + #[derive(Clone, FromRef)] 124 + /// The application state, shared across all routes. 125 + pub struct AppState { 126 + /// The application configuration. 127 + pub(crate) config: AppConfig, 128 + /// The main database connection pool. Used for common PDS data, like invite codes. 129 + pub db: Pool, 130 + /// Actor-specific database connection pools. Hashed by DID. 131 + pub db_actors: std::collections::HashMap<String, ActorStorage>, 132 + 133 + /// The HTTP client with middleware. 134 + pub client: Client, 135 + /// The simple HTTP client. 136 + pub simple_client: reqwest::Client, 137 + /// The firehose producer. 138 + pub sequencer: Arc<RwLock<Sequencer>>, 139 + /// The account manager. 140 + pub account_manager: Arc<RwLock<AccountManager>>, 141 + /// The ID resolver. 142 + pub id_resolver: Arc<RwLock<IdResolver>>, 143 + 144 + /// The signing key. 145 + pub signing_key: SigningKey, 146 + /// The rotation key. 147 + pub rotation_key: RotationKey, 148 + } 149 + 150 + /// The main application entry point. 151 + #[expect( 152 + clippy::cognitive_complexity, 153 + clippy::too_many_lines, 154 + unused_qualifications, 155 + reason = "main function has high complexity" 156 + )] 157 + pub async fn run() -> anyhow::Result<()> { 158 + let args = Args::parse(); 159 + 160 + // Set up trace logging to console and account for the user-provided verbosity flag. 161 + if args.verbosity.log_level_filter() != LevelFilter::Off { 162 + let lvl = match args.verbosity.log_level_filter() { 163 + LevelFilter::Error => tracing::Level::ERROR, 164 + LevelFilter::Warn => tracing::Level::WARN, 165 + LevelFilter::Info | LevelFilter::Off => tracing::Level::INFO, 166 + LevelFilter::Debug => tracing::Level::DEBUG, 167 + LevelFilter::Trace => tracing::Level::TRACE, 168 + }; 169 + tracing_subscriber::fmt().with_max_level(lvl).init(); 170 + } 171 + 172 + if !args.config.exists() { 173 + // Throw up a warning if the config file does not exist. 174 + // 175 + // This is not fatal because users can specify all configuration settings via 176 + // the environment, but the most likely scenario here is that a user accidentally 177 + // omitted the config file for some reason (e.g. forgot to mount it into Docker). 178 + warn!( 179 + "configuration file {} does not exist", 180 + args.config.display() 181 + ); 182 + } 183 + 184 + // Read and parse the user-provided configuration. 185 + let config: AppConfig = Figment::new() 186 + .admerge(figment::providers::Toml::file(args.config)) 187 + .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 188 + .extract() 189 + .context("failed to load configuration")?; 190 + 191 + if config.test { 192 + warn!("BluePDS starting up in TEST mode."); 193 + warn!("This means the application will not federate with the rest of the network."); 194 + warn!( 195 + "If you want to turn this off, either set `test` to false in the config or define `BLUEPDS_TEST = false`" 196 + ); 197 + } 198 + 199 + // Initialize metrics reporting. 200 + super::metrics::setup(config.metrics.as_ref()).context("failed to set up metrics exporter")?; 201 + 202 + // Create a reqwest client that will be used for all outbound requests. 203 + let simple_client = reqwest::Client::builder() 204 + .user_agent(APP_USER_AGENT) 205 + .build() 206 + .context("failed to build requester client")?; 207 + let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 208 + .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 209 + mode: CacheMode::Default, 210 + manager: MokaManager::default(), 211 + options: HttpCacheOptions::default(), 212 + })) 213 + .build(); 214 + 215 + tokio::fs::create_dir_all(&config.key.parent().context("should have parent")?) 216 + .await 217 + .context("failed to create key directory")?; 218 + 219 + // Check if crypto keys exist. If not, create new ones. 220 + let (skey, rkey) = if let Ok(f) = std::fs::File::open(&config.key) { 221 + let keys: KeyData = serde_ipld_dagcbor::from_reader(std::io::BufReader::new(f)) 222 + .context("failed to deserialize crypto keys")?; 223 + 224 + let skey = Secp256k1Keypair::import(&keys.skey).context("failed to import signing key")?; 225 + let rkey = Secp256k1Keypair::import(&keys.rkey).context("failed to import rotation key")?; 226 + 227 + (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 228 + } else { 229 + info!("signing keys not found, generating new ones"); 230 + 231 + let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 232 + let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 233 + 234 + let keys = KeyData { 235 + skey: skey.export(), 236 + rkey: rkey.export(), 237 + }; 238 + 239 + let mut f = std::fs::File::create(&config.key).context("failed to create key file")?; 240 + serde_ipld_dagcbor::to_writer(&mut f, &keys).context("failed to serialize crypto keys")?; 241 + 242 + (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 243 + }; 244 + 245 + tokio::fs::create_dir_all(&config.repo.path).await?; 246 + tokio::fs::create_dir_all(&config.plc.path).await?; 247 + tokio::fs::create_dir_all(&config.blob.path).await?; 248 + 249 + // Create a database connection manager and pool for the main database. 250 + let pool = 251 + establish_pool(&config.db).context("failed to establish database connection pool")?; 252 + 253 + // Create a dictionary of database connection pools for each actor. 254 + let mut actor_pools = std::collections::HashMap::new(); 255 + // We'll determine actors by looking in the data/repo dir for .db files. 256 + let mut actor_dbs = tokio::fs::read_dir(&config.repo.path) 257 + .await 258 + .context("failed to read repo directory")?; 259 + while let Some(entry) = actor_dbs 260 + .next_entry() 261 + .await 262 + .context("failed to read repo dir")? 263 + { 264 + let path = entry.path(); 265 + if path.extension().and_then(|s| s.to_str()) == Some("db") { 266 + let actor_repo_pool = establish_pool(&format!("sqlite://{}", path.display())) 267 + .context("failed to create database connection pool")?; 268 + 269 + let did = Did::from_str(&format!( 270 + "did:plc:{}", 271 + path.file_stem() 272 + .and_then(|s| s.to_str()) 273 + .context("failed to get actor DID")? 274 + )) 275 + .expect("should be able to parse actor DID") 276 + .to_string(); 277 + let blob_path = config.blob.path.to_path_buf(); 278 + let actor_storage = ActorStorage { 279 + repo: actor_repo_pool, 280 + blob: blob_path.clone(), 281 + }; 282 + drop(actor_pools.insert(did, actor_storage)); 283 + } 284 + } 285 + // Apply pending migrations 286 + // let conn = pool.get().await?; 287 + // conn.run_pending_migrations(MIGRATIONS) 288 + // .expect("should be able to run migrations"); 289 + 290 + let hostname = config.host_name.clone(); 291 + let crawlers: Vec<String> = config 292 + .firehose 293 + .relays 294 + .iter() 295 + .map(|s| s.to_string()) 296 + .collect(); 297 + let sequencer = Arc::new(RwLock::new(Sequencer::new( 298 + Crawlers::new(hostname, crawlers.clone()), 299 + None, 300 + ))); 301 + let account_manager = Arc::new(RwLock::new(AccountManager::new(pool.clone()))); 302 + let plc_url = if cfg!(debug_assertions) { 303 + "http://localhost:8000".to_owned() // dummy for debug 304 + } else { 305 + env::var("PDS_DID_PLC_URL").unwrap_or("https://plc.directory".to_owned()) // TODO: toml config 306 + }; 307 + let id_resolver = Arc::new(RwLock::new(IdResolver::new(IdentityResolverOpts { 308 + timeout: None, 309 + plc_url: Some(plc_url), 310 + did_cache: Some(DidCache::new(None, None)), 311 + backup_nameservers: Some(env_list("PDS_HANDLE_BACKUP_NAMESERVERS")), 312 + }))); 313 + 314 + let addr = config 315 + .listen_address 316 + .unwrap_or(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 8000)); 317 + 318 + let app = Router::new() 319 + .route("/", get(super::index)) 320 + .merge(super::oauth::routes()) 321 + .nest( 322 + "/xrpc", 323 + super::apis::routes() 324 + .merge(super::actor_endpoints::routes()) 325 + .fallback(service_proxy), 326 + ) 327 + // .layer(RateLimitLayer::new(30, Duration::from_secs(30))) 328 + .layer(CorsLayer::permissive()) 329 + .layer(TraceLayer::new_for_http()) 330 + .with_state(AppState { 331 + config: config.clone(), 332 + db: pool.clone(), 333 + db_actors: actor_pools.clone(), 334 + client: client.clone(), 335 + simple_client, 336 + sequencer: sequencer.clone(), 337 + account_manager, 338 + id_resolver, 339 + signing_key: skey, 340 + rotation_key: rkey, 341 + }); 342 + 343 + info!("listening on {addr}"); 344 + info!("connect to: http://127.0.0.1:{}", addr.port()); 345 + 346 + // Determine whether or not this was the first startup (i.e. no accounts exist and no invite codes were created). 347 + // If so, create an invite code and share it via the console. 348 + let conn = pool.get().await.context("failed to get db connection")?; 349 + 350 + #[derive(QueryableByName)] 351 + struct TotalCount { 352 + #[diesel(sql_type = diesel::sql_types::Integer)] 353 + total_count: i32, 354 + } 355 + 356 + let result = conn.interact(move |conn| { 357 + diesel::sql_query( 358 + "SELECT (SELECT COUNT(*) FROM account) + (SELECT COUNT(*) FROM invite_code) AS total_count", 359 + ) 360 + .get_result::<TotalCount>(conn) 361 + }) 362 + .await 363 + .expect("should be able to query database")?; 364 + 365 + let c = result.total_count; 366 + 367 + #[expect(clippy::print_stdout)] 368 + if c == 0 { 369 + let uuid = Uuid::new_v4().to_string(); 370 + 371 + use crate::models::pds as models; 372 + use crate::schema::pds::invite_code::dsl as InviteCode; 373 + let uuid_clone = uuid.clone(); 374 + drop( 375 + conn.interact(move |conn| { 376 + diesel::insert_into(InviteCode::invite_code) 377 + .values(models::InviteCode { 378 + code: uuid_clone, 379 + available_uses: 1, 380 + disabled: 0, 381 + for_account: "None".to_owned(), 382 + created_by: "None".to_owned(), 383 + created_at: "None".to_owned(), 384 + }) 385 + .execute(conn) 386 + .context("failed to create new invite code") 387 + }) 388 + .await 389 + .expect("should be able to create invite code"), 390 + ); 391 + 392 + // N.B: This is a sensitive message, so we're bypassing `tracing` here and 393 + // logging it directly to console. 394 + println!("====================================="); 395 + println!(" FIRST STARTUP "); 396 + println!("====================================="); 397 + println!("Use this code to create an account:"); 398 + println!("{uuid}"); 399 + println!("====================================="); 400 + } 401 + 402 + let listener = TcpListener::bind(&addr) 403 + .await 404 + .context("failed to bind address")?; 405 + 406 + // Serve the app, and request crawling from upstream relays. 407 + let serve = tokio::spawn(async move { 408 + axum::serve(listener, app.into_make_service()) 409 + .await 410 + .context("failed to serve app") 411 + }); 412 + 413 + // Now that the app is live, request a crawl from upstream relays. 414 + if cfg!(debug_assertions) { 415 + info!("debug mode: not requesting crawl"); 416 + } else { 417 + info!("requesting crawl from upstream relays"); 418 + let mut background_sequencer = sequencer.write().await.clone(); 419 + drop(tokio::spawn( 420 + async move { background_sequencer.start().await }, 421 + )); 422 + } 423 + 424 + serve 425 + .await 426 + .map_err(Into::into) 427 + .and_then(|r| r) 428 + .context("failed to serve app") 429 + }
+123
src/service_proxy.rs
··· 1 + /// Service proxy. 2 + /// 3 + /// Reference: <https://atproto.com/specs/xrpc#service-proxying> 4 + use anyhow::{Context as _, anyhow}; 5 + use atrium_api::types::string::Did; 6 + use axum::{ 7 + body::Body, 8 + extract::{Request, State}, 9 + http::{self, HeaderMap, Response, StatusCode, Uri}, 10 + }; 11 + use rand::Rng as _; 12 + use std::str::FromStr as _; 13 + 14 + use super::{ 15 + auth::AuthenticatedUser, 16 + serve::{Client, Error, Result, SigningKey}, 17 + }; 18 + 19 + pub(super) async fn service_proxy( 20 + uri: Uri, 21 + user: AuthenticatedUser, 22 + State(skey): State<SigningKey>, 23 + State(client): State<reqwest::Client>, 24 + headers: HeaderMap, 25 + request: Request<Body>, 26 + ) -> Result<Response<Body>> { 27 + let url_path = uri.path_and_query().context("invalid service proxy url")?; 28 + let lxm = url_path 29 + .path() 30 + .strip_prefix("/") 31 + .with_context(|| format!("invalid service proxy url prefix: {}", url_path.path()))?; 32 + 33 + let user_did = user.did(); 34 + let (did, id) = match headers.get("atproto-proxy") { 35 + Some(val) => { 36 + let val = 37 + std::str::from_utf8(val.as_bytes()).context("proxy header not valid utf-8")?; 38 + 39 + let (did, id) = val.split_once('#').context("invalid proxy header")?; 40 + 41 + let did = 42 + Did::from_str(did).map_err(|e| anyhow!("atproto proxy not a valid DID: {e}"))?; 43 + 44 + (did, format!("#{id}")) 45 + } 46 + // HACK: Assume the bluesky appview by default. 47 + None => ( 48 + Did::new("did:web:api.bsky.app".to_owned()) 49 + .expect("service proxy should be a valid DID"), 50 + "#bsky_appview".to_owned(), 51 + ), 52 + }; 53 + 54 + let did_doc = super::did::resolve(&Client::new(client.clone(), []), did.clone()) 55 + .await 56 + .with_context(|| format!("failed to resolve did document {}", did.as_str()))?; 57 + 58 + let Some(service) = did_doc.service.iter().find(|s| s.id == id) else { 59 + return Err(Error::with_status( 60 + StatusCode::BAD_REQUEST, 61 + anyhow!("could not find resolve service #{id}"), 62 + )); 63 + }; 64 + 65 + let target_url: url::Url = service 66 + .service_endpoint 67 + .join(&format!("/xrpc{url_path}")) 68 + .context("failed to construct target url")?; 69 + 70 + let exp = (chrono::Utc::now().checked_add_signed(chrono::Duration::minutes(1))) 71 + .context("should be valid expiration datetime")? 72 + .timestamp(); 73 + let jti = rand::thread_rng() 74 + .sample_iter(rand::distributions::Alphanumeric) 75 + .take(10) 76 + .map(char::from) 77 + .collect::<String>(); 78 + 79 + // Mint a bearer token by signing a JSON web token. 80 + // https://github.com/DavidBuchanan314/millipds/blob/5c7529a739d394e223c0347764f1cf4e8fd69f94/src/millipds/appview_proxy.py#L47-L59 81 + let token = super::auth::sign( 82 + &skey, 83 + "JWT", 84 + &serde_json::json!({ 85 + "iss": user_did.as_str(), 86 + "aud": did.as_str(), 87 + "lxm": lxm, 88 + "exp": exp, 89 + "jti": jti, 90 + }), 91 + ) 92 + .context("failed to sign jwt")?; 93 + 94 + let mut h = HeaderMap::new(); 95 + if let Some(hdr) = request.headers().get("atproto-accept-labelers") { 96 + drop(h.insert("atproto-accept-labelers", hdr.clone())); 97 + } 98 + if let Some(hdr) = request.headers().get(http::header::CONTENT_TYPE) { 99 + drop(h.insert(http::header::CONTENT_TYPE, hdr.clone())); 100 + } 101 + 102 + let r = client 103 + .request(request.method().clone(), target_url) 104 + .headers(h) 105 + .header(http::header::AUTHORIZATION, format!("Bearer {token}")) 106 + .body(reqwest::Body::wrap_stream( 107 + request.into_body().into_data_stream(), 108 + )) 109 + .send() 110 + .await 111 + .context("failed to send request")?; 112 + 113 + let mut resp = Response::builder().status(r.status()); 114 + if let Some(hdrs) = resp.headers_mut() { 115 + *hdrs = r.headers().clone(); 116 + } 117 + 118 + let resp = resp 119 + .body(Body::from_stream(r.bytes_stream())) 120 + .context("failed to construct response")?; 121 + 122 + Ok(resp) 123 + }
-459
src/tests.rs
··· 1 - //! Testing utilities for the PDS. 2 - #![expect(clippy::arbitrary_source_item_ordering)] 3 - use std::{ 4 - net::{IpAddr, Ipv4Addr, SocketAddr, TcpListener}, 5 - path::PathBuf, 6 - time::{Duration, Instant}, 7 - }; 8 - 9 - use anyhow::Result; 10 - use atrium_api::{ 11 - com::atproto::server, 12 - types::string::{AtIdentifier, Did, Handle, Nsid, RecordKey}, 13 - }; 14 - use figment::{Figment, providers::Format as _}; 15 - use futures::future::join_all; 16 - use serde::{Deserialize, Serialize}; 17 - use tokio::sync::OnceCell; 18 - use uuid::Uuid; 19 - 20 - use crate::config::AppConfig; 21 - 22 - /// Global test state, created once for all tests. 23 - pub(crate) static TEST_STATE: OnceCell<TestState> = OnceCell::const_new(); 24 - 25 - /// A temporary test directory that will be cleaned up when the struct is dropped. 26 - struct TempDir { 27 - /// The path to the directory. 28 - path: PathBuf, 29 - } 30 - 31 - impl TempDir { 32 - /// Create a new temporary directory. 33 - fn new() -> Result<Self> { 34 - let path = std::env::temp_dir().join(format!("bluepds-test-{}", Uuid::new_v4())); 35 - std::fs::create_dir_all(&path)?; 36 - Ok(Self { path }) 37 - } 38 - 39 - /// Get the path to the directory. 40 - fn path(&self) -> &PathBuf { 41 - &self.path 42 - } 43 - } 44 - 45 - impl Drop for TempDir { 46 - fn drop(&mut self) { 47 - drop(std::fs::remove_dir_all(&self.path)); 48 - } 49 - } 50 - 51 - /// Test state for the application. 52 - pub(crate) struct TestState { 53 - /// The address the test server is listening on. 54 - address: SocketAddr, 55 - /// The HTTP client. 56 - client: reqwest::Client, 57 - /// The application configuration. 58 - config: AppConfig, 59 - /// The temporary directory for test data. 60 - #[expect(dead_code)] 61 - temp_dir: TempDir, 62 - } 63 - 64 - impl TestState { 65 - /// Get a base URL for the test server. 66 - pub(crate) fn base_url(&self) -> String { 67 - format!("http://{}", self.address) 68 - } 69 - 70 - /// Create a test account. 71 - pub(crate) async fn create_test_account(&self) -> Result<TestAccount> { 72 - // Create the account 73 - let handle = "test.handle"; 74 - let response = self 75 - .client 76 - .post(format!( 77 - "http://{}/xrpc/com.atproto.server.createAccount", 78 - self.address 79 - )) 80 - .json(&server::create_account::InputData { 81 - did: None, 82 - verification_code: None, 83 - verification_phone: None, 84 - email: Some(format!("{}@example.com", &handle)), 85 - handle: Handle::new(handle.to_owned()).expect("should be able to create handle"), 86 - password: Some("password123".to_owned()), 87 - invite_code: None, 88 - recovery_key: None, 89 - plc_op: None, 90 - }) 91 - .send() 92 - .await?; 93 - 94 - let account: server::create_account::Output = response.json().await?; 95 - 96 - Ok(TestAccount { 97 - handle: handle.to_owned(), 98 - did: account.did.to_string(), 99 - access_token: account.access_jwt.clone(), 100 - refresh_token: account.refresh_jwt.clone(), 101 - }) 102 - } 103 - 104 - /// Create a new test state. 105 - #[expect(clippy::unused_async)] 106 - async fn new() -> Result<Self> { 107 - // Configure the test app 108 - #[derive(Serialize, Deserialize)] 109 - struct TestConfigInput { 110 - db: Option<String>, 111 - host_name: Option<String>, 112 - key: Option<PathBuf>, 113 - listen_address: Option<SocketAddr>, 114 - test: Option<bool>, 115 - } 116 - // Create a temporary directory for test data 117 - let temp_dir = TempDir::new()?; 118 - 119 - // Find a free port 120 - let listener = TcpListener::bind(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), 0))?; 121 - let address = listener.local_addr()?; 122 - drop(listener); 123 - 124 - let test_config = TestConfigInput { 125 - db: Some(format!("sqlite://{}/test.db", temp_dir.path().display())), 126 - host_name: Some(format!("localhost:{}", address.port())), 127 - key: Some(temp_dir.path().join("test.key")), 128 - listen_address: Some(address), 129 - test: Some(true), 130 - }; 131 - 132 - let config: AppConfig = Figment::new() 133 - .admerge(figment::providers::Toml::file("default.toml")) 134 - .admerge(figment::providers::Env::prefixed("BLUEPDS_")) 135 - .merge(figment::providers::Serialized::defaults(test_config)) 136 - .merge( 137 - figment::providers::Toml::string( 138 - r#" 139 - [firehose] 140 - relays = [] 141 - 142 - [repo] 143 - path = "repo" 144 - 145 - [plc] 146 - path = "plc" 147 - 148 - [blob] 149 - path = "blob" 150 - limit = 10485760 # 10 MB 151 - "#, 152 - ) 153 - .nested(), 154 - ) 155 - .extract()?; 156 - 157 - // Create directories 158 - std::fs::create_dir_all(temp_dir.path().join("repo"))?; 159 - std::fs::create_dir_all(temp_dir.path().join("plc"))?; 160 - std::fs::create_dir_all(temp_dir.path().join("blob"))?; 161 - 162 - // Create client 163 - let client = reqwest::Client::builder() 164 - .timeout(Duration::from_secs(30)) 165 - .build()?; 166 - 167 - Ok(Self { 168 - address, 169 - client, 170 - config, 171 - temp_dir, 172 - }) 173 - } 174 - 175 - /// Start the application in a background task. 176 - async fn start_app(&self) -> Result<()> { 177 - // // Get a reference to the config that can be moved into the task 178 - // let config = self.config.clone(); 179 - // let address = self.address; 180 - 181 - // // Start the application in a background task 182 - // let _handle = tokio::spawn(async move { 183 - // // Set up the application 184 - // use crate::*; 185 - 186 - // // Initialize metrics (noop in test mode) 187 - // drop(metrics::setup(None)); 188 - 189 - // // Create client 190 - // let simple_client = reqwest::Client::builder() 191 - // .user_agent(APP_USER_AGENT) 192 - // .build() 193 - // .context("failed to build requester client")?; 194 - // let client = reqwest_middleware::ClientBuilder::new(simple_client.clone()) 195 - // .with(http_cache_reqwest::Cache(http_cache_reqwest::HttpCache { 196 - // mode: CacheMode::Default, 197 - // manager: MokaManager::default(), 198 - // options: HttpCacheOptions::default(), 199 - // })) 200 - // .build(); 201 - 202 - // // Create a test keypair 203 - // std::fs::create_dir_all(config.key.parent().context("should have parent")?)?; 204 - // let (skey, rkey) = { 205 - // let skey = Secp256k1Keypair::create(&mut rand::thread_rng()); 206 - // let rkey = Secp256k1Keypair::create(&mut rand::thread_rng()); 207 - 208 - // let keys = KeyData { 209 - // skey: skey.export(), 210 - // rkey: rkey.export(), 211 - // }; 212 - 213 - // let mut f = 214 - // std::fs::File::create(&config.key).context("failed to create key file")?; 215 - // serde_ipld_dagcbor::to_writer(&mut f, &keys) 216 - // .context("failed to serialize crypto keys")?; 217 - 218 - // (SigningKey(Arc::new(skey)), RotationKey(Arc::new(rkey))) 219 - // }; 220 - 221 - // // Set up database 222 - // let opts = SqliteConnectOptions::from_str(&config.db) 223 - // .context("failed to parse database options")? 224 - // .create_if_missing(true); 225 - // let db = SqliteDbConn::connect_with(opts).await?; 226 - 227 - // sqlx::migrate!() 228 - // .run(&db) 229 - // .await 230 - // .context("failed to apply migrations")?; 231 - 232 - // // Create firehose 233 - // let (_fh, fhp) = firehose::spawn(client.clone(), config.clone()); 234 - 235 - // // Create the application state 236 - // let app_state = AppState { 237 - // cred: azure_identity::DefaultAzureCredential::new()?, 238 - // config: config.clone(), 239 - // db: db.clone(), 240 - // client: client.clone(), 241 - // simple_client, 242 - // firehose: fhp, 243 - // signing_key: skey, 244 - // rotation_key: rkey, 245 - // }; 246 - 247 - // // Create the router 248 - // let app = Router::new() 249 - // .route("/", get(index)) 250 - // .merge(oauth::routes()) 251 - // .nest( 252 - // "/xrpc", 253 - // endpoints::routes() 254 - // .merge(actor_endpoints::routes()) 255 - // .fallback(service_proxy), 256 - // ) 257 - // .layer(CorsLayer::permissive()) 258 - // .layer(TraceLayer::new_for_http()) 259 - // .with_state(app_state); 260 - 261 - // // Listen for connections 262 - // let listener = TcpListener::bind(&address) 263 - // .await 264 - // .context("failed to bind address")?; 265 - 266 - // axum::serve(listener, app.into_make_service()) 267 - // .await 268 - // .context("failed to serve app") 269 - // }); 270 - 271 - // // Give the server a moment to start 272 - // tokio::time::sleep(Duration::from_millis(500)).await; 273 - 274 - Ok(()) 275 - } 276 - } 277 - 278 - /// A test account that can be used for testing. 279 - pub(crate) struct TestAccount { 280 - /// The access token for the account. 281 - pub(crate) access_token: String, 282 - /// The account DID. 283 - pub(crate) did: String, 284 - /// The account handle. 285 - pub(crate) handle: String, 286 - /// The refresh token for the account. 287 - #[expect(dead_code)] 288 - pub(crate) refresh_token: String, 289 - } 290 - 291 - /// Initialize the test state. 292 - pub(crate) async fn init_test_state() -> Result<&'static TestState> { 293 - async fn init_test_state() -> std::result::Result<TestState, anyhow::Error> { 294 - let state = TestState::new().await?; 295 - state.start_app().await?; 296 - Ok(state) 297 - } 298 - TEST_STATE.get_or_try_init(init_test_state).await 299 - } 300 - 301 - /// Create a record benchmark that creates records and measures the time it takes. 302 - #[expect( 303 - clippy::arithmetic_side_effects, 304 - clippy::integer_division, 305 - clippy::integer_division_remainder_used, 306 - clippy::use_debug, 307 - clippy::print_stdout 308 - )] 309 - pub(crate) async fn create_record_benchmark(count: usize, concurrent: usize) -> Result<Duration> { 310 - // Initialize the test state 311 - let state = init_test_state().await?; 312 - 313 - // Create a test account 314 - let account = state.create_test_account().await?; 315 - 316 - // Create the client with authorization 317 - let client = reqwest::Client::builder() 318 - .timeout(Duration::from_secs(30)) 319 - .build()?; 320 - 321 - let start = Instant::now(); 322 - 323 - // Split the work into batches 324 - let mut handles = Vec::new(); 325 - for batch_idx in 0..concurrent { 326 - let batch_size = count / concurrent; 327 - let client = client.clone(); 328 - let base_url = state.base_url(); 329 - let account_did = account.did.clone(); 330 - let account_handle = account.handle.clone(); 331 - let access_token = account.access_token.clone(); 332 - 333 - let handle = tokio::spawn(async move { 334 - let mut results = Vec::new(); 335 - 336 - for i in 0..batch_size { 337 - let request_start = Instant::now(); 338 - let record_idx = batch_idx * batch_size + i; 339 - 340 - let result = client 341 - .post(format!("{base_url}/xrpc/com.atproto.repo.createRecord")) 342 - .header("Authorization", format!("Bearer {access_token}")) 343 - .json(&atrium_api::com::atproto::repo::create_record::InputData { 344 - repo: AtIdentifier::Did(Did::new(account_did.clone()).expect("valid DID")), 345 - collection: Nsid::new("app.bsky.feed.post".to_owned()).expect("valid NSID"), 346 - rkey: Some( 347 - RecordKey::new(format!("test-{record_idx}")).expect("valid record key"), 348 - ), 349 - validate: None, 350 - record: serde_json::from_str( 351 - &serde_json::json!({ 352 - "$type": "app.bsky.feed.post", 353 - "text": format!("Test post {record_idx} from {account_handle}"), 354 - "createdAt": chrono::Utc::now().to_rfc3339(), 355 - }) 356 - .to_string(), 357 - ) 358 - .expect("valid JSON record"), 359 - swap_commit: None, 360 - }) 361 - .send() 362 - .await; 363 - 364 - // Fetch the record we just created 365 - let get_response = client 366 - .get(format!( 367 - "{base_url}/xrpc/com.atproto.sync.getRecord?did={account_did}&collection=app.bsky.feed.post&rkey={record_idx}" 368 - )) 369 - .header("Authorization", format!("Bearer {access_token}")) 370 - .send() 371 - .await; 372 - if get_response.is_err() { 373 - println!("Failed to fetch record {record_idx}: {get_response:?}"); 374 - results.push(get_response); 375 - continue; 376 - } 377 - 378 - let request_duration = request_start.elapsed(); 379 - if record_idx % 10 == 0 { 380 - println!("Created record {record_idx} in {request_duration:?}"); 381 - } 382 - results.push(result); 383 - } 384 - 385 - results 386 - }); 387 - 388 - handles.push(handle); 389 - } 390 - 391 - // Wait for all batches to complete 392 - let results = join_all(handles).await; 393 - 394 - // Check for errors 395 - for batch_result in results { 396 - let batch_responses = batch_result?; 397 - for response_result in batch_responses { 398 - match response_result { 399 - Ok(response) => { 400 - if !response.status().is_success() { 401 - return Err(anyhow::anyhow!( 402 - "Failed to create record: {}", 403 - response.status() 404 - )); 405 - } 406 - } 407 - Err(err) => { 408 - return Err(anyhow::anyhow!("Failed to create record: {}", err)); 409 - } 410 - } 411 - } 412 - } 413 - 414 - let duration = start.elapsed(); 415 - Ok(duration) 416 - } 417 - 418 - #[cfg(test)] 419 - #[expect(clippy::module_inception, clippy::use_debug, clippy::print_stdout)] 420 - mod tests { 421 - use super::*; 422 - use anyhow::anyhow; 423 - 424 - #[tokio::test] 425 - async fn test_create_account() -> Result<()> { 426 - return Ok(()); 427 - #[expect(unreachable_code, reason = "Disabled")] 428 - let state = init_test_state().await?; 429 - let account = state.create_test_account().await?; 430 - 431 - println!("Created test account: {}", account.handle); 432 - if account.handle.is_empty() { 433 - return Err(anyhow::anyhow!("Account handle is empty")); 434 - } 435 - if account.did.is_empty() { 436 - return Err(anyhow::anyhow!("Account DID is empty")); 437 - } 438 - if account.access_token.is_empty() { 439 - return Err(anyhow::anyhow!("Account access token is empty")); 440 - } 441 - 442 - Ok(()) 443 - } 444 - 445 - #[tokio::test] 446 - async fn test_create_record_benchmark() -> Result<()> { 447 - return Ok(()); 448 - #[expect(unreachable_code, reason = "Disabled")] 449 - let duration = create_record_benchmark(100, 1).await?; 450 - 451 - println!("Created 100 records in {duration:?}"); 452 - 453 - if duration.as_secs() >= 10 { 454 - return Err(anyhow!("Benchmark took too long")); 455 - } 456 - 457 - Ok(()) 458 - } 459 - }