+62
-65
src/atproto/__init__.py
+62
-65
src/atproto/__init__.py
···
1
-
import aiohttp
2
-
import aiodns
1
+
from aiodns import DNSResolver, error as dns_error
2
+
from aiohttp.client import ClientSession
3
+
from os import getenv
3
4
from re import match as regex_match
4
5
from typing import Any
5
6
···
7
8
from .validator import is_valid_authserver_meta
8
9
from ..security import is_safe_url
9
10
10
-
PLC_DIRECTORY = "https://plc.directory"
11
+
PLC_DIRECTORY = getenv("PLC_DIRECTORY_URL") or "https://plc.directory"
11
12
HANDLE_REGEX = r"^([a-zA-Z0-9]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?\.)+[a-zA-Z]([a-zA-Z0-9-]{0,61}[a-zA-Z0-9])?$"
12
13
DID_REGEX = r"^did:[a-z]+:[a-zA-Z0-9._:%-]*[a-zA-Z0-9._-]$"
13
14
14
15
15
-
AuthserverUrl = str
16
-
PdsUrl = str
17
-
DID = str
16
+
type AuthserverUrl = str
17
+
type PdsUrl = str
18
+
type DID = str
18
19
19
20
20
21
def is_valid_handle(handle: str) -> bool:
···
26
27
27
28
28
29
async def resolve_identity(
30
+
client: ClientSession,
29
31
query: str,
30
32
didkv: KV = nokv,
31
33
) -> tuple[str, str, dict[str, Any]] | None:
···
36
38
did = await resolve_did_from_handle(handle, didkv)
37
39
if not did:
38
40
return None
39
-
doc = await resolve_doc_from_did(did)
41
+
doc = await resolve_doc_from_did(client, did)
40
42
if not doc:
41
43
return None
42
-
handles = handles_from_doc(doc)
43
-
if not handles or handle not in handles:
44
+
doc_handle = handle_from_doc(doc)
45
+
if not doc_handle or doc_handle != handle:
44
46
return None
45
47
return (did, handle, doc)
46
48
47
49
if is_valid_did(query):
48
50
did = query
49
-
doc = await resolve_doc_from_did(did)
51
+
doc = await resolve_doc_from_did(client, did)
50
52
if not doc:
51
53
return None
52
54
handle = handle_from_doc(doc)
···
59
61
return None
60
62
61
63
62
-
def handles_from_doc(doc: dict[str, list[str]]) -> list[str]:
64
+
def handle_from_doc(doc: dict[str, list[str]]) -> str | None:
63
65
"""Return all possible handles inside the DID document."""
64
-
handles: list[str] = []
66
+
65
67
for aka in doc.get("alsoKnownAs", []):
66
68
if aka.startswith("at://"):
67
69
handle = aka[5:].lower()
68
70
if is_valid_handle(handle):
69
-
handles.append(handle)
70
-
return handles
71
-
72
-
73
-
def handle_from_doc(doc: dict[str, list[str]]) -> str | None:
74
-
"""Return the first handle inside the DID document."""
75
-
handles = handles_from_doc(doc)
76
-
try:
77
-
return handles[0]
78
-
except IndexError:
79
-
return None
71
+
return handle
72
+
return None
80
73
81
74
82
75
async def resolve_did_from_handle(
···
94
87
print(f"returning cached did for {handle}")
95
88
return did
96
89
97
-
resolver = aiodns.DNSResolver()
90
+
resolver = DNSResolver()
98
91
try:
99
92
result = await resolver.query(f"_atproto.{handle}", "TXT")
100
-
except aiodns.error.DNSError:
93
+
except dns_error.DNSError:
101
94
return None
102
95
103
96
for record in result:
···
122
115
123
116
124
117
async def resolve_pds_from_did(
118
+
client: ClientSession,
125
119
did: DID,
126
120
kv: KV = nokv,
127
121
reload: bool = False,
···
131
125
print(f"returning cached pds for {did}")
132
126
return pds
133
127
134
-
doc = await resolve_doc_from_did(did)
128
+
doc = await resolve_doc_from_did(client, did)
135
129
if doc is None:
136
130
return None
137
131
pds = doc["service"][0]["serviceEndpoint"]
···
143
137
144
138
145
139
async def resolve_doc_from_did(
140
+
client: ClientSession,
146
141
did: DID,
147
-
directory: str = PLC_DIRECTORY,
148
142
) -> dict[str, Any] | None:
149
-
async with aiohttp.ClientSession() as client:
150
-
if did.startswith("did:plc:"):
151
-
response = await client.get(f"{directory}/{did}")
152
-
if response.ok:
153
-
return await response.json()
154
-
return None
143
+
"""Returns the DID document"""
155
144
156
-
if did.startswith("did:web:"):
157
-
# TODO: resolve did:web
158
-
return None
145
+
if did.startswith("did:plc:"):
146
+
response = await client.get(f"{PLC_DIRECTORY}/{did}")
147
+
if response.ok:
148
+
return await response.json()
149
+
return None
150
+
151
+
if did.startswith("did:web:"):
152
+
# TODO: resolve did:web
153
+
raise Exception("resolve did:web")
159
154
160
155
return None
161
156
162
157
163
158
async def resolve_authserver_from_pds(
159
+
client: ClientSession,
164
160
pds_url: PdsUrl,
165
161
kv: KV = nokv,
166
162
reload: bool = False,
···
174
170
175
171
assert is_safe_url(pds_url)
176
172
endpoint = f"{pds_url}/.well-known/oauth-protected-resource"
177
-
async with aiohttp.ClientSession() as client:
178
-
response = await client.get(endpoint)
179
-
if response.status != 200:
180
-
return None
181
-
parsed: dict[str, list[str]] = await response.json()
182
-
authserver_url = parsed["authorization_servers"][0]
183
-
print(f"caching authserver {authserver_url} for PDS {pds_url}")
184
-
kv.set(pds_url, value=authserver_url)
185
-
return authserver_url
173
+
response = await client.get(endpoint)
174
+
if response.status != 200:
175
+
return None
176
+
parsed: dict[str, list[str]] = await response.json()
177
+
authserver_url = parsed["authorization_servers"][0]
178
+
print(f"caching authserver {authserver_url} for PDS {pds_url}")
179
+
kv.set(pds_url, value=authserver_url)
180
+
return authserver_url
186
181
187
182
188
-
async def fetch_authserver_meta(authserver_url: str) -> dict[str, str] | None:
183
+
async def fetch_authserver_meta(
184
+
client: ClientSession,
185
+
authserver_url: str,
186
+
) -> dict[str, str] | None:
189
187
"""Returns metadata from the authserver"""
188
+
190
189
assert is_safe_url(authserver_url)
191
190
endpoint = f"{authserver_url}/.well-known/oauth-authorization-server"
192
-
async with aiohttp.ClientSession() as client:
193
-
response = await client.get(endpoint)
194
-
if not response.ok:
195
-
return None
196
-
meta: dict[str, Any] = await response.json()
197
-
assert is_valid_authserver_meta(meta, authserver_url)
198
-
return meta
191
+
response = await client.get(endpoint)
192
+
if not response.ok:
193
+
return None
194
+
meta: dict[str, Any] = await response.json()
195
+
assert is_valid_authserver_meta(meta, authserver_url)
196
+
return meta
199
197
200
198
201
199
async def get_record(
200
+
client: ClientSession,
202
201
pds: str,
203
202
repo: str,
204
203
collection: str,
···
207
206
) -> dict[str, Any] | None:
208
207
"""Retrieve record from PDS. Verifies type is the same as collection name."""
209
208
210
-
async with aiohttp.ClientSession() as client:
211
-
response = await client.get(
212
-
f"{pds}/xrpc/com.atproto.repo.getRecord?repo={repo}&collection={collection}&rkey={record}"
213
-
)
214
-
if not response.ok:
215
-
return None
216
-
parsed = await response.json()
217
-
value: dict[str, Any] = parsed["value"]
218
-
if value["$type"] != (type or collection):
219
-
return None
220
-
del value["$type"]
209
+
params = {"repo": repo, "collection": collection, "rkey": record}
210
+
response = await client.get(f"{pds}/xrpc/com.atproto.repo.getRecord", params=params)
211
+
if not response.ok:
212
+
return None
213
+
parsed = await response.json()
214
+
value: dict[str, Any] = parsed["value"]
215
+
if value["$type"] != (type or collection):
216
+
return None
217
+
del value["$type"]
221
218
222
-
return value
219
+
return value
+7
-17
src/atproto/oauth.py
+7
-17
src/atproto/oauth.py
···
1
1
from typing import Any, Callable, NamedTuple
2
2
import time
3
3
import json
4
+
from aiohttp.client import ClientSession, ClientResponse
4
5
from authlib.jose import JsonWebKey, Key, jwt
5
6
from authlib.common.security import generate_token
6
7
from authlib.oauth2.rfc7636 import create_s256_code_challenge
7
-
from aiohttp import ClientResponse
8
8
9
9
from . import fetch_authserver_meta
10
10
···
84
84
respjson = await resp.json()
85
85
if resp.status == 400 and respjson["error"] == "use_dpop_nonce":
86
86
dpop_authserver_nonce = resp.headers["DPoP-Nonce"]
87
-
print(f"retrying with new auth server DPoP nonce: {dpop_authserver_nonce}")
88
87
dpop_proof = _authserver_dpop_jwt(
89
88
"POST", par_url, dpop_authserver_nonce, dpop_private_jwk
90
89
)
···
105
104
# Returns token response (OAuthTokens) and DPoP nonce (str)
106
105
# IMPORTANT: the 'tokens.sub' field must be verified against the original request by code calling this function.
107
106
async def initial_token_request(
107
+
client: ClientSession,
108
108
auth_request: OAuthAuthRequest,
109
109
code: str,
110
110
app_url: str,
···
113
113
authserver_url = auth_request.authserver_iss
114
114
115
115
# Re-fetch server metadata
116
-
authserver_meta = await fetch_authserver_meta(authserver_url)
116
+
authserver_meta = await fetch_authserver_meta(client, authserver_url)
117
117
if not authserver_meta:
118
118
raise Exception("missing authserver meta")
119
119
···
146
146
147
147
# IMPORTANT: Token URL is untrusted input, SSRF mitigations are needed
148
148
assert is_safe_url(token_url)
149
-
async with hardened_http.get_session() as session:
150
-
resp = await session.post(token_url, data=params, headers={"DPoP": dpop_proof})
151
149
150
+
resp = await client.post(token_url, data=params, headers={"DPoP": dpop_proof})
152
151
# Handle DPoP missing/invalid nonce error by retrying with server-provided nonce
153
152
respjson = await resp.json()
154
153
if resp.status == 400 and respjson["error"] == "use_dpop_nonce":
155
154
dpop_authserver_nonce = resp.headers["DPoP-Nonce"]
156
-
print(f"retrying with new auth server DPoP nonce: {dpop_authserver_nonce}")
157
-
# print(server_nonce)
158
155
dpop_proof = _authserver_dpop_jwt(
159
156
"POST", token_url, dpop_authserver_nonce, dpop_private_jwk
160
157
)
161
-
async with hardened_http.get_session() as session:
162
-
resp = await session.post(
163
-
token_url,
164
-
data=params,
165
-
headers={"DPoP": dpop_proof},
166
-
)
158
+
resp = await client.post(token_url, data=params, headers={"DPoP": dpop_proof})
167
159
168
160
resp.raise_for_status()
169
161
token_body = await resp.json()
···
174
166
175
167
# Returns token response (OAuthTokens) and DPoP nonce (str)
176
168
async def refresh_token_request(
169
+
client: ClientSession,
177
170
user: OAuthSession,
178
171
app_url: str,
179
172
client_secret_jwk: Key,
···
181
174
authserver_url = user.authserver_iss
182
175
183
176
# Re-fetch server metadata
184
-
authserver_meta = await fetch_authserver_meta(authserver_url)
177
+
authserver_meta = await fetch_authserver_meta(client, authserver_url)
185
178
if not authserver_meta:
186
179
raise Exception("missing authserver meta")
187
180
···
218
211
respjson = await resp.json()
219
212
if resp.status == 400 and respjson["error"] == "use_dpop_nonce":
220
213
dpop_authserver_nonce = resp.headers["DPoP-Nonce"]
221
-
print(f"retrying with new auth server DPoP nonce: {dpop_authserver_nonce}")
222
-
# print(server_nonce)
223
214
dpop_proof = _authserver_dpop_jwt(
224
215
"POST", token_url, dpop_authserver_nonce, dpop_private_jwk
225
216
)
···
278
269
respjson = await response.json()
279
270
if response.status in [400, 401] and respjson["error"] == "use_dpop_nonce":
280
271
dpop_pds_nonce = response.headers["DPoP-Nonce"]
281
-
print(f"retrying with new PDS DPoP nonce: {dpop_pds_nonce}")
282
272
update_dpop_pds_nonce(dpop_pds_nonce)
283
273
continue
284
274
break
+27
-19
src/main.py
+27
-19
src/main.py
···
1
1
import asyncio
2
2
import json
3
3
4
+
from aiohttp.client import ClientSession
4
5
from flask import Flask, g, session, redirect, render_template, request, url_for
5
6
from typing import Any
6
7
···
12
13
resolve_pds_from_did,
13
14
)
14
15
from .atproto.oauth import pds_authed_req
15
-
from .db import KV, close_db_connection, init_db
16
+
from .db import KV, close_db_connection, get_db, init_db
16
17
from .oauth import get_auth_session, oauth, save_auth_session
17
18
from .types import OAuthSession
18
19
···
25
26
26
27
27
28
@app.before_request
28
-
def load_user_to_context():
29
+
async def load_user_to_context():
29
30
g.user = get_auth_session(session)
30
31
31
32
···
34
35
35
36
36
37
@app.teardown_appcontext
37
-
def app_teardown(exception: BaseException | None):
38
+
async def app_teardown(exception: BaseException | None):
38
39
close_db_connection(exception)
39
40
40
41
···
47
48
async def page_profile(atid: str):
48
49
reload = request.args.get("reload") is not None
49
50
51
+
db = get_db(app)
52
+
didkv = KV(db, "did_from_handle")
53
+
pdskv = KV(db, "pds_from_did")
54
+
50
55
if atid.startswith("@"):
51
56
handle = atid[1:].lower()
52
-
did = await resolve_did_from_handle(handle, reload=reload)
57
+
did = await resolve_did_from_handle(handle, kv=didkv, reload=reload)
53
58
if did is None:
54
59
return render_template("error.html", message="did not found"), 404
55
60
elif is_valid_did(atid):
···
60
65
if _is_did_blocked(did):
61
66
return render_template("error.html", message="profile not found"), 404
62
67
63
-
kv = KV(app, "pds_from_did")
64
-
pds = await resolve_pds_from_did(did, kv, reload=reload)
65
-
if pds is None:
66
-
return render_template("error.html", message="pds not found"), 404
67
-
(profile, _), links = await asyncio.gather(
68
-
load_profile(pds, did, reload=reload),
69
-
load_links(pds, did, reload=reload),
70
-
)
68
+
async with ClientSession() as client:
69
+
pds = await resolve_pds_from_did(client, did=did, kv=pdskv, reload=reload)
70
+
if pds is None:
71
+
return render_template("error.html", message="pds not found"), 404
72
+
(profile, _), links = await asyncio.gather(
73
+
load_profile(client, pds, did, reload=reload),
74
+
load_links(client, pds, did, reload=reload),
75
+
)
71
76
if links is None:
72
77
return render_template("error.html", message="profile not found"), 404
73
78
···
112
117
pds: str = user.pds_url
113
118
handle: str | None = user.handle
114
119
115
-
(profile, from_bluesky), links = await asyncio.gather(
116
-
load_profile(pds, did, reload=True),
117
-
load_links(pds, did, reload=True),
118
-
)
120
+
async with ClientSession() as client:
121
+
(profile, from_bluesky), links = await asyncio.gather(
122
+
load_profile(client, pds, did, reload=True),
123
+
load_links(client, pds, did, reload=True),
124
+
)
119
125
120
126
return render_template(
121
127
"editor.html",
···
197
203
198
204
199
205
async def load_links(
206
+
client: ClientSession,
200
207
pds: str,
201
208
did: str,
202
209
reload: bool = False,
···
208
215
app.logger.debug(f"returning cached links for {did}")
209
216
return json.loads(recordstr)["links"]
210
217
211
-
record = await get_record(pds, did, f"{SCHEMA}.actor.links", "self")
218
+
record = await get_record(client, pds, did, f"{SCHEMA}.actor.links", "self")
212
219
if record is None:
213
220
return None
214
221
···
218
225
219
226
220
227
async def load_profile(
228
+
client: ClientSession,
221
229
pds: str,
222
230
did: str,
223
231
fallback_with_bluesky: bool = True,
···
231
239
return json.loads(recordstr), False
232
240
233
241
from_bluesky = False
234
-
record = await get_record(pds, did, f"{SCHEMA}.actor.profile", "self")
242
+
record = await get_record(client, pds, did, f"{SCHEMA}.actor.profile", "self")
235
243
if record is None and fallback_with_bluesky:
236
-
record = await get_record(pds, did, "app.bsky.actor.profile", "self")
244
+
record = await get_record(client, pds, did, "app.bsky.actor.profile", "self")
237
245
from_bluesky = True
238
246
if record is None:
239
247
return None, False
+23
-7
src/oauth.py
+23
-7
src/oauth.py
···
1
-
from typing import NamedTuple
1
+
from aiohttp.client import ClientSession
2
2
from authlib.jose import JsonWebKey, Key
3
3
from flask import Blueprint, current_app, jsonify, redirect, request, session, url_for
4
4
from flask.sessions import SessionMixin
5
+
from typing import NamedTuple
5
6
from urllib.parse import urlencode
6
7
7
8
import json
···
33
34
db = get_db(current_app)
34
35
pdskv = KV(db, "authserver_from_pds")
35
36
37
+
client = ClientSession()
38
+
36
39
if is_valid_handle(username) or is_valid_did(username):
37
40
login_hint = username
38
41
kv = KV(db, "did_from_handle")
39
-
identity = await resolve_identity(username, didkv=kv)
42
+
identity = await resolve_identity(client, username, didkv=kv)
40
43
if identity is None:
41
44
return "couldnt resolve identity", 500
42
45
did, handle, doc = identity
···
44
47
if not pds_url:
45
48
return "pds not found", 404
46
49
current_app.logger.debug(f"account PDS: {pds_url}")
47
-
authserver_url = await resolve_authserver_from_pds(pds_url, pdskv)
50
+
authserver_url = await resolve_authserver_from_pds(client, pds_url, pdskv)
48
51
if not authserver_url:
49
52
return "authserver not found", 404
50
53
51
54
elif username.startswith("https://") and is_safe_url(username):
52
55
did, handle, pds_url = None, None, None
53
56
login_hint = None
54
-
authserver_url = await resolve_authserver_from_pds(username, pdskv) or username
57
+
authserver_url = (
58
+
await resolve_authserver_from_pds(client, username, pdskv) or username
59
+
)
55
60
56
61
else:
57
62
return "not a valid handle, did or auth server", 400
58
63
59
64
current_app.logger.debug(f"Authserver: {authserver_url}")
60
65
assert is_safe_url(authserver_url)
61
-
authserver_meta = await fetch_authserver_meta(authserver_url)
66
+
authserver_meta = await fetch_authserver_meta(client, authserver_url)
62
67
if not authserver_meta:
63
68
return "no authserver meta", 404
64
69
70
+
await client.close()
71
+
65
72
# Auth
66
73
dpop_private_jwk: Key = JsonWebKey.generate_key("EC", "P-256", is_private=True)
67
74
scope = "atproto transition:generic"
···
133
140
assert auth_request.authserver_iss == authserver_iss
134
141
assert auth_request.state == state
135
142
143
+
client = ClientSession()
144
+
136
145
app_url = request.url_root.replace("http://", "https://")
137
146
CLIENT_SECRET_JWK = JsonWebKey.import_key(current_app.config["CLIENT_SECRET_JWK"])
138
147
tokens, dpop_authserver_nonce = await initial_token_request(
148
+
client,
139
149
auth_request,
140
150
authorization_code,
141
151
app_url,
···
155
165
else:
156
166
did = tokens.sub
157
167
assert is_valid_did(did)
158
-
identity = await resolve_identity(did, didkv=didkv)
168
+
identity = await resolve_identity(client, did, didkv=didkv)
159
169
if not identity:
160
170
return "could not resolve identity", 500
161
171
did, handle, did_doc = identity
162
172
pds_url = pds_endpoint_from_doc(did_doc)
163
173
if not pds_url:
164
174
return "could not resolve pds", 500
165
-
authserver_url = await resolve_authserver_from_pds(pds_url, authserverkv)
175
+
authserver_url = await resolve_authserver_from_pds(
176
+
client,
177
+
pds_url,
178
+
authserverkv,
179
+
)
166
180
assert authserver_url == authserver_iss
181
+
182
+
await client.close()
167
183
168
184
assert row.scope == tokens.scope
169
185
assert pds_url is not None