+6
-2
backend/src/backend/_internal/atproto/client.py
+6
-2
backend/src/backend/_internal/atproto/client.py
···
6
from typing import Any
7
8
from atproto_oauth.models import OAuthSession
9
10
from backend._internal import Session as AuthSession
11
from backend._internal import get_oauth_client, get_session, update_session_tokens
12
13
logger = logging.getLogger(__name__)
14
15
-
# per-session locks for token refresh to prevent concurrent refresh races
16
-
_refresh_locks: dict[str, asyncio.Lock] = {}
17
18
19
def reconstruct_oauth_session(oauth_data: dict[str, Any]) -> OAuthSession:
···
6
from typing import Any
7
8
from atproto_oauth.models import OAuthSession
9
+
from cachetools import LRUCache
10
11
from backend._internal import Session as AuthSession
12
from backend._internal import get_oauth_client, get_session, update_session_tokens
13
14
logger = logging.getLogger(__name__)
15
16
+
# per-session locks for token refresh to prevent concurrent refresh races.
17
+
# uses LRUCache (not TTLCache) to bound memory - LRU eviction is safe because:
18
+
# 1. recently-used locks won't be evicted while in use
19
+
# 2. TTL expiration could evict a lock while a coroutine holds it, breaking mutual exclusion
20
+
_refresh_locks: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=10_000)
21
22
23
def reconstruct_oauth_session(oauth_data: dict[str, Any]) -> OAuthSession:
+57
-51
backend/tests/test_token_refresh.py
+57
-51
backend/tests/test_token_refresh.py
···
5
6
import pytest
7
from atproto_oauth.models import OAuthSession
8
9
from backend._internal import Session as AuthSession
10
-
from backend._internal.atproto.client import _refresh_session_tokens
11
12
13
@pytest.fixture
14
def mock_auth_session() -> AuthSession:
15
"""create mock auth session."""
16
-
# generate a real EC key and serialize it
17
-
import cryptography.hazmat.backends
18
-
import cryptography.hazmat.primitives.asymmetric.ec as ec
19
-
from cryptography.hazmat.primitives import serialization
20
-
21
-
private_key = ec.generate_private_key(
22
-
ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
23
-
)
24
25
dpop_key_pem = private_key.private_bytes(
26
encoding=serialization.Encoding.PEM,
···
50
@pytest.fixture
51
def mock_oauth_session() -> OAuthSession:
52
"""create mock oauth session."""
53
-
# defer cryptography import to avoid overhead
54
-
import cryptography.hazmat.backends
55
-
import cryptography.hazmat.primitives.asymmetric.ec as ec
56
-
57
-
# generate a real key for the mock
58
-
private_key = ec.generate_private_key(
59
-
ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
60
-
)
61
62
return OAuthSession(
63
did="did:plc:test123",
···
84
new_token = "new-refreshed-token"
85
86
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
87
-
"""mock OAuth client refresh with delay to simulate race."""
88
nonlocal refresh_call_count
89
refresh_call_count += 1
90
-
91
-
# simulate network delay
92
await asyncio.sleep(0.1)
93
-
94
-
# return updated session with new token
95
session.access_token = new_token
96
return session
97
98
async def mock_get_session(session_id: str) -> AuthSession | None:
99
-
"""mock get_session that returns updated tokens after first refresh."""
100
if refresh_call_count > 0:
101
-
# first refresh completed - return new token
102
mock_auth_session.oauth_session["access_token"] = new_token
103
return mock_auth_session
104
105
async def mock_update_session_tokens(
106
session_id: str, oauth_session_data: dict
107
) -> None:
108
-
"""mock session update."""
109
mock_auth_session.oauth_session.update(oauth_session_data)
110
111
-
# create a mock OAuth client with the refresh method
112
mock_oauth_client = type(
113
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
114
)()
···
127
side_effect=mock_update_session_tokens,
128
),
129
):
130
-
# launch 5 concurrent refresh attempts
131
tasks = [
132
_refresh_session_tokens(mock_auth_session, mock_oauth_session)
133
for _ in range(5)
134
]
135
results = await asyncio.gather(*tasks)
136
137
-
# all should succeed and get the new token
138
assert all(result.access_token == new_token for result in results)
139
-
140
-
# but OAuth client should only be called once (the lock worked!)
141
assert refresh_call_count == 1
142
143
async def test_refresh_failure_uses_fallback(
···
150
async def mock_refresh_session_fails(
151
self, session: OAuthSession
152
) -> OAuthSession:
153
-
"""mock refresh that always fails."""
154
nonlocal refresh_called
155
refresh_called = True
156
await asyncio.sleep(0.05)
···
159
get_session_calls = 0
160
161
async def mock_get_session(session_id: str) -> AuthSession | None:
162
-
"""mock get_session that returns updated tokens on retry."""
163
nonlocal get_session_calls
164
get_session_calls += 1
165
-
166
-
# on retry (after failure), return new tokens as if another request succeeded
167
if get_session_calls >= 2:
168
mock_auth_session.oauth_session["access_token"] = new_token
169
-
170
return mock_auth_session
171
172
async def mock_update_session_tokens(
173
session_id: str, oauth_session_data: dict
174
) -> None:
175
-
"""mock session update."""
176
mock_auth_session.oauth_session.update(oauth_session_data)
177
178
-
# create a mock OAuth client with the failing refresh method
179
mock_oauth_client = type(
180
"MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails}
181
)()
···
194
side_effect=mock_update_session_tokens,
195
),
196
):
197
-
# this should fail to refresh but succeed via fallback
198
result = await _refresh_session_tokens(
199
mock_auth_session, mock_oauth_session
200
)
201
202
-
# verify it tried to refresh
203
assert refresh_called
204
-
205
-
# verify it fell back to reloaded tokens
206
assert result.access_token == new_token
207
208
async def test_second_request_skips_refresh_if_already_done(
···
213
new_token = "already-refreshed-token"
214
215
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
216
-
"""mock OAuth client refresh."""
217
nonlocal refresh_call_count
218
refresh_call_count += 1
219
await asyncio.sleep(0.1)
···
223
get_session_calls = 0
224
225
async def mock_get_session(session_id: str) -> AuthSession | None:
226
-
"""mock get_session that simulates first refresh completing quickly."""
227
nonlocal get_session_calls
228
get_session_calls += 1
229
-
230
-
# on second+ call, act like refresh already happened
231
if get_session_calls > 1:
232
mock_auth_session.oauth_session["access_token"] = new_token
233
-
234
return mock_auth_session
235
236
async def mock_update_session_tokens(
237
session_id: str, oauth_session_data: dict
238
) -> None:
239
-
"""mock session update."""
240
mock_auth_session.oauth_session.update(oauth_session_data)
241
242
-
# create a mock OAuth client with the refresh method
243
mock_oauth_client = type(
244
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
245
)()
···
258
side_effect=mock_update_session_tokens,
259
),
260
):
261
-
# first refresh
262
result1 = await _refresh_session_tokens(
263
mock_auth_session, mock_oauth_session
264
)
265
assert result1.access_token == new_token
266
267
-
# second refresh attempt should skip network call
268
result2 = await _refresh_session_tokens(
269
mock_auth_session, mock_oauth_session
270
)
271
assert result2.access_token == new_token
272
273
-
# OAuth client should have been called exactly once
274
assert refresh_call_count == 1
···
5
6
import pytest
7
from atproto_oauth.models import OAuthSession
8
+
from cachetools import LRUCache
9
+
from cryptography.hazmat.backends import default_backend
10
+
from cryptography.hazmat.primitives import serialization
11
+
from cryptography.hazmat.primitives.asymmetric import ec
12
13
from backend._internal import Session as AuthSession
14
+
from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens
15
16
17
@pytest.fixture
18
def mock_auth_session() -> AuthSession:
19
"""create mock auth session."""
20
+
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
21
22
dpop_key_pem = private_key.private_bytes(
23
encoding=serialization.Encoding.PEM,
···
47
@pytest.fixture
48
def mock_oauth_session() -> OAuthSession:
49
"""create mock oauth session."""
50
+
private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
51
52
return OAuthSession(
53
did="did:plc:test123",
···
74
new_token = "new-refreshed-token"
75
76
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
77
nonlocal refresh_call_count
78
refresh_call_count += 1
79
await asyncio.sleep(0.1)
80
session.access_token = new_token
81
return session
82
83
async def mock_get_session(session_id: str) -> AuthSession | None:
84
if refresh_call_count > 0:
85
mock_auth_session.oauth_session["access_token"] = new_token
86
return mock_auth_session
87
88
async def mock_update_session_tokens(
89
session_id: str, oauth_session_data: dict
90
) -> None:
91
mock_auth_session.oauth_session.update(oauth_session_data)
92
93
mock_oauth_client = type(
94
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
95
)()
···
108
side_effect=mock_update_session_tokens,
109
),
110
):
111
tasks = [
112
_refresh_session_tokens(mock_auth_session, mock_oauth_session)
113
for _ in range(5)
114
]
115
results = await asyncio.gather(*tasks)
116
117
assert all(result.access_token == new_token for result in results)
118
assert refresh_call_count == 1
119
120
async def test_refresh_failure_uses_fallback(
···
127
async def mock_refresh_session_fails(
128
self, session: OAuthSession
129
) -> OAuthSession:
130
nonlocal refresh_called
131
refresh_called = True
132
await asyncio.sleep(0.05)
···
135
get_session_calls = 0
136
137
async def mock_get_session(session_id: str) -> AuthSession | None:
138
nonlocal get_session_calls
139
get_session_calls += 1
140
if get_session_calls >= 2:
141
mock_auth_session.oauth_session["access_token"] = new_token
142
return mock_auth_session
143
144
async def mock_update_session_tokens(
145
session_id: str, oauth_session_data: dict
146
) -> None:
147
mock_auth_session.oauth_session.update(oauth_session_data)
148
149
mock_oauth_client = type(
150
"MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails}
151
)()
···
164
side_effect=mock_update_session_tokens,
165
),
166
):
167
result = await _refresh_session_tokens(
168
mock_auth_session, mock_oauth_session
169
)
170
171
assert refresh_called
172
assert result.access_token == new_token
173
174
async def test_second_request_skips_refresh_if_already_done(
···
179
new_token = "already-refreshed-token"
180
181
async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
182
nonlocal refresh_call_count
183
refresh_call_count += 1
184
await asyncio.sleep(0.1)
···
188
get_session_calls = 0
189
190
async def mock_get_session(session_id: str) -> AuthSession | None:
191
nonlocal get_session_calls
192
get_session_calls += 1
193
if get_session_calls > 1:
194
mock_auth_session.oauth_session["access_token"] = new_token
195
return mock_auth_session
196
197
async def mock_update_session_tokens(
198
session_id: str, oauth_session_data: dict
199
) -> None:
200
mock_auth_session.oauth_session.update(oauth_session_data)
201
202
mock_oauth_client = type(
203
"MockOAuthClient", (), {"refresh_session": mock_refresh_session}
204
)()
···
217
side_effect=mock_update_session_tokens,
218
),
219
):
220
result1 = await _refresh_session_tokens(
221
mock_auth_session, mock_oauth_session
222
)
223
assert result1.access_token == new_token
224
225
result2 = await _refresh_session_tokens(
226
mock_auth_session, mock_oauth_session
227
)
228
assert result2.access_token == new_token
229
230
assert refresh_call_count == 1
231
+
232
+
233
+
class TestRefreshLocksCache:
234
+
"""test _refresh_locks cache behavior (memory leak prevention)."""
235
+
236
+
def test_same_session_returns_same_lock(self):
237
+
"""same session_id should return the same lock instance."""
238
+
_refresh_locks.clear()
239
+
240
+
_refresh_locks["session-a"] = asyncio.Lock()
241
+
lock1 = _refresh_locks["session-a"]
242
+
lock2 = _refresh_locks["session-a"]
243
+
244
+
assert lock1 is lock2
245
+
246
+
def test_different_sessions_have_different_locks(self):
247
+
"""different session_ids should have different lock instances."""
248
+
_refresh_locks.clear()
249
+
250
+
_refresh_locks["session-a"] = asyncio.Lock()
251
+
_refresh_locks["session-b"] = asyncio.Lock()
252
+
253
+
assert _refresh_locks["session-a"] is not _refresh_locks["session-b"]
254
+
255
+
def test_cache_is_bounded_by_maxsize(self):
256
+
"""cache should evict entries when full (LRU behavior)."""
257
+
_refresh_locks.clear()
258
+
259
+
assert _refresh_locks.maxsize == 10_000
260
+
261
+
for i in range(100):
262
+
_refresh_locks[f"session-{i}"] = asyncio.Lock()
263
+
264
+
assert len(_refresh_locks) == 100
265
+
266
+
def test_lru_eviction_order(self):
267
+
"""LRU cache should evict least recently used entries first."""
268
+
small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3)
269
+
270
+
small_cache["a"] = asyncio.Lock()
271
+
small_cache["b"] = asyncio.Lock()
272
+
small_cache["c"] = asyncio.Lock()
273
+
274
+
_ = small_cache["a"]
275
+
small_cache["d"] = asyncio.Lock()
276
+
277
+
assert "a" in small_cache
278
+
assert "b" not in small_cache
279
+
assert "c" in small_cache
280
+
assert "d" in small_cache