A from-scratch atproto PDS implementation in Python (mirrors https://github.com/DavidBuchanan314/millipds)

refactoring to fix a bunch of test warnings

+5 -2
.vscode/settings.json
··· 6 6 "-p", 7 7 "test_*.py" 8 8 ], 9 - "python.testing.pytestEnabled": false, 10 - "python.testing.unittestEnabled": true 9 + "python.testing.pytestEnabled": true, 10 + "python.testing.unittestEnabled": false, 11 + "python.testing.pytestArgs": [ 12 + "." 13 + ] 11 14 }
+3
pyproject.toml
··· 55 55 indent-style = "tab" 56 56 57 57 [tool.setuptools_scm] 58 + 59 + [tool.pytest.ini_options] 60 + asyncio_default_fixture_loop_scope = "session"
+11 -8
src/millipds/__main__.py
··· 67 67 from getpass import getpass 68 68 69 69 from docopt import docopt 70 + import aiohttp 70 71 71 72 import cbrrr 72 73 ··· 215 216 else: 216 217 print("invalid account subcommand") 217 218 elif args["run"]: 218 - asyncio.run( 219 - service.run( 220 - db=db, 221 - sock_path=args["--sock_path"], 222 - host=args["--listen_host"], 223 - port=int(args["--listen_port"]), 224 - ) 225 - ) 219 + async def run_with_client(): 220 + async with aiohttp.ClientSession() as client: 221 + await service.run( 222 + db=db, 223 + client=client, 224 + sock_path=args["--sock_path"], 225 + host=args["--listen_host"], 226 + port=int(args["--listen_port"]), 227 + ) 228 + asyncio.run(run_with_client()) 226 229 else: 227 230 print("CLI arg parse error?!") 228 231
+12 -8
src/millipds/app_util.py
··· 6 6 7 7 from . import database 8 8 9 + MILLIPDS_DB = web.AppKey("MILLIPDS_DB", database.Database) 10 + MILLIPDS_AIOHTTP_CLIENT = web.AppKey("MILLIPDS_AIOHTTP_CLIENT", aiohttp.ClientSession) 11 + MILLIPDS_FIREHOSE_QUEUES = web.AppKey("MILLIPDS_FIREHOSE_QUEUES", Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]) 12 + MILLIPDS_FIREHOSE_QUEUES_LOCK = web.AppKey("MILLIPDS_FIREHOSE_QUEUES_LOCK", asyncio.Lock) 9 13 10 14 # these helpers are useful for conciseness and type hinting 11 - def get_db(req: web.Request) -> database.Database: 12 - return req.app["MILLIPDS_DB"] 15 + def get_db(req: web.Request): 16 + return req.app[MILLIPDS_DB] 13 17 14 - def get_client(req: web.Request) -> aiohttp.ClientSession: 15 - return req.app["MILLIPDS_AIOHTTP_CLIENT"] 18 + def get_client(req: web.Request): 19 + return req.app[MILLIPDS_AIOHTTP_CLIENT] 16 20 17 - def get_firehose_queues(req: web.Request) -> Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]: 18 - return req.app["MILLIPDS_FIREHOSE_QUEUES"] 21 + def get_firehose_queues(req: web.Request): 22 + return req.app[MILLIPDS_FIREHOSE_QUEUES] 19 23 20 - def get_firehose_queues_lock(req: web.Request) -> asyncio.Lock: 21 - return req.app["MILLIPDS_FIREHOSE_QUEUES_LOCK"] 24 + def get_firehose_queues_lock(req: web.Request): 25 + return req.app[MILLIPDS_FIREHOSE_QUEUES_LOCK]
+9 -9
src/millipds/atproto_repo.py
··· 115 115 @routes.get("/xrpc/com.atproto.repo.describeRepo") 116 116 async def repo_describe_repo(request: web.Request): 117 117 if "repo" not in request.query: 118 - return web.HTTPBadRequest(text="missing repo") 118 + raise web.HTTPBadRequest(text="missing repo") 119 119 did_or_handle = request.query["repo"] 120 120 with get_db(request).new_con(readonly=True) as con: 121 121 user_id, did, handle = con.execute( ··· 141 141 @routes.get("/xrpc/com.atproto.repo.getRecord") 142 142 async def repo_get_record(request: web.Request): 143 143 if "repo" not in request.query: 144 - return web.HTTPBadRequest(text="missing repo") 144 + raise web.HTTPBadRequest(text="missing repo") 145 145 if "collection" not in request.query: 146 - return web.HTTPBadRequest(text="missing collection") 146 + raise web.HTTPBadRequest(text="missing collection") 147 147 if "rkey" not in request.query: 148 - return web.HTTPBadRequest(text="missing rkey") 148 + raise web.HTTPBadRequest(text="missing rkey") 149 149 did_or_handle = request.query["repo"] 150 150 collection = request.query["collection"] 151 151 rkey = request.query["rkey"] ··· 157 157 ).fetchone() 158 158 if row is None: 159 159 return await service_proxy(request) # forward to appview 160 - #return web.HTTPNotFound(text="record not found") 160 + #raise web.HTTPNotFound(text="record not found") 161 161 cid_out, value = row 162 162 cid_out = cbrrr.CID(cid_out) 163 163 if cid_in is not None: 164 164 if cbrrr.CID.decode(cid_in) != cid_out: 165 - return web.HTTPNotFound(text="record not found with matching CID") 165 + raise web.HTTPNotFound(text="record not found with matching CID") 166 166 return web.json_response({ 167 167 "uri": f"at://{did_or_handle}/{collection}/{rkey}", # TODO rejig query to get the did out always, 168 168 "cid": cid_out.encode(), ··· 173 173 @routes.get("/xrpc/com.atproto.repo.listRecords") 174 174 async def repo_list_records(request: web.Request): 175 175 if "repo" not in request.query: 176 - return web.HTTPBadRequest(text="missing repo") 176 + raise web.HTTPBadRequest(text="missing repo") 177 177 if "collection" not in request.query: 178 - return web.HTTPBadRequest(text="missing collection") 178 + raise web.HTTPBadRequest(text="missing collection") 179 179 limit = int(request.query.get("limit", 50)) 180 180 if limit < 1 or limit > 100: 181 - return web.HTTPBadRequest(text="limit out of range") 181 + raise web.HTTPBadRequest(text="limit out of range") 182 182 reverse = request.query.get("reverse") == "true" 183 183 cursor = request.query.get("cursor", "" if reverse else "\xff") 184 184 did_or_handle = request.query["repo"]
+16 -16
src/millipds/atproto_sync.py
··· 20 20 (request.query["did"], bytes(cbrrr.CID.decode(request.query["cid"]))) # TODO: check params exist first, give nicer error 21 21 ).fetchone() 22 22 if blob_id is None: 23 - return web.HTTPNotFound(text="blob not found") 23 + raise web.HTTPNotFound(text="blob not found") 24 24 res = web.StreamResponse(headers={"Content-Disposition": f'attachment; filename="{request.query["cid"]}.bin"'}) 25 25 res.content_type = "application/octet-stream" 26 26 await res.prepare(request) ··· 38 38 async def sync_get_blocks(request: web.Request): 39 39 did = request.query.get("did") 40 40 if did is None: 41 - return web.HTTPBadRequest(text="no did specified") 41 + raise web.HTTPBadRequest(text="no did specified") 42 42 try: 43 43 cids = [bytes(cbrrr.CID.decode(cid)) for cid in request.query.getall("cids")] 44 44 except ValueError: 45 - return web.HTTPBadRequest(text="invalid cid") 45 + raise web.HTTPBadRequest(text="invalid cid") 46 46 db = get_db(request) 47 47 row = db.con.execute("SELECT id FROM user WHERE did=?", (did,)).fetchone() 48 48 if row is None: 49 - return web.HTTPNotFound(text="did not found") 49 + raise web.HTTPNotFound(text="did not found") 50 50 user_id = row[0] 51 51 res = web.StreamResponse() 52 52 res.content_type = "application/vnd.ipld.car" ··· 71 71 async def sync_get_latest_commit(request: web.Request): 72 72 did = request.query.get("did") 73 73 if did is None: 74 - return web.HTTPBadRequest(text="no did specified") 74 + raise web.HTTPBadRequest(text="no did specified") 75 75 row = get_db(request).con.execute( 76 76 "SELECT rev, head FROM user WHERE did=?", 77 77 (did,) 78 78 ).fetchone() 79 79 if row is None: 80 - return web.HTTPNotFound(text="did not found") 80 + raise web.HTTPNotFound(text="did not found") 81 81 rev, head = row 82 82 return web.json_response({ 83 83 "cid": cbrrr.CID(head).encode(), ··· 88 88 @routes.get("/xrpc/com.atproto.sync.getRecord") 89 89 async def sync_get_record(request: web.Request): 90 90 if "did" not in request.query: 91 - return web.HTTPBadRequest(text="missing did") 91 + raise web.HTTPBadRequest(text="missing did") 92 92 if "collection" not in request.query: 93 - return web.HTTPBadRequest(text="missing collection") 93 + raise web.HTTPBadRequest(text="missing collection") 94 94 if "rkey" not in request.query: 95 - return web.HTTPBadRequest(text="missing rkey") 95 + raise web.HTTPBadRequest(text="missing rkey") 96 96 97 97 # we don't stream the response because it should be compact-ish 98 98 car = repo_ops.get_record( ··· 102 102 ) 103 103 104 104 if car is None: 105 - return web.HTTPNotFound(text="did or record not found") 105 + raise web.HTTPNotFound(text="did or record not found") 106 106 107 107 return web.Response( 108 108 body=car, ··· 114 114 async def sync_get_repo_status(request: web.Request): 115 115 did = request.query.get("did") 116 116 if did is None: 117 - return web.HTTPBadRequest(text="no did specified") 117 + raise web.HTTPBadRequest(text="no did specified") 118 118 row = get_db(request).con.execute( 119 119 "SELECT rev FROM user WHERE did=?", 120 120 (did,) 121 121 ).fetchone() 122 122 if row is None: 123 - return web.HTTPNotFound(text="did not found") 123 + raise web.HTTPNotFound(text="did not found") 124 124 return web.json_response({ 125 125 "did": did, 126 126 "active": True, ··· 132 132 async def sync_get_repo(request: web.Request): 133 133 did = request.query.get("did") 134 134 if did is None: 135 - return web.HTTPBadRequest(text="no did specified") 135 + raise web.HTTPBadRequest(text="no did specified") 136 136 since = request.query.get("since", "") # empty string is "lowest" possible value wrt string comparison 137 137 138 138 # TODO: do bad things happen if a client holds the connection open for a long time? ··· 143 143 (did,) 144 144 ).fetchone() 145 145 except TypeError: # from trying to unpack None 146 - return web.HTTPNotFound(text="repo not found") 146 + raise web.HTTPNotFound(text="repo not found") 147 147 148 148 res = web.StreamResponse() 149 149 res.content_type = "application/vnd.ipld.car" ··· 170 170 async def sync_list_blobs(request: web.Request): 171 171 did = request.query.get("did") 172 172 if did is None: 173 - return web.HTTPBadRequest(text="no did specified") 173 + raise web.HTTPBadRequest(text="no did specified") 174 174 since = request.query.get("since", "") # empty string is "lowest" possible value wrt string comparison 175 175 limit = int(request.query.get("limit", 500)) 176 176 if limit < 1 or limit > 1000: 177 - return web.HTTPBadRequest(text="limit out of range") 177 + raise web.HTTPBadRequest(text="limit out of range") 178 178 cursor = int(request.query.get("cursor", 0)) 179 179 180 180 cids = []
+8 -4
src/millipds/auth_bearer.py
··· 9 9 10 10 routes = web.RouteTableDef() 11 11 def authenticated(handler): 12 - def authentication_handler(request: web.Request, *args, **kwargs): 12 + async def authentication_handler(request: web.Request, *args, **kwargs): 13 13 # extract the auth token 14 14 auth = request.headers.get("Authorization") 15 15 if auth is None: ··· 29 29 key=db.config["jwt_access_secret"], 30 30 algorithms=["HS256"], 31 31 audience=db.config["pds_did"], 32 - require=["exp", "scope"], # consider iat? 33 - strict_aud=True, 32 + options={ 33 + "require": ["exp", "iat", "scope"], # consider iat? 34 + "verify_exp": True, 35 + "verify_iat": True, 36 + "strict_aud": True, # may be unnecessary 37 + } 34 38 ) 35 39 except jwt.exceptions.PyJWTError: 36 40 raise web.HTTPUnauthorized(text="invalid jwt") ··· 43 47 if not subject.startswith("did:"): 44 48 raise web.HTTPUnauthorized(text="invalid jwt: invalid subject") 45 49 request["authed_did"] = subject 46 - return handler(request, *args, **kwargs) 50 + return await handler(request, *args, **kwargs) 47 51 48 52 return authentication_handler
+3 -3
src/millipds/auth_oauth.py
··· 114 114 # TODO: verify iat?, iss? 115 115 116 116 if request.method != decoded["htm"]: 117 - return web.HTTPBadRequest( 117 + raise web.HTTPBadRequest( 118 118 text="dpop: bad htm" 119 119 ) 120 120 121 121 if str(request.url) != decoded["htu"]: 122 122 logger.info(f"{request.url!r} != {decoded['htu']!r}") 123 - return web.HTTPBadRequest( 123 + raise web.HTTPBadRequest( 124 124 text="dpop: bad htu (if your application is reverse-proxied, make sure the Host header is getting set properly)" 125 125 ) 126 126 127 127 if decoded.get("nonce") != DPOP_NONCE: 128 - return web.HTTPBadRequest( 128 + raise web.HTTPBadRequest( 129 129 body=json.dumps({ 130 130 "error": "use_dpop_nonce", 131 131 "error_description": "Authorization server requires nonce in DPoP proof"
-3
src/millipds/database.py
··· 314 314 "INSERT INTO mst(repo, cid, since, value) VALUES (?, ?, ?, ?)", 315 315 (user_id, bytes(empty_mst.cid), tid, empty_mst.serialised), 316 316 ) 317 - #util.mkdirs_for_file(repo_path) 318 - #UserDatabase.init_tables(self.con, did, repo_path, tid) 319 - #self.con.execute("DETACH spoke") 320 317 321 318 def verify_account_login( 322 319 self, did_or_handle: str, password: str
+14 -14
src/millipds/service.py
··· 144 144 # TODO: forward to appview(?) if we can't answer? 145 145 handle = request.query.get("handle") 146 146 if handle is None: 147 - return web.HTTPBadRequest(text="missing or invalid handle") 147 + raise web.HTTPBadRequest(text="missing or invalid handle") 148 148 did = get_db(request).did_by_handle(handle) 149 149 if not did: 150 - return web.HTTPNotFound(text="no user by that handle exists on this PDS") 150 + raise web.HTTPNotFound(text="no user by that handle exists on this PDS") 151 151 return web.json_response({"did": did}) 152 152 153 153 ··· 171 171 try: 172 172 req_json: dict = await request.json() 173 173 except json.JSONDecodeError: 174 - return web.HTTPBadRequest(text="expected JSON") 174 + raise web.HTTPBadRequest(text="expected JSON") 175 175 176 176 identifier = req_json.get("identifier") 177 177 password = req_json.get("password") 178 178 if not (isinstance(identifier, str) and isinstance(password, str)): 179 - return web.HTTPBadRequest(text="invalid identifier or password") 179 + raise web.HTTPBadRequest(text="invalid identifier or password") 180 180 181 181 # do authentication 182 182 db = get_db(request) ··· 234 234 req_json: dict = await request.json() 235 235 handle = req_json.get("handle") 236 236 if handle is None: 237 - return web.HTTPBadRequest(text="missing or invalid handle") 237 + raise web.HTTPBadRequest(text="missing or invalid handle") 238 238 # TODO: actually validate it, and update the db!!! 239 239 # (I'm writing this half-baked version just so I can send firehose #identity events) 240 240 with get_db(request).new_con() as con: ··· 291 291 ) 292 292 293 293 294 - def construct_app(routes, db: database.Database) -> web.Application: 294 + def construct_app(routes, db: database.Database, client: aiohttp.ClientSession) -> web.Application: 295 295 cors = cors_middleware( # TODO: review and reduce scope - and maybe just /xrpc/*? 296 296 allow_all=True, 297 297 expose_headers=["*"], ··· 300 300 allow_credentials=True, 301 301 max_age=100_000_000 302 302 ) 303 + 304 + client.headers.update({"User-Agent": importlib.metadata.version("millipds")}) 303 305 304 306 app = web.Application(middlewares=[cors, atproto_service_proxy_middleware]) 305 - app["MILLIPDS_DB"] = db 306 - app["MILLIPDS_AIOHTTP_CLIENT"] = ( 307 - aiohttp.ClientSession() 308 - ) # should this be dependency-injected? 309 - app["MILLIPDS_FIREHOSE_QUEUES"] = set() 310 - app["MILLIPDS_FIREHOSE_QUEUES_LOCK"] = asyncio.Lock() 307 + app[MILLIPDS_DB] = db 308 + app[MILLIPDS_AIOHTTP_CLIENT] = client 309 + app[MILLIPDS_FIREHOSE_QUEUES] = set() 310 + app[MILLIPDS_FIREHOSE_QUEUES_LOCK] = asyncio.Lock() 311 311 app.add_routes(routes) 312 312 app.add_routes(auth_oauth.routes) 313 313 app.add_routes(atproto_sync.routes) ··· 358 358 return app 359 359 360 360 361 - async def run(db: database.Database, sock_path: Optional[str], host: str, port: int): 361 + async def run(db: database.Database, client: aiohttp.ClientSession, sock_path: Optional[str], host: str, port: int): 362 362 """ 363 363 This gets invoked via millipds.__main__.py 364 364 """ 365 365 366 - app = construct_app(routes, db) 366 + app = construct_app(routes, db, client) 367 367 runner = web.AppRunner(app, access_log_format=static_config.HTTP_LOG_FMT) 368 368 await runner.setup() 369 369
+57 -53
tests/integration_test.py
··· 6 6 import pytest 7 7 import dataclasses 8 8 import aiohttp 9 + import aiohttp.web 9 10 10 11 from millipds import service 11 12 from millipds import database 12 13 from millipds import crypto 13 14 14 15 @dataclasses.dataclass 15 - class TestPDS: 16 + class PDSInfo: 16 17 endpoint: str 17 18 db: database.Database 18 19 19 20 old_web_tcpsite_start = aiohttp.web.TCPSite.start 20 21 21 - def make_capture_random_bound_port_web_tcpsite_startstart(queue): 22 - async def mock_start(site, *args, **kwargs): 22 + def make_capture_random_bound_port_web_tcpsite_startstart(queue: asyncio.Queue): 23 + async def mock_start(site: aiohttp.web.TCPSite, *args, **kwargs): 23 24 nonlocal queue 24 25 await old_web_tcpsite_start(site, *args, **kwargs) 25 26 await queue.put(site._server.sockets[0].getsockname()[1]) 26 27 return mock_start 27 28 28 - async def service_run_and_capture_port(queue, **kwargs): 29 + async def service_run_and_capture_port(queue: asyncio.Queue, **kwargs): 29 30 mock_start = make_capture_random_bound_port_web_tcpsite_startstart(queue) 30 31 with unittest.mock.patch.object(aiohttp.web.TCPSite, "start", new=mock_start): 31 32 await service.run(**kwargs) ··· 44 45 async def test_pds(aiolib): 45 46 queue = asyncio.Queue() 46 47 with tempfile.TemporaryDirectory() as tempdir: 47 - db_path = f"{tempdir}/millipds-0000.db" 48 - db = database.Database(path=db_path) 49 - 50 - hostname = "localhost:0" 51 - db.update_config( 52 - pds_pfx=f'http://{hostname}', 53 - pds_did=f'did:web:{urllib.parse.quote(hostname)}', 54 - bsky_appview_pfx="https://api.bsky.app", 55 - bsky_appview_did="did:web:api.bsky.app", 56 - ) 48 + async with aiohttp.ClientSession() as client: 49 + db_path = f"{tempdir}/millipds-0000.db" 50 + db = database.Database(path=db_path) 57 51 58 - service_run_task = asyncio.create_task( 59 - service_run_and_capture_port( 60 - queue, 61 - db=db, 62 - sock_path=None, 63 - host="localhost", 64 - port=0, 52 + hostname = "localhost:0" 53 + db.update_config( 54 + pds_pfx=f'http://{hostname}', 55 + pds_did=f'did:web:{urllib.parse.quote(hostname)}', 56 + bsky_appview_pfx="https://api.bsky.app", 57 + bsky_appview_did="did:web:api.bsky.app", 65 58 ) 66 - ) 67 - queue_get_task = asyncio.create_task(queue.get()) 68 - done, pending = await asyncio.wait( 69 - (queue_get_task, service_run_task), 70 - return_when=asyncio.FIRST_COMPLETED, 71 - ) 72 - if done == service_run_task: 73 - raise service_run_task.execption() 74 - else: 75 - port = queue_get_task.result() 76 59 77 - hostname = f"localhost:{port}" 78 - db.update_config( 79 - pds_pfx=f'http://{hostname}', 80 - pds_did=f'did:web:{urllib.parse.quote(hostname)}', 81 - bsky_appview_pfx="https://api.bsky.app", 82 - bsky_appview_did="did:web:api.bsky.app", 83 - ) 84 - db.create_account( 85 - did=TEST_DID, 86 - handle=TEST_HANDLE, 87 - password=TEST_PASSWORD, 88 - privkey=TEST_PRIVKEY, 89 - ) 60 + service_run_task = asyncio.create_task( 61 + service_run_and_capture_port( 62 + queue, 63 + db=db, 64 + client=client, 65 + sock_path=None, 66 + host="localhost", 67 + port=0, 68 + ) 69 + ) 70 + queue_get_task = asyncio.create_task(queue.get()) 71 + done, pending = await asyncio.wait( 72 + (queue_get_task, service_run_task), 73 + return_when=asyncio.FIRST_COMPLETED, 74 + ) 75 + if done == service_run_task: 76 + raise service_run_task.execption() 77 + else: 78 + port = queue_get_task.result() 90 79 91 - try: 92 - yield TestPDS( 93 - endpoint=f"http://{hostname}", 94 - db=db, 80 + hostname = f"localhost:{port}" 81 + db.update_config( 82 + pds_pfx=f'http://{hostname}', 83 + pds_did=f'did:web:{urllib.parse.quote(hostname)}', 84 + bsky_appview_pfx="https://api.bsky.app", 85 + bsky_appview_did="did:web:api.bsky.app", 95 86 ) 96 - finally: 97 - service_run_task.cancel() 87 + db.create_account( 88 + did=TEST_DID, 89 + handle=TEST_HANDLE, 90 + password=TEST_PASSWORD, 91 + privkey=TEST_PRIVKEY, 92 + ) 93 + 98 94 try: 99 - await service_run_task 100 - except asyncio.CancelledError: 101 - pass 95 + yield PDSInfo( 96 + endpoint=f"http://{hostname}", 97 + db=db, 98 + ) 99 + finally: 100 + db.con.close() 101 + service_run_task.cancel() 102 + try: 103 + await service_run_task 104 + except asyncio.CancelledError: 105 + pass 102 106 103 107 @pytest.fixture 104 108 async def s(aiolib):