+5
-2
.vscode/settings.json
+5
-2
.vscode/settings.json
+3
pyproject.toml
+3
pyproject.toml
+11
-8
src/millipds/__main__.py
+11
-8
src/millipds/__main__.py
···
67
67
from getpass import getpass
68
68
69
69
from docopt import docopt
70
+
import aiohttp
70
71
71
72
import cbrrr
72
73
···
215
216
else:
216
217
print("invalid account subcommand")
217
218
elif args["run"]:
218
-
asyncio.run(
219
-
service.run(
220
-
db=db,
221
-
sock_path=args["--sock_path"],
222
-
host=args["--listen_host"],
223
-
port=int(args["--listen_port"]),
224
-
)
225
-
)
219
+
async def run_with_client():
220
+
async with aiohttp.ClientSession() as client:
221
+
await service.run(
222
+
db=db,
223
+
client=client,
224
+
sock_path=args["--sock_path"],
225
+
host=args["--listen_host"],
226
+
port=int(args["--listen_port"]),
227
+
)
228
+
asyncio.run(run_with_client())
226
229
else:
227
230
print("CLI arg parse error?!")
228
231
+12
-8
src/millipds/app_util.py
+12
-8
src/millipds/app_util.py
···
6
6
7
7
from . import database
8
8
9
+
MILLIPDS_DB = web.AppKey("MILLIPDS_DB", database.Database)
10
+
MILLIPDS_AIOHTTP_CLIENT = web.AppKey("MILLIPDS_AIOHTTP_CLIENT", aiohttp.ClientSession)
11
+
MILLIPDS_FIREHOSE_QUEUES = web.AppKey("MILLIPDS_FIREHOSE_QUEUES", Set[asyncio.Queue[Optional[Tuple[int, bytes]]]])
12
+
MILLIPDS_FIREHOSE_QUEUES_LOCK = web.AppKey("MILLIPDS_FIREHOSE_QUEUES_LOCK", asyncio.Lock)
9
13
10
14
# these helpers are useful for conciseness and type hinting
11
-
def get_db(req: web.Request) -> database.Database:
12
-
return req.app["MILLIPDS_DB"]
15
+
def get_db(req: web.Request):
16
+
return req.app[MILLIPDS_DB]
13
17
14
-
def get_client(req: web.Request) -> aiohttp.ClientSession:
15
-
return req.app["MILLIPDS_AIOHTTP_CLIENT"]
18
+
def get_client(req: web.Request):
19
+
return req.app[MILLIPDS_AIOHTTP_CLIENT]
16
20
17
-
def get_firehose_queues(req: web.Request) -> Set[asyncio.Queue[Optional[Tuple[int, bytes]]]]:
18
-
return req.app["MILLIPDS_FIREHOSE_QUEUES"]
21
+
def get_firehose_queues(req: web.Request):
22
+
return req.app[MILLIPDS_FIREHOSE_QUEUES]
19
23
20
-
def get_firehose_queues_lock(req: web.Request) -> asyncio.Lock:
21
-
return req.app["MILLIPDS_FIREHOSE_QUEUES_LOCK"]
24
+
def get_firehose_queues_lock(req: web.Request):
25
+
return req.app[MILLIPDS_FIREHOSE_QUEUES_LOCK]
+9
-9
src/millipds/atproto_repo.py
+9
-9
src/millipds/atproto_repo.py
···
115
115
@routes.get("/xrpc/com.atproto.repo.describeRepo")
116
116
async def repo_describe_repo(request: web.Request):
117
117
if "repo" not in request.query:
118
-
return web.HTTPBadRequest(text="missing repo")
118
+
raise web.HTTPBadRequest(text="missing repo")
119
119
did_or_handle = request.query["repo"]
120
120
with get_db(request).new_con(readonly=True) as con:
121
121
user_id, did, handle = con.execute(
···
141
141
@routes.get("/xrpc/com.atproto.repo.getRecord")
142
142
async def repo_get_record(request: web.Request):
143
143
if "repo" not in request.query:
144
-
return web.HTTPBadRequest(text="missing repo")
144
+
raise web.HTTPBadRequest(text="missing repo")
145
145
if "collection" not in request.query:
146
-
return web.HTTPBadRequest(text="missing collection")
146
+
raise web.HTTPBadRequest(text="missing collection")
147
147
if "rkey" not in request.query:
148
-
return web.HTTPBadRequest(text="missing rkey")
148
+
raise web.HTTPBadRequest(text="missing rkey")
149
149
did_or_handle = request.query["repo"]
150
150
collection = request.query["collection"]
151
151
rkey = request.query["rkey"]
···
157
157
).fetchone()
158
158
if row is None:
159
159
return await service_proxy(request) # forward to appview
160
-
#return web.HTTPNotFound(text="record not found")
160
+
#raise web.HTTPNotFound(text="record not found")
161
161
cid_out, value = row
162
162
cid_out = cbrrr.CID(cid_out)
163
163
if cid_in is not None:
164
164
if cbrrr.CID.decode(cid_in) != cid_out:
165
-
return web.HTTPNotFound(text="record not found with matching CID")
165
+
raise web.HTTPNotFound(text="record not found with matching CID")
166
166
return web.json_response({
167
167
"uri": f"at://{did_or_handle}/{collection}/{rkey}", # TODO rejig query to get the did out always,
168
168
"cid": cid_out.encode(),
···
173
173
@routes.get("/xrpc/com.atproto.repo.listRecords")
174
174
async def repo_list_records(request: web.Request):
175
175
if "repo" not in request.query:
176
-
return web.HTTPBadRequest(text="missing repo")
176
+
raise web.HTTPBadRequest(text="missing repo")
177
177
if "collection" not in request.query:
178
-
return web.HTTPBadRequest(text="missing collection")
178
+
raise web.HTTPBadRequest(text="missing collection")
179
179
limit = int(request.query.get("limit", 50))
180
180
if limit < 1 or limit > 100:
181
-
return web.HTTPBadRequest(text="limit out of range")
181
+
raise web.HTTPBadRequest(text="limit out of range")
182
182
reverse = request.query.get("reverse") == "true"
183
183
cursor = request.query.get("cursor", "" if reverse else "\xff")
184
184
did_or_handle = request.query["repo"]
+16
-16
src/millipds/atproto_sync.py
+16
-16
src/millipds/atproto_sync.py
···
20
20
(request.query["did"], bytes(cbrrr.CID.decode(request.query["cid"]))) # TODO: check params exist first, give nicer error
21
21
).fetchone()
22
22
if blob_id is None:
23
-
return web.HTTPNotFound(text="blob not found")
23
+
raise web.HTTPNotFound(text="blob not found")
24
24
res = web.StreamResponse(headers={"Content-Disposition": f'attachment; filename="{request.query["cid"]}.bin"'})
25
25
res.content_type = "application/octet-stream"
26
26
await res.prepare(request)
···
38
38
async def sync_get_blocks(request: web.Request):
39
39
did = request.query.get("did")
40
40
if did is None:
41
-
return web.HTTPBadRequest(text="no did specified")
41
+
raise web.HTTPBadRequest(text="no did specified")
42
42
try:
43
43
cids = [bytes(cbrrr.CID.decode(cid)) for cid in request.query.getall("cids")]
44
44
except ValueError:
45
-
return web.HTTPBadRequest(text="invalid cid")
45
+
raise web.HTTPBadRequest(text="invalid cid")
46
46
db = get_db(request)
47
47
row = db.con.execute("SELECT id FROM user WHERE did=?", (did,)).fetchone()
48
48
if row is None:
49
-
return web.HTTPNotFound(text="did not found")
49
+
raise web.HTTPNotFound(text="did not found")
50
50
user_id = row[0]
51
51
res = web.StreamResponse()
52
52
res.content_type = "application/vnd.ipld.car"
···
71
71
async def sync_get_latest_commit(request: web.Request):
72
72
did = request.query.get("did")
73
73
if did is None:
74
-
return web.HTTPBadRequest(text="no did specified")
74
+
raise web.HTTPBadRequest(text="no did specified")
75
75
row = get_db(request).con.execute(
76
76
"SELECT rev, head FROM user WHERE did=?",
77
77
(did,)
78
78
).fetchone()
79
79
if row is None:
80
-
return web.HTTPNotFound(text="did not found")
80
+
raise web.HTTPNotFound(text="did not found")
81
81
rev, head = row
82
82
return web.json_response({
83
83
"cid": cbrrr.CID(head).encode(),
···
88
88
@routes.get("/xrpc/com.atproto.sync.getRecord")
89
89
async def sync_get_record(request: web.Request):
90
90
if "did" not in request.query:
91
-
return web.HTTPBadRequest(text="missing did")
91
+
raise web.HTTPBadRequest(text="missing did")
92
92
if "collection" not in request.query:
93
-
return web.HTTPBadRequest(text="missing collection")
93
+
raise web.HTTPBadRequest(text="missing collection")
94
94
if "rkey" not in request.query:
95
-
return web.HTTPBadRequest(text="missing rkey")
95
+
raise web.HTTPBadRequest(text="missing rkey")
96
96
97
97
# we don't stream the response because it should be compact-ish
98
98
car = repo_ops.get_record(
···
102
102
)
103
103
104
104
if car is None:
105
-
return web.HTTPNotFound(text="did or record not found")
105
+
raise web.HTTPNotFound(text="did or record not found")
106
106
107
107
return web.Response(
108
108
body=car,
···
114
114
async def sync_get_repo_status(request: web.Request):
115
115
did = request.query.get("did")
116
116
if did is None:
117
-
return web.HTTPBadRequest(text="no did specified")
117
+
raise web.HTTPBadRequest(text="no did specified")
118
118
row = get_db(request).con.execute(
119
119
"SELECT rev FROM user WHERE did=?",
120
120
(did,)
121
121
).fetchone()
122
122
if row is None:
123
-
return web.HTTPNotFound(text="did not found")
123
+
raise web.HTTPNotFound(text="did not found")
124
124
return web.json_response({
125
125
"did": did,
126
126
"active": True,
···
132
132
async def sync_get_repo(request: web.Request):
133
133
did = request.query.get("did")
134
134
if did is None:
135
-
return web.HTTPBadRequest(text="no did specified")
135
+
raise web.HTTPBadRequest(text="no did specified")
136
136
since = request.query.get("since", "") # empty string is "lowest" possible value wrt string comparison
137
137
138
138
# TODO: do bad things happen if a client holds the connection open for a long time?
···
143
143
(did,)
144
144
).fetchone()
145
145
except TypeError: # from trying to unpack None
146
-
return web.HTTPNotFound(text="repo not found")
146
+
raise web.HTTPNotFound(text="repo not found")
147
147
148
148
res = web.StreamResponse()
149
149
res.content_type = "application/vnd.ipld.car"
···
170
170
async def sync_list_blobs(request: web.Request):
171
171
did = request.query.get("did")
172
172
if did is None:
173
-
return web.HTTPBadRequest(text="no did specified")
173
+
raise web.HTTPBadRequest(text="no did specified")
174
174
since = request.query.get("since", "") # empty string is "lowest" possible value wrt string comparison
175
175
limit = int(request.query.get("limit", 500))
176
176
if limit < 1 or limit > 1000:
177
-
return web.HTTPBadRequest(text="limit out of range")
177
+
raise web.HTTPBadRequest(text="limit out of range")
178
178
cursor = int(request.query.get("cursor", 0))
179
179
180
180
cids = []
+8
-4
src/millipds/auth_bearer.py
+8
-4
src/millipds/auth_bearer.py
···
9
9
10
10
routes = web.RouteTableDef()
11
11
def authenticated(handler):
12
-
def authentication_handler(request: web.Request, *args, **kwargs):
12
+
async def authentication_handler(request: web.Request, *args, **kwargs):
13
13
# extract the auth token
14
14
auth = request.headers.get("Authorization")
15
15
if auth is None:
···
29
29
key=db.config["jwt_access_secret"],
30
30
algorithms=["HS256"],
31
31
audience=db.config["pds_did"],
32
-
require=["exp", "scope"], # consider iat?
33
-
strict_aud=True,
32
+
options={
33
+
"require": ["exp", "iat", "scope"], # consider iat?
34
+
"verify_exp": True,
35
+
"verify_iat": True,
36
+
"strict_aud": True, # may be unnecessary
37
+
}
34
38
)
35
39
except jwt.exceptions.PyJWTError:
36
40
raise web.HTTPUnauthorized(text="invalid jwt")
···
43
47
if not subject.startswith("did:"):
44
48
raise web.HTTPUnauthorized(text="invalid jwt: invalid subject")
45
49
request["authed_did"] = subject
46
-
return handler(request, *args, **kwargs)
50
+
return await handler(request, *args, **kwargs)
47
51
48
52
return authentication_handler
+3
-3
src/millipds/auth_oauth.py
+3
-3
src/millipds/auth_oauth.py
···
114
114
# TODO: verify iat?, iss?
115
115
116
116
if request.method != decoded["htm"]:
117
-
return web.HTTPBadRequest(
117
+
raise web.HTTPBadRequest(
118
118
text="dpop: bad htm"
119
119
)
120
120
121
121
if str(request.url) != decoded["htu"]:
122
122
logger.info(f"{request.url!r} != {decoded['htu']!r}")
123
-
return web.HTTPBadRequest(
123
+
raise web.HTTPBadRequest(
124
124
text="dpop: bad htu (if your application is reverse-proxied, make sure the Host header is getting set properly)"
125
125
)
126
126
127
127
if decoded.get("nonce") != DPOP_NONCE:
128
-
return web.HTTPBadRequest(
128
+
raise web.HTTPBadRequest(
129
129
body=json.dumps({
130
130
"error": "use_dpop_nonce",
131
131
"error_description": "Authorization server requires nonce in DPoP proof"
-3
src/millipds/database.py
-3
src/millipds/database.py
···
314
314
"INSERT INTO mst(repo, cid, since, value) VALUES (?, ?, ?, ?)",
315
315
(user_id, bytes(empty_mst.cid), tid, empty_mst.serialised),
316
316
)
317
-
#util.mkdirs_for_file(repo_path)
318
-
#UserDatabase.init_tables(self.con, did, repo_path, tid)
319
-
#self.con.execute("DETACH spoke")
320
317
321
318
def verify_account_login(
322
319
self, did_or_handle: str, password: str
+14
-14
src/millipds/service.py
+14
-14
src/millipds/service.py
···
144
144
# TODO: forward to appview(?) if we can't answer?
145
145
handle = request.query.get("handle")
146
146
if handle is None:
147
-
return web.HTTPBadRequest(text="missing or invalid handle")
147
+
raise web.HTTPBadRequest(text="missing or invalid handle")
148
148
did = get_db(request).did_by_handle(handle)
149
149
if not did:
150
-
return web.HTTPNotFound(text="no user by that handle exists on this PDS")
150
+
raise web.HTTPNotFound(text="no user by that handle exists on this PDS")
151
151
return web.json_response({"did": did})
152
152
153
153
···
171
171
try:
172
172
req_json: dict = await request.json()
173
173
except json.JSONDecodeError:
174
-
return web.HTTPBadRequest(text="expected JSON")
174
+
raise web.HTTPBadRequest(text="expected JSON")
175
175
176
176
identifier = req_json.get("identifier")
177
177
password = req_json.get("password")
178
178
if not (isinstance(identifier, str) and isinstance(password, str)):
179
-
return web.HTTPBadRequest(text="invalid identifier or password")
179
+
raise web.HTTPBadRequest(text="invalid identifier or password")
180
180
181
181
# do authentication
182
182
db = get_db(request)
···
234
234
req_json: dict = await request.json()
235
235
handle = req_json.get("handle")
236
236
if handle is None:
237
-
return web.HTTPBadRequest(text="missing or invalid handle")
237
+
raise web.HTTPBadRequest(text="missing or invalid handle")
238
238
# TODO: actually validate it, and update the db!!!
239
239
# (I'm writing this half-baked version just so I can send firehose #identity events)
240
240
with get_db(request).new_con() as con:
···
291
291
)
292
292
293
293
294
-
def construct_app(routes, db: database.Database) -> web.Application:
294
+
def construct_app(routes, db: database.Database, client: aiohttp.ClientSession) -> web.Application:
295
295
cors = cors_middleware( # TODO: review and reduce scope - and maybe just /xrpc/*?
296
296
allow_all=True,
297
297
expose_headers=["*"],
···
300
300
allow_credentials=True,
301
301
max_age=100_000_000
302
302
)
303
+
304
+
client.headers.update({"User-Agent": importlib.metadata.version("millipds")})
303
305
304
306
app = web.Application(middlewares=[cors, atproto_service_proxy_middleware])
305
-
app["MILLIPDS_DB"] = db
306
-
app["MILLIPDS_AIOHTTP_CLIENT"] = (
307
-
aiohttp.ClientSession()
308
-
) # should this be dependency-injected?
309
-
app["MILLIPDS_FIREHOSE_QUEUES"] = set()
310
-
app["MILLIPDS_FIREHOSE_QUEUES_LOCK"] = asyncio.Lock()
307
+
app[MILLIPDS_DB] = db
308
+
app[MILLIPDS_AIOHTTP_CLIENT] = client
309
+
app[MILLIPDS_FIREHOSE_QUEUES] = set()
310
+
app[MILLIPDS_FIREHOSE_QUEUES_LOCK] = asyncio.Lock()
311
311
app.add_routes(routes)
312
312
app.add_routes(auth_oauth.routes)
313
313
app.add_routes(atproto_sync.routes)
···
358
358
return app
359
359
360
360
361
-
async def run(db: database.Database, sock_path: Optional[str], host: str, port: int):
361
+
async def run(db: database.Database, client: aiohttp.ClientSession, sock_path: Optional[str], host: str, port: int):
362
362
"""
363
363
This gets invoked via millipds.__main__.py
364
364
"""
365
365
366
-
app = construct_app(routes, db)
366
+
app = construct_app(routes, db, client)
367
367
runner = web.AppRunner(app, access_log_format=static_config.HTTP_LOG_FMT)
368
368
await runner.setup()
369
369
+57
-53
tests/integration_test.py
+57
-53
tests/integration_test.py
···
6
6
import pytest
7
7
import dataclasses
8
8
import aiohttp
9
+
import aiohttp.web
9
10
10
11
from millipds import service
11
12
from millipds import database
12
13
from millipds import crypto
13
14
14
15
@dataclasses.dataclass
15
-
class TestPDS:
16
+
class PDSInfo:
16
17
endpoint: str
17
18
db: database.Database
18
19
19
20
old_web_tcpsite_start = aiohttp.web.TCPSite.start
20
21
21
-
def make_capture_random_bound_port_web_tcpsite_startstart(queue):
22
-
async def mock_start(site, *args, **kwargs):
22
+
def make_capture_random_bound_port_web_tcpsite_startstart(queue: asyncio.Queue):
23
+
async def mock_start(site: aiohttp.web.TCPSite, *args, **kwargs):
23
24
nonlocal queue
24
25
await old_web_tcpsite_start(site, *args, **kwargs)
25
26
await queue.put(site._server.sockets[0].getsockname()[1])
26
27
return mock_start
27
28
28
-
async def service_run_and_capture_port(queue, **kwargs):
29
+
async def service_run_and_capture_port(queue: asyncio.Queue, **kwargs):
29
30
mock_start = make_capture_random_bound_port_web_tcpsite_startstart(queue)
30
31
with unittest.mock.patch.object(aiohttp.web.TCPSite, "start", new=mock_start):
31
32
await service.run(**kwargs)
···
44
45
async def test_pds(aiolib):
45
46
queue = asyncio.Queue()
46
47
with tempfile.TemporaryDirectory() as tempdir:
47
-
db_path = f"{tempdir}/millipds-0000.db"
48
-
db = database.Database(path=db_path)
49
-
50
-
hostname = "localhost:0"
51
-
db.update_config(
52
-
pds_pfx=f'http://{hostname}',
53
-
pds_did=f'did:web:{urllib.parse.quote(hostname)}',
54
-
bsky_appview_pfx="https://api.bsky.app",
55
-
bsky_appview_did="did:web:api.bsky.app",
56
-
)
48
+
async with aiohttp.ClientSession() as client:
49
+
db_path = f"{tempdir}/millipds-0000.db"
50
+
db = database.Database(path=db_path)
57
51
58
-
service_run_task = asyncio.create_task(
59
-
service_run_and_capture_port(
60
-
queue,
61
-
db=db,
62
-
sock_path=None,
63
-
host="localhost",
64
-
port=0,
52
+
hostname = "localhost:0"
53
+
db.update_config(
54
+
pds_pfx=f'http://{hostname}',
55
+
pds_did=f'did:web:{urllib.parse.quote(hostname)}',
56
+
bsky_appview_pfx="https://api.bsky.app",
57
+
bsky_appview_did="did:web:api.bsky.app",
65
58
)
66
-
)
67
-
queue_get_task = asyncio.create_task(queue.get())
68
-
done, pending = await asyncio.wait(
69
-
(queue_get_task, service_run_task),
70
-
return_when=asyncio.FIRST_COMPLETED,
71
-
)
72
-
if done == service_run_task:
73
-
raise service_run_task.execption()
74
-
else:
75
-
port = queue_get_task.result()
76
59
77
-
hostname = f"localhost:{port}"
78
-
db.update_config(
79
-
pds_pfx=f'http://{hostname}',
80
-
pds_did=f'did:web:{urllib.parse.quote(hostname)}',
81
-
bsky_appview_pfx="https://api.bsky.app",
82
-
bsky_appview_did="did:web:api.bsky.app",
83
-
)
84
-
db.create_account(
85
-
did=TEST_DID,
86
-
handle=TEST_HANDLE,
87
-
password=TEST_PASSWORD,
88
-
privkey=TEST_PRIVKEY,
89
-
)
60
+
service_run_task = asyncio.create_task(
61
+
service_run_and_capture_port(
62
+
queue,
63
+
db=db,
64
+
client=client,
65
+
sock_path=None,
66
+
host="localhost",
67
+
port=0,
68
+
)
69
+
)
70
+
queue_get_task = asyncio.create_task(queue.get())
71
+
done, pending = await asyncio.wait(
72
+
(queue_get_task, service_run_task),
73
+
return_when=asyncio.FIRST_COMPLETED,
74
+
)
75
+
if done == service_run_task:
76
+
raise service_run_task.execption()
77
+
else:
78
+
port = queue_get_task.result()
90
79
91
-
try:
92
-
yield TestPDS(
93
-
endpoint=f"http://{hostname}",
94
-
db=db,
80
+
hostname = f"localhost:{port}"
81
+
db.update_config(
82
+
pds_pfx=f'http://{hostname}',
83
+
pds_did=f'did:web:{urllib.parse.quote(hostname)}',
84
+
bsky_appview_pfx="https://api.bsky.app",
85
+
bsky_appview_did="did:web:api.bsky.app",
95
86
)
96
-
finally:
97
-
service_run_task.cancel()
87
+
db.create_account(
88
+
did=TEST_DID,
89
+
handle=TEST_HANDLE,
90
+
password=TEST_PASSWORD,
91
+
privkey=TEST_PRIVKEY,
92
+
)
93
+
98
94
try:
99
-
await service_run_task
100
-
except asyncio.CancelledError:
101
-
pass
95
+
yield PDSInfo(
96
+
endpoint=f"http://{hostname}",
97
+
db=db,
98
+
)
99
+
finally:
100
+
db.con.close()
101
+
service_run_task.cancel()
102
+
try:
103
+
await service_run_task
104
+
except asyncio.CancelledError:
105
+
pass
102
106
103
107
@pytest.fixture
104
108
async def s(aiolib):