+1
-1
backend/src/backend/_internal/atproto/client.py
+1
-1
backend/src/backend/_internal/atproto/client.py
···
17
17
# uses LRUCache (not TTLCache) to bound memory - LRU eviction is safe because:
18
18
# 1. recently-used locks won't be evicted while in use
19
19
# 2. TTL expiration could evict a lock while a coroutine holds it, breaking mutual exclusion
20
-
_refresh_locks: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=10000)
20
+
_refresh_locks: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=10_000)
21
21
22
22
23
23
def reconstruct_oauth_session(oauth_data: dict[str, Any]) -> OAuthSession:
+11
-73
backend/tests/test_token_refresh.py
+11
-73
backend/tests/test_token_refresh.py
···
5
5
6
6
import pytest
7
7
from atproto_oauth.models import OAuthSession
8
+
from cachetools import LRUCache
9
+
from cryptography.hazmat.backends import default_backend
10
+
from cryptography.hazmat.primitives import serialization
11
+
from cryptography.hazmat.primitives.asymmetric import ec
8
12
9
13
from backend._internal import Session as AuthSession
10
-
from backend._internal.atproto.client import _refresh_session_tokens
14
+
from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens
11
15
12
16
13
17
@pytest.fixture
14
18
def mock_auth_session() -> AuthSession:
15
19
"""create mock auth session."""
16
-
# generate a real EC key and serialize it
17
-
import cryptography.hazmat.backends
18
-
import cryptography.hazmat.primitives.asymmetric.ec as ec
19
-
from cryptography.hazmat.primitives import serialization
20
-
21
-
private_key = ec.generate_private_key(
22
-
ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
23
-
)
20
+
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
24
21
25
22
dpop_key_pem = private_key.private_bytes(
26
23
encoding=serialization.Encoding.PEM,
···
50
47
@pytest.fixture
51
48
def mock_oauth_session() -> OAuthSession:
52
49
"""create mock oauth session."""
53
-
# defer cryptography import to avoid overhead
54
-
import cryptography.hazmat.backends
55
-
import cryptography.hazmat.primitives.asymmetric.ec as ec
56
-
57
-
# generate a real key for the mock
58
-
private_key = ec.generate_private_key(
59
-
ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
60
-
)
50
+
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
61
51
62
52
return OAuthSession(
63
53
did="did:plc:test123",
···
84
74
new_token = "new-refreshed-token"
85
75
86
76
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
87
-
"""mock OAuth client refresh with delay to simulate race."""
88
77
nonlocal refresh_call_count
89
78
refresh_call_count += 1
90
-
91
-
# simulate network delay
92
79
await asyncio.sleep(0.1)
93
-
94
-
# return updated session with new token
95
80
session.access_token = new_token
96
81
return session
97
82
98
83
async def mock_get_session(session_id: str) -> AuthSession | None:
99
-
"""mock get_session that returns updated tokens after first refresh."""
100
84
if refresh_call_count > 0:
101
-
# first refresh completed - return new token
102
85
mock_auth_session.oauth_session["access_token"] = new_token
103
86
return mock_auth_session
104
87
105
88
async def mock_update_session_tokens(
106
89
session_id: str, oauth_session_data: dict
107
90
) -> None:
108
-
"""mock session update."""
109
91
mock_auth_session.oauth_session.update(oauth_session_data)
110
92
111
-
# create a mock OAuth client with the refresh method
112
93
mock_oauth_client = type(
113
94
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
114
95
)()
···
127
108
side_effect=mock_update_session_tokens,
128
109
),
129
110
):
130
-
# launch 5 concurrent refresh attempts
131
111
tasks = [
132
112
_refresh_session_tokens(mock_auth_session, mock_oauth_session)
133
113
for _ in range(5)
134
114
]
135
115
results = await asyncio.gather(*tasks)
136
116
137
-
# all should succeed and get the new token
138
117
assert all(result.access_token == new_token for result in results)
139
-
140
-
# but OAuth client should only be called once (the lock worked!)
141
118
assert refresh_call_count == 1
142
119
143
120
async def test_refresh_failure_uses_fallback(
···
150
127
async def mock_refresh_session_fails(
151
128
self, session: OAuthSession
152
129
) -> OAuthSession:
153
-
"""mock refresh that always fails."""
154
130
nonlocal refresh_called
155
131
refresh_called = True
156
132
await asyncio.sleep(0.05)
···
159
135
get_session_calls = 0
160
136
161
137
async def mock_get_session(session_id: str) -> AuthSession | None:
162
-
"""mock get_session that returns updated tokens on retry."""
163
138
nonlocal get_session_calls
164
139
get_session_calls += 1
165
-
166
-
# on retry (after failure), return new tokens as if another request succeeded
167
140
if get_session_calls >= 2:
168
141
mock_auth_session.oauth_session["access_token"] = new_token
169
-
170
142
return mock_auth_session
171
143
172
144
async def mock_update_session_tokens(
173
145
session_id: str, oauth_session_data: dict
174
146
) -> None:
175
-
"""mock session update."""
176
147
mock_auth_session.oauth_session.update(oauth_session_data)
177
148
178
-
# create a mock OAuth client with the failing refresh method
179
149
mock_oauth_client = type(
180
150
"MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails}
181
151
)()
···
194
164
side_effect=mock_update_session_tokens,
195
165
),
196
166
):
197
-
# this should fail to refresh but succeed via fallback
198
167
result = await _refresh_session_tokens(
199
168
mock_auth_session, mock_oauth_session
200
169
)
201
170
202
-
# verify it tried to refresh
203
171
assert refresh_called
204
-
205
-
# verify it fell back to reloaded tokens
206
172
assert result.access_token == new_token
207
173
208
174
async def test_second_request_skips_refresh_if_already_done(
···
213
179
new_token = "already-refreshed-token"
214
180
215
181
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
216
-
"""mock OAuth client refresh."""
217
182
nonlocal refresh_call_count
218
183
refresh_call_count += 1
219
184
await asyncio.sleep(0.1)
···
223
188
get_session_calls = 0
224
189
225
190
async def mock_get_session(session_id: str) -> AuthSession | None:
226
-
"""mock get_session that simulates first refresh completing quickly."""
227
191
nonlocal get_session_calls
228
192
get_session_calls += 1
229
-
230
-
# on second+ call, act like refresh already happened
231
193
if get_session_calls > 1:
232
194
mock_auth_session.oauth_session["access_token"] = new_token
233
-
234
195
return mock_auth_session
235
196
236
197
async def mock_update_session_tokens(
237
198
session_id: str, oauth_session_data: dict
238
199
) -> None:
239
-
"""mock session update."""
240
200
mock_auth_session.oauth_session.update(oauth_session_data)
241
201
242
-
# create a mock OAuth client with the refresh method
243
202
mock_oauth_client = type(
244
203
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
245
204
)()
···
258
217
side_effect=mock_update_session_tokens,
259
218
),
260
219
):
261
-
# first refresh
262
220
result1 = await _refresh_session_tokens(
263
221
mock_auth_session, mock_oauth_session
264
222
)
265
223
assert result1.access_token == new_token
266
224
267
-
# second refresh attempt should skip network call
268
225
result2 = await _refresh_session_tokens(
269
226
mock_auth_session, mock_oauth_session
270
227
)
271
228
assert result2.access_token == new_token
272
229
273
-
# OAuth client should have been called exactly once
274
230
assert refresh_call_count == 1
275
231
276
232
···
279
235
280
236
def test_same_session_returns_same_lock(self):
281
237
"""same session_id should return the same lock instance."""
282
-
from backend._internal.atproto.client import _refresh_locks
283
-
284
-
# clear for isolated test
285
238
_refresh_locks.clear()
286
239
287
-
# create lock for session
288
240
_refresh_locks["session-a"] = asyncio.Lock()
289
241
lock1 = _refresh_locks["session-a"]
290
-
291
-
# accessing again should return same lock
292
242
lock2 = _refresh_locks["session-a"]
243
+
293
244
assert lock1 is lock2
294
245
295
246
def test_different_sessions_have_different_locks(self):
296
247
"""different session_ids should have different lock instances."""
297
-
from backend._internal.atproto.client import _refresh_locks
298
-
299
248
_refresh_locks.clear()
300
249
301
250
_refresh_locks["session-a"] = asyncio.Lock()
···
305
254
306
255
def test_cache_is_bounded_by_maxsize(self):
307
256
"""cache should evict entries when full (LRU behavior)."""
308
-
from backend._internal.atproto.client import _refresh_locks
309
-
310
257
_refresh_locks.clear()
311
258
312
-
# fill cache beyond maxsize (maxsize=10000, but we'll test the behavior)
313
-
# just verify the maxsize property is set
314
-
assert _refresh_locks.maxsize == 10000
259
+
assert _refresh_locks.maxsize == 10_000
315
260
316
-
# add some entries and verify they exist
317
261
for i in range(100):
318
262
_refresh_locks[f"session-{i}"] = asyncio.Lock()
319
263
···
321
265
322
266
def test_lru_eviction_order(self):
323
267
"""LRU cache should evict least recently used entries first."""
324
-
from cachetools import LRUCache
325
-
326
-
# use a small cache to test eviction behavior
327
268
small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3)
328
269
329
270
small_cache["a"] = asyncio.Lock()
330
271
small_cache["b"] = asyncio.Lock()
331
272
small_cache["c"] = asyncio.Lock()
332
273
333
-
# access "a" to make it recently used
334
274
_ = small_cache["a"]
335
-
336
-
# add "d" - should evict "b" (least recently used)
337
275
small_cache["d"] = asyncio.Lock()
338
276
339
-
assert "a" in small_cache # recently accessed
340
-
assert "b" not in small_cache # evicted (LRU)
277
+
assert "a" in small_cache
278
+
assert "b" not in small_cache
341
279
assert "c" in small_cache
342
280
assert "d" in small_cache