A from-scratch atproto PDS implementation in Python (mirrors https://github.com/DavidBuchanan314/millipds)
10
fork

Configure Feed

Select the types of activity you want to include in your feed.

Merge pull request #38 from DavidBuchanan314/token-revocation

refreshSession, deleteSession

authored by retr0.id and committed by

GitHub 0b061d0d 9d467bdc

+304 -74
+8 -2
migration_scripts/v2.py
··· 7 7 8 8 from millipds import static_config 9 9 10 - with apsw.Connection(static_config.MAIN_DB_PATH) as con: 10 + 11 + def migrate(con): 11 12 version_now, *_ = con.execute("SELECT db_version FROM config").fetchone() 12 13 13 14 assert version_now == 1 ··· 36 37 37 38 con.execute("UPDATE config SET db_version=2") 38 39 39 - print("v1 -> v2 Migration successful") 40 + 41 + if __name__ == "__main__": 42 + with apsw.Connection(static_config.MAIN_DB_PATH) as con: 43 + migrate(con) 44 + 45 + print("v1 -> v2 Migration successful")
+34
migration_scripts/v3.py
··· 1 + # TODO: some smarter way of handling migrations 2 + 3 + import apsw 4 + import apsw.bestpractice 5 + 6 + apsw.bestpractice.apply(apsw.bestpractice.recommended) 7 + 8 + from millipds import static_config 9 + 10 + 11 + def migrate(con: apsw.Connection): 12 + version_now, *_ = con.execute("SELECT db_version FROM config").fetchone() 13 + 14 + assert version_now == 2 15 + 16 + con.execute( 17 + """ 18 + CREATE TABLE revoked_token( 19 + did TEXT NOT NULL, 20 + jti TEXT NOT NULL, 21 + expires_at INTEGER NOT NULL, 22 + PRIMARY KEY (did, jti) 23 + ) STRICT, WITHOUT ROWID 24 + """ 25 + ) 26 + 27 + con.execute("UPDATE config SET db_version=3") 28 + 29 + 30 + if __name__ == "__main__": 31 + with apsw.Connection(static_config.MAIN_DB_PATH) as con: 32 + migrate(con) 33 + 34 + print("v2 -> v3 Migration successful")
+7 -5
src/millipds/appview_proxy.py
··· 29 29 ) 30 30 if did_doc is None: 31 31 return web.HTTPInternalServerError( 32 - f"unable to resolve service {service!r}" 32 + text=f"unable to resolve service {service!r}" 33 33 ) 34 - for service in did_doc.get("service", []): 35 - if service.get("id") == fragment: 36 - service_route = service["serviceEndpoint"] 34 + for service_info in did_doc.get("service", []): 35 + if service_info.get("id") == fragment: 36 + service_route = service_info["serviceEndpoint"] 37 37 break 38 38 else: 39 - return web.HTTPBadRequest(f"unable to resolve service {service!r}") 39 + return web.HTTPBadRequest( 40 + text=f"unable to resolve service {service!r}" 41 + ) 40 42 else: # fall thru to assuming bsky appview 41 43 service_did = db.config["bsky_appview_did"] 42 44 service_route = db.config["bsky_appview_pfx"]
+50 -25
src/millipds/auth_bearer.py
··· 11 11 routes = web.RouteTableDef() 12 12 13 13 14 + def verify_symmetric_token( 15 + request: web.Request, token: str, expected_scope: str 16 + ) -> dict: 17 + db = get_db(request) 18 + try: 19 + payload: dict = jwt.decode( 20 + jwt=token, 21 + key=db.config["jwt_access_secret"], 22 + algorithms=["HS256"], 23 + audience=db.config["pds_did"], 24 + options={ 25 + "require": ["exp", "iat", "scope", "jti", "sub"], 26 + "verify_exp": True, 27 + "verify_iat": True, 28 + "strict_aud": True, # may be unnecessary 29 + }, 30 + ) 31 + except jwt.exceptions.PyJWTError: 32 + raise web.HTTPUnauthorized(text="invalid jwt") 33 + 34 + revoked = db.con.execute( 35 + "SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?", 36 + (payload["sub"], payload["jti"]), 37 + ).fetchone()[0] 38 + 39 + if revoked: 40 + raise web.HTTPUnauthorized(text="revoked token") 41 + 42 + # if we reached this far, the payload must've been signed by us 43 + if payload.get("scope") != expected_scope: 44 + raise web.HTTPUnauthorized(text="invalid jwt scope") 45 + 46 + if not payload.get("sub", "").startswith("did:"): 47 + raise web.HTTPUnauthorized(text="invalid jwt: invalid subject") 48 + 49 + return payload 50 + 51 + 14 52 def authenticated(handler): 15 53 """ 16 54 There are three types of auth: ··· 39 77 ) 40 78 # logger.info(unverified) 41 79 if unverified["header"]["alg"] == "HS256": # symmetric secret 42 - try: 43 - payload: dict = jwt.decode( 44 - jwt=token, 45 - key=db.config["jwt_access_secret"], 46 - algorithms=["HS256"], 47 - audience=db.config["pds_did"], 48 - options={ 49 - "require": ["exp", "iat", "scope"], # consider iat? 50 - "verify_exp": True, 51 - "verify_iat": True, 52 - "strict_aud": True, # may be unnecessary 53 - }, 54 - ) 55 - except jwt.exceptions.PyJWTError: 56 - raise web.HTTPUnauthorized(text="invalid jwt") 57 - 58 - # if we reached this far, the payload must've been signed by us 59 - if payload.get("scope") != "com.atproto.access": 60 - raise web.HTTPUnauthorized(text="invalid jwt scope") 61 - 62 - subject: str = payload.get("sub", "") 63 - if not subject.startswith("did:"): 64 - raise web.HTTPUnauthorized(text="invalid jwt: invalid subject") 65 - request["authed_did"] = subject 80 + request["authed_did"] = verify_symmetric_token( 81 + request, token, "com.atproto.access" 82 + )["sub"] 66 83 else: # asymmetric service auth (scoped to a specific lxm) 67 84 did: str = unverified["payload"]["iss"] 68 85 if not did.startswith("did:"): ··· 81 98 algorithms=[alg], 82 99 audience=db.config["pds_did"], 83 100 options={ 84 - "require": ["exp", "iat", "lxm"], 101 + "require": ["exp", "iat", "lxm", "jti", "iss"], 85 102 "verify_exp": True, 86 103 "verify_iat": True, 87 104 "strict_aud": True, # may be unnecessary ··· 89 106 ) 90 107 except jwt.exceptions.PyJWTError: 91 108 raise web.HTTPUnauthorized(text="invalid jwt") 109 + 110 + revoked = db.con.execute( 111 + "SELECT COUNT(*) FROM revoked_token WHERE did=? AND jti=?", 112 + (payload["iss"], payload["jti"]), 113 + ).fetchone()[0] 114 + 115 + if revoked: 116 + raise web.HTTPUnauthorized(text="revoked token") 92 117 93 118 request_lxm = request.path.rpartition("/")[2].partition("?")[0] 94 119 if request_lxm != payload.get("lxm"):
+13
src/millipds/database.py
··· 245 245 """ 246 246 ) 247 247 248 + # this is only for the tokens *we* issue, dpop jti will be tracked separately 249 + # there's no point remembering that an expired token was revoked, and we'll garbage-collect these periodically 250 + self.con.execute( 251 + """ 252 + CREATE TABLE revoked_token( 253 + did TEXT NOT NULL, 254 + jti TEXT NOT NULL, 255 + expires_at INTEGER NOT NULL, 256 + PRIMARY KEY (did, jti) 257 + ) STRICT, WITHOUT ROWID 258 + """ 259 + ) 260 + 248 261 def update_config( 249 262 self, 250 263 pds_pfx: Optional[str] = None,
+89 -41
src/millipds/service.py
··· 25 25 from . import crypto 26 26 from . import util 27 27 from .appview_proxy import service_proxy 28 - from .auth_bearer import authenticated 28 + from .auth_bearer import authenticated, verify_symmetric_token 29 29 from .app_util import * 30 30 from .did import DIDResolver 31 31 ··· 203 203 ) 204 204 205 205 206 + def session_info(request: web.Request) -> dict: 207 + return { 208 + "handle": get_db(request).handle_by_did(request["authed_did"]), 209 + "did": request["authed_did"], 210 + "email": "tfw_no@email.invalid", # this and below are just here for testing lol 211 + "emailConfirmed": True, 212 + # "didDoc": {}, # iiuc this is only used for entryway usecase? 213 + } 214 + 215 + 216 + def generate_session_tokens(request: web.Request) -> dict: 217 + db = get_db(request) 218 + unix_seconds_now = int(time.time()) 219 + # use the same jti for both tokens, so revoking one revokes both 220 + jti = str(uuid.uuid4()) 221 + access_jwt = jwt.encode( 222 + { 223 + "scope": "com.atproto.access", 224 + "aud": db.config["pds_did"], 225 + "sub": request["authed_did"], 226 + "iat": unix_seconds_now, 227 + "exp": unix_seconds_now + static_config.ACCESS_EXP, 228 + "jti": jti, 229 + }, 230 + db.config["jwt_access_secret"], 231 + "HS256", 232 + ) 233 + 234 + refresh_jwt = jwt.encode( 235 + { 236 + "scope": "com.atproto.refresh", 237 + "aud": db.config["pds_did"], 238 + "sub": request["authed_did"], 239 + "iat": unix_seconds_now, 240 + "exp": unix_seconds_now + static_config.REFRESH_EXP, 241 + "jti": jti, 242 + }, 243 + db.config["jwt_access_secret"], 244 + "HS256", 245 + ) 246 + 247 + return { 248 + "accessJwt": access_jwt, 249 + "refreshJwt": refresh_jwt, 250 + } 251 + 252 + 206 253 # TODO: ratelimit this!!! 207 254 @routes.post("/xrpc/com.atproto.server.createSession") 208 255 async def server_create_session(request: web.Request): ··· 228 275 except ValueError: 229 276 raise web.HTTPUnauthorized(text="incorrect identifier or password") 230 277 231 - # prepare access tokens 232 - unix_seconds_now = int(time.time()) 233 - access_jwt = jwt.encode( 234 - { 235 - "scope": "com.atproto.access", 236 - "aud": db.config["pds_did"], 237 - "sub": did, 238 - "iat": unix_seconds_now, 239 - "exp": unix_seconds_now + 60 * 60 * 24, # 24h 240 - "jti": str(uuid.uuid4()), 241 - }, 242 - db.config["jwt_access_secret"], 243 - "HS256", 278 + # both generate_session_tokens and session_info need this 279 + request["authed_did"] = did 280 + 281 + return web.json_response( 282 + session_info(request) | generate_session_tokens(request) 244 283 ) 245 284 246 - refresh_jwt = jwt.encode( 247 - { 248 - "scope": "com.atproto.refresh", 249 - "aud": db.config["pds_did"], 250 - "sub": did, 251 - "iat": unix_seconds_now, 252 - "exp": unix_seconds_now + 60 * 60 * 24 * 90, # 90 days! 253 - "jti": str(uuid.uuid4()), 254 - }, 255 - db.config["jwt_access_secret"], 256 - "HS256", 285 + 286 + @routes.post("/xrpc/com.atproto.server.refreshSession") 287 + async def server_refresh_session(request: web.Request): 288 + auth = request.headers.get("Authorization", "") 289 + if not auth.startswith("Bearer "): 290 + raise web.HTTPUnauthorized(text="invalid auth type") 291 + token = auth.removeprefix("Bearer ") 292 + token_payload = verify_symmetric_token( 293 + request, token, "com.atproto.refresh" 257 294 ) 295 + request["authed_did"] = token_payload["sub"] 258 296 297 + get_db(request).con.execute( 298 + "INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)", 299 + (token_payload["sub"], token_payload["jti"], token_payload["exp"]), 300 + ) 259 301 return web.json_response( 260 - { 261 - "did": did, 262 - "handle": handle, 263 - "accessJwt": access_jwt, 264 - "refreshJwt": refresh_jwt, 265 - } 302 + session_info(request) | generate_session_tokens(request) 303 + ) 304 + 305 + 306 + # NOTE: deleteSession requires refresh token as auth, not access token 307 + @routes.post("/xrpc/com.atproto.server.deleteSession") 308 + async def server_delete_session(request: web.Request): 309 + auth = request.headers.get("Authorization", "") 310 + if not auth.startswith("Bearer "): 311 + raise web.HTTPUnauthorized(text="invalid auth type") 312 + token = auth.removeprefix("Bearer ") 313 + token_payload = verify_symmetric_token( 314 + request, token, "com.atproto.refresh" 266 315 ) 316 + 317 + get_db(request).con.execute( 318 + "INSERT INTO revoked_token (did, jti, expires_at) VALUES (?, ?, ?)", 319 + (token_payload["sub"], token_payload["jti"], token_payload["exp"]), 320 + ) 321 + 322 + return web.Response() 267 323 268 324 269 325 @routes.get("/xrpc/com.atproto.server.getServiceAuth") ··· 302 358 "lxm": lxm, 303 359 "exp": exp, 304 360 "iat": now, 305 - "jti": str(uuid.uuid4()) 361 + "jti": str(uuid.uuid4()), 306 362 }, 307 363 signing_key, 308 364 algorithm=crypto.jwt_signature_alg_for_pem(signing_key), ··· 381 437 @routes.get("/xrpc/com.atproto.server.getSession") 382 438 @authenticated 383 439 async def server_get_session(request: web.Request): 384 - return web.json_response( 385 - { 386 - "handle": get_db(request).handle_by_did(request["authed_did"]), 387 - "did": request["authed_did"], 388 - "email": "tfw_no@email.invalid", # this and below are just here for testing lol 389 - "emailConfirmed": True, 390 - # "didDoc": {}, # iiuc this is only used for entryway usecase? 391 - } 392 - ) 440 + return web.json_response(session_info(request)) 393 441 394 442 395 443 def construct_app(
+4 -1
src/millipds/static_config.py
··· 11 11 GROUPNAME = "millipds-sock" 12 12 13 13 # this gets bumped if we make breaking changes to the db schema 14 - MILLIPDS_DB_VERSION = 2 14 + MILLIPDS_DB_VERSION = 3 15 15 16 16 ATPROTO_REPO_VERSION_3 = 3 # might get bumped if the atproto spec changes 17 17 CAR_VERSION_1 = 1 ··· 29 29 DID_CACHE_ERROR_TTL = 60 * 5 # 5 mins 30 30 31 31 PLC_DIRECTORY_HOST = "https://plc.directory" 32 + 33 + ACCESS_EXP = 60 * 60 * 2 # 2 h 34 + REFRESH_EXP = 60 * 60 * 24 * 90 # 90 days
+99
tests/integration_test.py
··· 394 394 ) as r: 395 395 assert r.status == 200 396 396 await r.json() 397 + 398 + 399 + async def test_refreshSession(s, pds_host): 400 + async with s.post( 401 + pds_host + "/xrpc/com.atproto.server.createSession", 402 + json=valid_logins[0], 403 + ) as r: 404 + assert r.status == 200 405 + r = await r.json() 406 + orig_session_token = r["accessJwt"] 407 + orig_refresh_token = r["refreshJwt"] 408 + 409 + # can't refresh using the session token 410 + async with s.post( 411 + pds_host + "/xrpc/com.atproto.server.refreshSession", 412 + headers={"Authorization": "Bearer " + orig_session_token}, 413 + ) as r: 414 + assert r.status != 200 415 + 416 + # correctly refresh using the refresh token 417 + async with s.post( 418 + pds_host + "/xrpc/com.atproto.server.refreshSession", 419 + headers={"Authorization": "Bearer " + orig_refresh_token}, 420 + ) as r: 421 + assert r.status == 200 422 + r = await r.json() 423 + new_session_token = r["accessJwt"] 424 + new_refresh_token = r["refreshJwt"] 425 + 426 + # test if the new session token works 427 + async with s.get( 428 + pds_host + "/xrpc/com.atproto.server.getSession", 429 + headers={"Authorization": "Bearer " + new_session_token}, 430 + ) as r: 431 + assert r.status == 200 432 + await r.json() 433 + 434 + # test that the old session token is invalid 435 + # XXX: in the future we might relax this behaviour 436 + async with s.get( 437 + pds_host + "/xrpc/com.atproto.server.getSession", 438 + headers={"Authorization": "Bearer " + orig_session_token}, 439 + ) as r: 440 + assert r.status != 200 441 + 442 + # test that the old refresh token is invalid 443 + async with s.post( 444 + pds_host + "/xrpc/com.atproto.server.refreshSession", 445 + headers={"Authorization": "Bearer " + orig_refresh_token}, 446 + ) as r: 447 + assert r.status != 200 448 + 449 + 450 + async def test_deleteSession(s, pds_host): 451 + async with s.post( 452 + pds_host + "/xrpc/com.atproto.server.createSession", 453 + json=valid_logins[0], 454 + ) as r: 455 + assert r.status == 200 456 + r = await r.json() 457 + session_token = r["accessJwt"] 458 + refresh_token = r["refreshJwt"] 459 + 460 + # sanity-check that the session token currently works 461 + async with s.get( 462 + pds_host + "/xrpc/com.atproto.server.getSession", 463 + headers={"Authorization": "Bearer " + session_token}, 464 + ) as r: 465 + assert r.status == 200 466 + await r.json() 467 + 468 + # can't delete using the session token 469 + async with s.post( 470 + pds_host + "/xrpc/com.atproto.server.deleteSession", 471 + headers={"Authorization": "Bearer " + session_token}, 472 + ) as r: 473 + assert r.status != 200 474 + 475 + # can delete using the refresh token 476 + async with s.post( 477 + pds_host + "/xrpc/com.atproto.server.deleteSession", 478 + headers={"Authorization": "Bearer " + refresh_token}, 479 + ) as r: 480 + assert r.status == 200 481 + 482 + # test that the session token is invalid now 483 + # XXX: in the future we might relax this behaviour 484 + async with s.get( 485 + pds_host + "/xrpc/com.atproto.server.getSession", 486 + headers={"Authorization": "Bearer " + session_token}, 487 + ) as r: 488 + assert r.status != 200 489 + 490 + # test that the refresh token is invalid too 491 + async with s.post( 492 + pds_host + "/xrpc/com.atproto.server.refreshSession", 493 + headers={"Authorization": "Bearer " + refresh_token}, 494 + ) as r: 495 + assert r.status != 200