audio streaming app
plyr.fm
1"""tests for concurrent token refresh locking."""
2
3import asyncio
4from unittest.mock import patch
5
6import pytest
7from atproto_oauth.models import OAuthSession
8
9from backend._internal import Session as AuthSession
10from backend._internal.atproto.client import _refresh_session_tokens
11
12
13@pytest.fixture
14def mock_auth_session() -> AuthSession:
15 """create mock auth session."""
16 # generate a real EC key and serialize it
17 import cryptography.hazmat.backends
18 import cryptography.hazmat.primitives.asymmetric.ec as ec
19 from cryptography.hazmat.primitives import serialization
20
21 private_key = ec.generate_private_key(
22 ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
23 )
24
25 dpop_key_pem = private_key.private_bytes(
26 encoding=serialization.Encoding.PEM,
27 format=serialization.PrivateFormat.PKCS8,
28 encryption_algorithm=serialization.NoEncryption(),
29 ).decode("utf-8")
30
31 return AuthSession(
32 session_id="test-session-123",
33 did="did:plc:test123",
34 handle="test.bsky.social",
35 oauth_session={
36 "did": "did:plc:test123",
37 "handle": "test.bsky.social",
38 "pds_url": "https://pds.test",
39 "authserver_iss": "https://auth.test",
40 "scope": "atproto transition:generic",
41 "access_token": "old-token",
42 "refresh_token": "refresh-token",
43 "dpop_private_key_pem": dpop_key_pem,
44 "dpop_authserver_nonce": "nonce1",
45 "dpop_pds_nonce": "nonce2",
46 },
47 )
48
49
50@pytest.fixture
51def mock_oauth_session() -> OAuthSession:
52 """create mock oauth session."""
53 # defer cryptography import to avoid overhead
54 import cryptography.hazmat.backends
55 import cryptography.hazmat.primitives.asymmetric.ec as ec
56
57 # generate a real key for the mock
58 private_key = ec.generate_private_key(
59 ec.SECP256R1(), cryptography.hazmat.backends.default_backend()
60 )
61
62 return OAuthSession(
63 did="did:plc:test123",
64 handle="test.bsky.social",
65 pds_url="https://pds.test",
66 authserver_iss="https://auth.test",
67 access_token="old-token",
68 refresh_token="refresh-token",
69 dpop_private_key=private_key,
70 dpop_authserver_nonce="nonce1",
71 dpop_pds_nonce="nonce2",
72 scope="atproto transition:generic",
73 )
74
75
76class TestConcurrentTokenRefresh:
77 """test concurrent token refresh race condition handling."""
78
79 async def test_concurrent_refresh_only_calls_once(
80 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession
81 ):
82 """test that concurrent refresh attempts only call the OAuth client once."""
83 refresh_call_count = 0
84 new_token = "new-refreshed-token"
85
86 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
87 """mock OAuth client refresh with delay to simulate race."""
88 nonlocal refresh_call_count
89 refresh_call_count += 1
90
91 # simulate network delay
92 await asyncio.sleep(0.1)
93
94 # return updated session with new token
95 session.access_token = new_token
96 return session
97
98 async def mock_get_session(session_id: str) -> AuthSession | None:
99 """mock get_session that returns updated tokens after first refresh."""
100 if refresh_call_count > 0:
101 # first refresh completed - return new token
102 mock_auth_session.oauth_session["access_token"] = new_token
103 return mock_auth_session
104
105 async def mock_update_session_tokens(
106 session_id: str, oauth_session_data: dict
107 ) -> None:
108 """mock session update."""
109 mock_auth_session.oauth_session.update(oauth_session_data)
110
111 # create a mock OAuth client with the refresh method
112 mock_oauth_client = type(
113 "MockOAuthClient", (), {"refresh_session": mock_refresh_session}
114 )()
115
116 with (
117 patch(
118 "backend._internal.atproto.client.get_oauth_client",
119 return_value=mock_oauth_client,
120 ),
121 patch(
122 "backend._internal.atproto.client.get_session",
123 side_effect=mock_get_session,
124 ),
125 patch(
126 "backend._internal.atproto.client.update_session_tokens",
127 side_effect=mock_update_session_tokens,
128 ),
129 ):
130 # launch 5 concurrent refresh attempts
131 tasks = [
132 _refresh_session_tokens(mock_auth_session, mock_oauth_session)
133 for _ in range(5)
134 ]
135 results = await asyncio.gather(*tasks)
136
137 # all should succeed and get the new token
138 assert all(result.access_token == new_token for result in results)
139
140 # but OAuth client should only be called once (the lock worked!)
141 assert refresh_call_count == 1
142
143 async def test_refresh_failure_uses_fallback(
144 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession
145 ):
146 """test that on refresh failure, retries with reload from DB."""
147 new_token = "new-refreshed-token"
148 refresh_called = False
149
150 async def mock_refresh_session_fails(
151 self, session: OAuthSession
152 ) -> OAuthSession:
153 """mock refresh that always fails."""
154 nonlocal refresh_called
155 refresh_called = True
156 await asyncio.sleep(0.05)
157 raise Exception("500 server_error from PDS")
158
159 get_session_calls = 0
160
161 async def mock_get_session(session_id: str) -> AuthSession | None:
162 """mock get_session that returns updated tokens on retry."""
163 nonlocal get_session_calls
164 get_session_calls += 1
165
166 # on retry (after failure), return new tokens as if another request succeeded
167 if get_session_calls >= 2:
168 mock_auth_session.oauth_session["access_token"] = new_token
169
170 return mock_auth_session
171
172 async def mock_update_session_tokens(
173 session_id: str, oauth_session_data: dict
174 ) -> None:
175 """mock session update."""
176 mock_auth_session.oauth_session.update(oauth_session_data)
177
178 # create a mock OAuth client with the failing refresh method
179 mock_oauth_client = type(
180 "MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails}
181 )()
182
183 with (
184 patch(
185 "backend._internal.atproto.client.get_oauth_client",
186 return_value=mock_oauth_client,
187 ),
188 patch(
189 "backend._internal.atproto.client.get_session",
190 side_effect=mock_get_session,
191 ),
192 patch(
193 "backend._internal.atproto.client.update_session_tokens",
194 side_effect=mock_update_session_tokens,
195 ),
196 ):
197 # this should fail to refresh but succeed via fallback
198 result = await _refresh_session_tokens(
199 mock_auth_session, mock_oauth_session
200 )
201
202 # verify it tried to refresh
203 assert refresh_called
204
205 # verify it fell back to reloaded tokens
206 assert result.access_token == new_token
207
208 async def test_second_request_skips_refresh_if_already_done(
209 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession
210 ):
211 """test that second request sees new token and skips refresh."""
212 refresh_call_count = 0
213 new_token = "already-refreshed-token"
214
215 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession:
216 """mock OAuth client refresh."""
217 nonlocal refresh_call_count
218 refresh_call_count += 1
219 await asyncio.sleep(0.1)
220 session.access_token = new_token
221 return session
222
223 get_session_calls = 0
224
225 async def mock_get_session(session_id: str) -> AuthSession | None:
226 """mock get_session that simulates first refresh completing quickly."""
227 nonlocal get_session_calls
228 get_session_calls += 1
229
230 # on second+ call, act like refresh already happened
231 if get_session_calls > 1:
232 mock_auth_session.oauth_session["access_token"] = new_token
233
234 return mock_auth_session
235
236 async def mock_update_session_tokens(
237 session_id: str, oauth_session_data: dict
238 ) -> None:
239 """mock session update."""
240 mock_auth_session.oauth_session.update(oauth_session_data)
241
242 # create a mock OAuth client with the refresh method
243 mock_oauth_client = type(
244 "MockOAuthClient", (), {"refresh_session": mock_refresh_session}
245 )()
246
247 with (
248 patch(
249 "backend._internal.atproto.client.get_oauth_client",
250 return_value=mock_oauth_client,
251 ),
252 patch(
253 "backend._internal.atproto.client.get_session",
254 side_effect=mock_get_session,
255 ),
256 patch(
257 "backend._internal.atproto.client.update_session_tokens",
258 side_effect=mock_update_session_tokens,
259 ),
260 ):
261 # first refresh
262 result1 = await _refresh_session_tokens(
263 mock_auth_session, mock_oauth_session
264 )
265 assert result1.access_token == new_token
266
267 # second refresh attempt should skip network call
268 result2 = await _refresh_session_tokens(
269 mock_auth_session, mock_oauth_session
270 )
271 assert result2.access_token == new_token
272
273 # OAuth client should have been called exactly once
274 assert refresh_call_count == 1