music on atproto
plyr.fm
1"""tests for concurrent token refresh locking."""
2
3import asyncio
4from unittest.mock import patch
5
6import pytest
7from atproto_oauth.models import OAuthSession
8from cachetools import LRUCache
9from cryptography.hazmat.backends import default_backend
10from cryptography.hazmat.primitives import serialization
11from cryptography.hazmat.primitives.asymmetric import ec
12
13from backend._internal import Session as AuthSession
14from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens
15
16
17@pytest.fixture
18def mock_auth_session() -> AuthSession:
19 """create mock auth session."""
20 private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
21
22 dpop_key_pem = private_key.private_bytes(
23 encoding=serialization.Encoding.PEM,
24 format=serialization.PrivateFormat.PKCS8,
25 encryption_algorithm=serialization.NoEncryption(),
26 ).decode("utf-8")
27
28 return AuthSession(
29 session_id="test-session-123",
30 did="did:plc:test123",
31 handle="test.bsky.social",
32 oauth_session={
33 "did": "did:plc:test123",
34 "handle": "test.bsky.social",
35 "pds_url": "https://pds.test",
36 "authserver_iss": "https://auth.test",
37 "scope": "atproto transition:generic",
38 "access_token": "old-token",
39 "refresh_token": "refresh-token",
40 "dpop_private_key_pem": dpop_key_pem,
41 "dpop_authserver_nonce": "nonce1",
42 "dpop_pds_nonce": "nonce2",
43 },
44 )
45
46
47@pytest.fixture
48def mock_oauth_session() -> OAuthSession:
49 """create mock oauth session."""
50 private_key = ec.generate_private_key(ec.SECP256R1(), default_backend())
51
52 return OAuthSession(
53 did="did:plc:test123",
54 handle="test.bsky.social",
55 pds_url="https://pds.test",
56 authserver_iss="https://auth.test",
57 access_token="old-token",
58 refresh_token="refresh-token",
59 dpop_private_key=private_key,
60 dpop_authserver_nonce="nonce1",
61 dpop_pds_nonce="nonce2",
62 scope="atproto transition:generic",
63 )
64
65
66class TestConcurrentTokenRefresh:
67 """test concurrent token refresh race condition handling."""
68
69 async def test_concurrent_refresh_only_calls_once(
70 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession
71 ):
72 """test that concurrent refresh attempts only call the OAuth client once."""
73 refresh_call_count = 0
74 new_token = "new-refreshed-token"
75
76 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
77 nonlocal refresh_call_count
78 refresh_call_count += 1
79 await asyncio.sleep(0.1)
80 session.access_token = new_token
81 return session
82
83 async def mock_get_session(session_id: str) -> AuthSession | None:
84 if refresh_call_count > 0:
85 mock_auth_session.oauth_session["access_token"] = new_token
86 return mock_auth_session
87
88 async def mock_update_session_tokens(
89 session_id: str, oauth_session_data: dict
90 ) -> None:
91 mock_auth_session.oauth_session.update(oauth_session_data)
92
93 mock_oauth_client = type(
94 "MockOAuthClient", (), {"refresh_session": mock_refresh_session}
95 )()
96
97 with (
98 patch(
99 "backend._internal.atproto.client.get_oauth_client",
100 return_value=mock_oauth_client,
101 ),
102 patch(
103 "backend._internal.atproto.client.get_session",
104 side_effect=mock_get_session,
105 ),
106 patch(
107 "backend._internal.atproto.client.update_session_tokens",
108 side_effect=mock_update_session_tokens,
109 ),
110 ):
111 tasks = [
112 _refresh_session_tokens(mock_auth_session, mock_oauth_session)
113 for _ in range(5)
114 ]
115 results = await asyncio.gather(*tasks)
116
117 assert all(result.access_token == new_token for result in results)
118 assert refresh_call_count == 1
119
120 async def test_refresh_failure_uses_fallback(
121 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession
122 ):
123 """test that on refresh failure, retries with reload from DB."""
124 new_token = "new-refreshed-token"
125 refresh_called = False
126
127 async def mock_refresh_session_fails(
128 self, session: OAuthSession
129 ) -> OAuthSession:
130 nonlocal refresh_called
131 refresh_called = True
132 await asyncio.sleep(0.05)
133 raise Exception("500 server_error from PDS")
134
135 get_session_calls = 0
136
137 async def mock_get_session(session_id: str) -> AuthSession | None:
138 nonlocal get_session_calls
139 get_session_calls += 1
140 if get_session_calls >= 2:
141 mock_auth_session.oauth_session["access_token"] = new_token
142 return mock_auth_session
143
144 async def mock_update_session_tokens(
145 session_id: str, oauth_session_data: dict
146 ) -> None:
147 mock_auth_session.oauth_session.update(oauth_session_data)
148
149 mock_oauth_client = type(
150 "MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails}
151 )()
152
153 with (
154 patch(
155 "backend._internal.atproto.client.get_oauth_client",
156 return_value=mock_oauth_client,
157 ),
158 patch(
159 "backend._internal.atproto.client.get_session",
160 side_effect=mock_get_session,
161 ),
162 patch(
163 "backend._internal.atproto.client.update_session_tokens",
164 side_effect=mock_update_session_tokens,
165 ),
166 ):
167 result = await _refresh_session_tokens(
168 mock_auth_session, mock_oauth_session
169 )
170
171 assert refresh_called
172 assert result.access_token == new_token
173
174 async def test_second_request_skips_refresh_if_already_done(
175 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession
176 ):
177 """test that second request sees new token and skips refresh."""
178 refresh_call_count = 0
179 new_token = "already-refreshed-token"
180
181 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
182 nonlocal refresh_call_count
183 refresh_call_count += 1
184 await asyncio.sleep(0.1)
185 session.access_token = new_token
186 return session
187
188 get_session_calls = 0
189
190 async def mock_get_session(session_id: str) -> AuthSession | None:
191 nonlocal get_session_calls
192 get_session_calls += 1
193 if get_session_calls > 1:
194 mock_auth_session.oauth_session["access_token"] = new_token
195 return mock_auth_session
196
197 async def mock_update_session_tokens(
198 session_id: str, oauth_session_data: dict
199 ) -> None:
200 mock_auth_session.oauth_session.update(oauth_session_data)
201
202 mock_oauth_client = type(
203 "MockOAuthClient", (), {"refresh_session": mock_refresh_session}
204 )()
205
206 with (
207 patch(
208 "backend._internal.atproto.client.get_oauth_client",
209 return_value=mock_oauth_client,
210 ),
211 patch(
212 "backend._internal.atproto.client.get_session",
213 side_effect=mock_get_session,
214 ),
215 patch(
216 "backend._internal.atproto.client.update_session_tokens",
217 side_effect=mock_update_session_tokens,
218 ),
219 ):
220 result1 = await _refresh_session_tokens(
221 mock_auth_session, mock_oauth_session
222 )
223 assert result1.access_token == new_token
224
225 result2 = await _refresh_session_tokens(
226 mock_auth_session, mock_oauth_session
227 )
228 assert result2.access_token == new_token
229
230 assert refresh_call_count == 1
231
232
233class TestRefreshLocksCache:
234 """test _refresh_locks cache behavior (memory leak prevention)."""
235
236 def test_same_session_returns_same_lock(self):
237 """same session_id should return the same lock instance."""
238 _refresh_locks.clear()
239
240 _refresh_locks["session-a"] = asyncio.Lock()
241 lock1 = _refresh_locks["session-a"]
242 lock2 = _refresh_locks["session-a"]
243
244 assert lock1 is lock2
245
246 def test_different_sessions_have_different_locks(self):
247 """different session_ids should have different lock instances."""
248 _refresh_locks.clear()
249
250 _refresh_locks["session-a"] = asyncio.Lock()
251 _refresh_locks["session-b"] = asyncio.Lock()
252
253 assert _refresh_locks["session-a"] is not _refresh_locks["session-b"]
254
255 def test_cache_is_bounded_by_maxsize(self):
256 """cache should evict entries when full (LRU behavior)."""
257 _refresh_locks.clear()
258
259 assert _refresh_locks.maxsize == 10_000
260
261 for i in range(100):
262 _refresh_locks[f"session-{i}"] = asyncio.Lock()
263
264 assert len(_refresh_locks) == 100
265
266 def test_lru_eviction_order(self):
267 """LRU cache should evict least recently used entries first."""
268 small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3)
269
270 small_cache["a"] = asyncio.Lock()
271 small_cache["b"] = asyncio.Lock()
272 small_cache["c"] = asyncio.Lock()
273
274 _ = small_cache["a"]
275 small_cache["d"] = asyncio.Lock()
276
277 assert "a" in small_cache
278 assert "b" not in small_cache
279 assert "c" in small_cache
280 assert "d" in small_cache