+6
-2
backend/src/backend/_internal/atproto/client.py
+6
-2
backend/src/backend/_internal/atproto/client.py
···
6
6
from typing import Any
7
7
8
8
from atproto_oauth.models import OAuthSession
9
+
from cachetools import LRUCache
9
10
10
11
from backend._internal import Session as AuthSession
11
12
from backend._internal import get_oauth_client, get_session, update_session_tokens
12
13
13
14
logger = logging.getLogger(__name__)
14
15
15
-
# per-session locks for token refresh to prevent concurrent refresh races
16
-
_refresh_locks: dict[str, asyncio.Lock] = {}
16
+
# per-session locks for token refresh to prevent concurrent refresh races.
17
+
# uses LRUCache (not TTLCache) to bound memory - LRU eviction is safe because:
18
+
# 1. recently-used locks won't be evicted while in use
19
+
# 2. TTL expiration could evict a lock while a coroutine holds it, breaking mutual exclusion
20
+
_refresh_locks: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=10_000)
17
21
18
22
19
23
def reconstruct_oauth_session(oauth_data: dict[str, Any]) -> OAuthSession:
+57
-51
backend/tests/test_token_refresh.py
+57
-51
backend/tests/test_token_refresh.py
···
5
5
6
6
import pytest
7
7
from atproto_oauth.models import OAuthSession
8
+
from cachetools import LRUCache
9
+
from cryptography.hazmat.backends import default_backend
10
+
from cryptography.hazmat.primitives import serialization
11
+
from cryptography.hazmat.primitives.asymmetric import ec
8
12
9
13
from backend._internal import Session as AuthSession
10
-
from backend._internal.atproto.client import _refresh_session_tokens
14
+
from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens
11
15
12
16
13
17
@pytest.fixture
14
18
def mock_auth_session() -> AuthSession:
15
19
"""create mock auth session."""
16
-
# generate a real EC key and serialize it
17
-
import cryptography.hazmat.backends
18
-
import cryptography.hazmat.primitives.asymmetric.ec as ec
19
-
from cryptography.hazmat.primitives import serialization
20
-
21
-
private_key = ec.generate_private_key(
22
-
ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
23
-
)
20
+
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
24
21
25
22
dpop_key_pem = private_key.private_bytes(
26
23
encoding=serialization.Encoding.PEM,
···
50
47
@pytest.fixture
51
48
def mock_oauth_session() -> OAuthSession:
52
49
"""create mock oauth session."""
53
-
# defer cryptography import to avoid overhead
54
-
import cryptography.hazmat.backends
55
-
import cryptography.hazmat.primitives.asymmetric.ec as ec
56
-
57
-
# generate a real key for the mock
58
-
private_key = ec.generate_private_key(
59
-
ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
60
-
)
50
+
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
61
51
62
52
return OAuthSession(
63
53
did="did:plc:test123",
···
84
74
new_token = "new-refreshed-token"
85
75
86
76
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
87
-
"""mock OAuth client refresh with delay to simulate race."""
88
77
nonlocal refresh_call_count
89
78
refresh_call_count += 1
90
-
91
-
# simulate network delay
92
79
await asyncio.sleep(0.1)
93
-
94
-
# return updated session with new token
95
80
session.access_token = new_token
96
81
return session
97
82
98
83
async def mock_get_session(session_id: str) -> AuthSession | None:
99
-
"""mock get_session that returns updated tokens after first refresh."""
100
84
if refresh_call_count > 0:
101
-
# first refresh completed - return new token
102
85
mock_auth_session.oauth_session["access_token"] = new_token
103
86
return mock_auth_session
104
87
105
88
async def mock_update_session_tokens(
106
89
session_id: str, oauth_session_data: dict
107
90
) -> None:
108
-
"""mock session update."""
109
91
mock_auth_session.oauth_session.update(oauth_session_data)
110
92
111
-
# create a mock OAuth client with the refresh method
112
93
mock_oauth_client = type(
113
94
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
114
95
)()
···
127
108
side_effect=mock_update_session_tokens,
128
109
),
129
110
):
130
-
# launch 5 concurrent refresh attempts
131
111
tasks = [
132
112
_refresh_session_tokens(mock_auth_session, mock_oauth_session)
133
113
for _ in range(5)
134
114
]
135
115
results = await asyncio.gather(*tasks)
136
116
137
-
# all should succeed and get the new token
138
117
assert all(result.access_token == new_token for result in results)
139
-
140
-
# but OAuth client should only be called once (the lock worked!)
141
118
assert refresh_call_count == 1
142
119
143
120
async def test_refresh_failure_uses_fallback(
···
150
127
async def mock_refresh_session_fails(
151
128
self, session: OAuthSession
152
129
) -> OAuthSession:
153
-
"""mock refresh that always fails."""
154
130
nonlocal refresh_called
155
131
refresh_called = True
156
132
await asyncio.sleep(0.05)
···
159
135
get_session_calls = 0
160
136
161
137
async def mock_get_session(session_id: str) -> AuthSession | None:
162
-
"""mock get_session that returns updated tokens on retry."""
163
138
nonlocal get_session_calls
164
139
get_session_calls += 1
165
-
166
-
# on retry (after failure), return new tokens as if another request succeeded
167
140
if get_session_calls >= 2:
168
141
mock_auth_session.oauth_session["access_token"] = new_token
169
-
170
142
return mock_auth_session
171
143
172
144
async def mock_update_session_tokens(
173
145
session_id: str, oauth_session_data: dict
174
146
) -> None:
175
-
"""mock session update."""
176
147
mock_auth_session.oauth_session.update(oauth_session_data)
177
148
178
-
# create a mock OAuth client with the failing refresh method
179
149
mock_oauth_client = type(
180
150
"MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails}
181
151
)()
···
194
164
side_effect=mock_update_session_tokens,
195
165
),
196
166
):
197
-
# this should fail to refresh but succeed via fallback
198
167
result = await _refresh_session_tokens(
199
168
mock_auth_session, mock_oauth_session
200
169
)
201
170
202
-
# verify it tried to refresh
203
171
assert refresh_called
204
-
205
-
# verify it fell back to reloaded tokens
206
172
assert result.access_token == new_token
207
173
208
174
async def test_second_request_skips_refresh_if_already_done(
···
213
179
new_token = "already-refreshed-token"
214
180
215
181
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
216
-
"""mock OAuth client refresh."""
217
182
nonlocal refresh_call_count
218
183
refresh_call_count += 1
219
184
await asyncio.sleep(0.1)
···
223
188
get_session_calls = 0
224
189
225
190
async def mock_get_session(session_id: str) -> AuthSession | None:
226
-
"""mock get_session that simulates first refresh completing quickly."""
227
191
nonlocal get_session_calls
228
192
get_session_calls += 1
229
-
230
-
# on second+ call, act like refresh already happened
231
193
if get_session_calls > 1:
232
194
mock_auth_session.oauth_session["access_token"] = new_token
233
-
234
195
return mock_auth_session
235
196
236
197
async def mock_update_session_tokens(
237
198
session_id: str, oauth_session_data: dict
238
199
) -> None:
239
-
"""mock session update."""
240
200
mock_auth_session.oauth_session.update(oauth_session_data)
241
201
242
-
# create a mock OAuth client with the refresh method
243
202
mock_oauth_client = type(
244
203
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
245
204
)()
···
258
217
side_effect=mock_update_session_tokens,
259
218
),
260
219
):
261
-
# first refresh
262
220
result1 = await _refresh_session_tokens(
263
221
mock_auth_session, mock_oauth_session
264
222
)
265
223
assert result1.access_token == new_token
266
224
267
-
# second refresh attempt should skip network call
268
225
result2 = await _refresh_session_tokens(
269
226
mock_auth_session, mock_oauth_session
270
227
)
271
228
assert result2.access_token == new_token
272
229
273
-
# OAuth client should have been called exactly once
274
230
assert refresh_call_count == 1
231
+
232
+
233
+
class TestRefreshLocksCache:
234
+
"""test _refresh_locks cache behavior (memory leak prevention)."""
235
+
236
+
def test_same_session_returns_same_lock(self):
237
+
"""same session_id should return the same lock instance."""
238
+
_refresh_locks.clear()
239
+
240
+
_refresh_locks["session-a"] = asyncio.Lock()
241
+
lock1 = _refresh_locks["session-a"]
242
+
lock2 = _refresh_locks["session-a"]
243
+
244
+
assert lock1 is lock2
245
+
246
+
def test_different_sessions_have_different_locks(self):
247
+
"""different session_ids should have different lock instances."""
248
+
_refresh_locks.clear()
249
+
250
+
_refresh_locks["session-a"] = asyncio.Lock()
251
+
_refresh_locks["session-b"] = asyncio.Lock()
252
+
253
+
assert _refresh_locks["session-a"] is not _refresh_locks["session-b"]
254
+
255
+
def test_cache_is_bounded_by_maxsize(self):
256
+
"""cache should evict entries when full (LRU behavior)."""
257
+
_refresh_locks.clear()
258
+
259
+
assert _refresh_locks.maxsize == 10_000
260
+
261
+
for i in range(100):
262
+
_refresh_locks[f"session-{i}"] = asyncio.Lock()
263
+
264
+
assert len(_refresh_locks) == 100
265
+
266
+
def test_lru_eviction_order(self):
267
+
"""LRU cache should evict least recently used entries first."""
268
+
small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3)
269
+
270
+
small_cache["a"] = asyncio.Lock()
271
+
small_cache["b"] = asyncio.Lock()
272
+
small_cache["c"] = asyncio.Lock()
273
+
274
+
_ = small_cache["a"]
275
+
small_cache["d"] = asyncio.Lock()
276
+
277
+
assert "a" in small_cache
278
+
assert "b" not in small_cache
279
+
assert "c" in small_cache
280
+
assert "d" in small_cache