···22from typing import Any
33import json
4455-from .atproto2 import PdsUrl, get_record, resolve_did_from_handle, resolve_pds_from_did
66-from .atproto2.atproto_oauth import pds_authed_req
55+from .atproto import PdsUrl, get_record, resolve_did_from_handle, resolve_pds_from_did
66+from .atproto.atproto_oauth import pds_authed_req
77from .db import close_db_connection, get_db, init_db
88from .oauth import oauth
99+from .types import OAuthSession
9101011app = Flask(__name__)
1112_ = app.config.from_prefixed_env()
···20212122@app.before_request
2223def load_user_to_context():
2424+ user: OAuthSession | None = None
2325 did: str | None = session.get("user_did")
2424- if did is None:
2525- g.user = None
2626- else:
2626+ if did is not None:
2727 db = get_db(app)
2828- g.user = db.execute(
2828+ row = db.execute(
2929 "select * from oauth_session where did = ?",
3030 (did,),
3131 ).fetchone()
3232+ user = OAuthSession(**row)
3333+ g.user = user
323433353434-def get_user() -> dict[str, str] | None:
3636+def get_user() -> OAuthSession | None:
3537 return g.user
36383739···9092 if user is None:
9193 return redirect("/login")
92949393- did: str = user["did"]
9494- pds: str = user["pds_url"]
9595- handle: str | None = user["handle"]
9595+ did: str = user.did
9696+ pds: str = user.pds_url
9797+ handle: str | None = user.handle
96989799 profile, from_bluesky = load_profile(pds, did, reload=True)
98100 links = load_links(pds, did, reload=True) or [{"background": "#fa0"}]
···119121120122 put_record(
121123 user=user,
122122- pds=user["pds_url"],
123123- repo=user["did"],
124124+ pds=user.pds_url,
125125+ repo=user.did,
124126 collection=f"{SCHEMA}.actor.profile",
125127 rkey="self",
126128 record={
···159161160162 put_record(
161163 user=user,
162162- pds=user["pds_url"],
163163- repo=user["did"],
164164+ pds=user.pds_url,
165165+ repo=user.did,
164166 collection=f"{SCHEMA}.actor.links",
165167 rkey="self",
166168 record={
···212214213215214216def put_record(
215215- user: dict[str, str],
217217+ user: OAuthSession,
216218 pds: PdsUrl,
217219 repo: str,
218220 collection: str,
···246248 if user is not None:
247249 db = get_db(app)
248250 cursor = db.cursor()
249249- _ = cursor.execute("delete from oauth_session where did = ?", (user["did"],))
251251+ _ = cursor.execute("delete from oauth_session where did = ?", (user.did,))
250252 db.commit()
251253 cursor.close()
252254 session.clear()
+43-28
src/oauth.py
···4455import json
6677-from .atproto2.atproto_identity import is_valid_did, is_valid_handle
88-from .atproto2.atproto_oauth import initial_token_request, send_par_auth_request
99-from .atproto2.atproto_security import is_safe_url
1010-from .atproto2 import (
77+from .atproto.atproto_identity import is_valid_did, is_valid_handle
88+from .atproto.atproto_oauth import initial_token_request, send_par_auth_request
99+from .atproto.atproto_security import is_safe_url
1010+from .atproto import (
1111 pds_endpoint_from_doc,
1212 resolve_authserver_from_pds,
1313 resolve_authserver_meta,
1414 resolve_identity,
1515)
1616+from .types import OAuthAuthRequest
1617from .db import get_db
17181819oauth = Blueprint("oauth", __name__, url_prefix="/oauth")
1919-2020-2121-oauth_auth_requests: dict[str, dict[str, str]] = {}
222023212422@oauth.get("/start")
···87858886 par_request_uri: str = resp.json()["request_uri"]
8987 current_app.logger.debug(f"saving oauth_auth_request to DB state={state}")
9090- oauth_auth_requests[state] = {
9191- "authserver_iss": authserver_meta["issuer"],
9292- "did": did or "", # TODO: use actual typing
9393- "handle": handle or "",
9494- "pds_url": pds_url or "",
9595- "pkce_verifier": pkce_verifier,
9696- "scope": scope,
9797- "dpop_authserver_nonce": dpop_authserver_nonce,
9898- "dpop_private_jwk": dpop_private_jwk.as_json(is_private=True),
9999- }
8888+8989+ db = get_db(current_app)
9090+ cursor = db.cursor()
9191+ _ = cursor.execute(
9292+ "insert or replace into oauth_auth_requests values (?, ?, ?, ?, ?, ?, ?, ?, ?)",
9393+ (
9494+ state,
9595+ authserver_meta["issuer"],
9696+ did,
9797+ handle,
9898+ pds_url,
9999+ pkce_verifier,
100100+ scope,
101101+ dpop_authserver_nonce,
102102+ dpop_private_jwk.as_json(is_private=True),
103103+ ),
104104+ )
105105+ db.commit()
106106+ cursor.close()
100107101108 auth_endpoint = authserver_meta["authorization_endpoint"]
102109 assert is_safe_url(auth_endpoint)
···110117 authserver_iss = request.args["iss"]
111118 authorization_code = request.args["code"]
112119113113- auth_request = oauth_auth_requests.get(state)
114114- if auth_request is None:
120120+ db = get_db(current_app)
121121+ cursor = db.cursor()
122122+123123+ row = cursor.execute(
124124+ "select * from oauth_auth_requests where state = ?", (state,)
125125+ ).fetchone()
126126+ try:
127127+ auth_request = OAuthAuthRequest(**row)
128128+ except TypeError:
115129 return redirect(url_for("page_login"), 303)
116130117131 current_app.logger.debug(f"Deleting auth request for state={state}")
118118- _ = oauth_auth_requests.pop(state)
132132+ _ = cursor.execute("delete from oauth_auth_requests where state = ?", (state,))
133133+ db.commit()
119134120120- assert auth_request["authserver_iss"] == authserver_iss
121121- # assert state ????
135135+ assert auth_request.authserver_iss == authserver_iss
136136+ assert auth_request.state == state
122137123138 app_url = request.url_root.replace("http://", "https://")
124139 CLIENT_SECRET_JWK = JsonWebKey.import_key(current_app.config["CLIENT_SECRET_JWK"])
···131146132147 row = auth_request
133148134134- did = auth_request["did"]
135135- if row["did"]:
149149+ did = auth_request.did
150150+ if row.did:
136151 # If we started with an account identifier, this is simple
137137- did, handle, pds_url = row["did"], row["handle"], row["pds_url"]
152152+ did, handle, pds_url = row.did, row.handle, row.pds_url
138153 assert tokens["sub"] == did
139154 else:
140155 did = tokens["sub"]
···149164 authserver_url = resolve_authserver_from_pds(pds_url)
150165 assert authserver_url == authserver_iss
151166152152- assert row["scope"] == tokens["scope"]
167167+ assert row.scope == tokens["scope"]
153168154169 current_app.logger.debug("storing user did and handle")
155170 db = get_db(current_app)
···165180 tokens["refresh_token"],
166181 dpop_authserver_nonce,
167182 None,
168168- auth_request["dpop_private_jwk"],
183183+ auth_request.dpop_private_jwk,
169184 ),
170185 )
171186 db.commit()
172187 cursor.close()
173188174189 session["user_did"] = did
175175- session["user_handle"] = auth_request["handle"]
190190+ session["user_handle"] = auth_request.handle
176191177192 return redirect(url_for("page_login"))
178193
+13-1
src/schema.sql
···11-create table if not exists oauth_session (
11+create table if not exists oauth_auth_requests (
22+ state text not null primary key,
33+ authserver_iss text not null,
44+ did text,
55+ handle text,
66+ pds_url text,
77+ pkce_verifier text not null,
88+ scope text not null,
99+ dpop_authserver_nonce text not null,
1010+ dpop_private_jwk text not null
1111+) strict, without rowid;
1212+1313+create table if not exists oauth_sessions (
214 did text not null primary key,
315 handle text,
416 pds_url text not null,