audio streaming app plyr.fm
38
fork

Configure Feed

Select the types of activity you want to include in your feed.

at feat/admin-flag-count 274 lines 9.9 kB view raw
1"""tests for concurrent token refresh locking.""" 2 3import asyncio 4from unittest.mock import patch 5 6import pytest 7from atproto_oauth.models import OAuthSession 8 9from backend._internal import Session as AuthSession 10from backend._internal.atproto.client import _refresh_session_tokens 11 12 13@pytest.fixture 14def mock_auth_session() -> AuthSession: 15 """create mock auth session.""" 16 # generate a real EC key and serialize it 17 import cryptography.hazmat.backends 18 import cryptography.hazmat.primitives.asymmetric.ec as ec 19 from cryptography.hazmat.primitives import serialization 20 21 private_key = ec.generate_private_key( 22 ec.SECP256R1(), cryptography.hazmat.backends.default_backend() 23 ) 24 25 dpop_key_pem = private_key.private_bytes( 26 encoding=serialization.Encoding.PEM, 27 format=serialization.PrivateFormat.PKCS8, 28 encryption_algorithm=serialization.NoEncryption(), 29 ).decode("utf-8") 30 31 return AuthSession( 32 session_id="test-session-123", 33 did="did:plc:test123", 34 handle="test.bsky.social", 35 oauth_session={ 36 "did": "did:plc:test123", 37 "handle": "test.bsky.social", 38 "pds_url": "https://pds.test", 39 "authserver_iss": "https://auth.test", 40 "scope": "atproto transition:generic", 41 "access_token": "old-token", 42 "refresh_token": "refresh-token", 43 "dpop_private_key_pem": dpop_key_pem, 44 "dpop_authserver_nonce": "nonce1", 45 "dpop_pds_nonce": "nonce2", 46 }, 47 ) 48 49 50@pytest.fixture 51def mock_oauth_session() -> OAuthSession: 52 """create mock oauth session.""" 53 # defer cryptography import to avoid overhead 54 import cryptography.hazmat.backends 55 import cryptography.hazmat.primitives.asymmetric.ec as ec 56 57 # generate a real key for the mock 58 private_key = ec.generate_private_key( 59 ec.SECP256R1(), cryptography.hazmat.backends.default_backend() 60 ) 61 62 return OAuthSession( 63 did="did:plc:test123", 64 handle="test.bsky.social", 65 pds_url="https://pds.test", 66 authserver_iss="https://auth.test", 67 access_token="old-token", 68 refresh_token="refresh-token", 69 dpop_private_key=private_key, 70 dpop_authserver_nonce="nonce1", 71 dpop_pds_nonce="nonce2", 72 scope="atproto transition:generic", 73 ) 74 75 76class TestConcurrentTokenRefresh: 77 """test concurrent token refresh race condition handling.""" 78 79 async def test_concurrent_refresh_only_calls_once( 80 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 81 ): 82 """test that concurrent refresh attempts only call the OAuth client once.""" 83 refresh_call_count = 0 84 new_token = "new-refreshed-token" 85 86 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 87 """mock OAuth client refresh with delay to simulate race.""" 88 nonlocal refresh_call_count 89 refresh_call_count += 1 90 91 # simulate network delay 92 await asyncio.sleep(0.1) 93 94 # return updated session with new token 95 session.access_token = new_token 96 return session 97 98 async def mock_get_session(session_id: str) -> AuthSession | None: 99 """mock get_session that returns updated tokens after first refresh.""" 100 if refresh_call_count > 0: 101 # first refresh completed - return new token 102 mock_auth_session.oauth_session["access_token"] = new_token 103 return mock_auth_session 104 105 async def mock_update_session_tokens( 106 session_id: str, oauth_session_data: dict 107 ) -> None: 108 """mock session update.""" 109 mock_auth_session.oauth_session.update(oauth_session_data) 110 111 # create a mock OAuth client with the refresh method 112 mock_oauth_client = type( 113 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 114 )() 115 116 with ( 117 patch( 118 "backend._internal.atproto.client.get_oauth_client", 119 return_value=mock_oauth_client, 120 ), 121 patch( 122 "backend._internal.atproto.client.get_session", 123 side_effect=mock_get_session, 124 ), 125 patch( 126 "backend._internal.atproto.client.update_session_tokens", 127 side_effect=mock_update_session_tokens, 128 ), 129 ): 130 # launch 5 concurrent refresh attempts 131 tasks = [ 132 _refresh_session_tokens(mock_auth_session, mock_oauth_session) 133 for _ in range(5) 134 ] 135 results = await asyncio.gather(*tasks) 136 137 # all should succeed and get the new token 138 assert all(result.access_token == new_token for result in results) 139 140 # but OAuth client should only be called once (the lock worked!) 141 assert refresh_call_count == 1 142 143 async def test_refresh_failure_uses_fallback( 144 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 145 ): 146 """test that on refresh failure, retries with reload from DB.""" 147 new_token = "new-refreshed-token" 148 refresh_called = False 149 150 async def mock_refresh_session_fails( 151 self, session: OAuthSession 152 ) -> OAuthSession: 153 """mock refresh that always fails.""" 154 nonlocal refresh_called 155 refresh_called = True 156 await asyncio.sleep(0.05) 157 raise Exception("500 server_error from PDS") 158 159 get_session_calls = 0 160 161 async def mock_get_session(session_id: str) -> AuthSession | None: 162 """mock get_session that returns updated tokens on retry.""" 163 nonlocal get_session_calls 164 get_session_calls += 1 165 166 # on retry (after failure), return new tokens as if another request succeeded 167 if get_session_calls >= 2: 168 mock_auth_session.oauth_session["access_token"] = new_token 169 170 return mock_auth_session 171 172 async def mock_update_session_tokens( 173 session_id: str, oauth_session_data: dict 174 ) -> None: 175 """mock session update.""" 176 mock_auth_session.oauth_session.update(oauth_session_data) 177 178 # create a mock OAuth client with the failing refresh method 179 mock_oauth_client = type( 180 "MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails} 181 )() 182 183 with ( 184 patch( 185 "backend._internal.atproto.client.get_oauth_client", 186 return_value=mock_oauth_client, 187 ), 188 patch( 189 "backend._internal.atproto.client.get_session", 190 side_effect=mock_get_session, 191 ), 192 patch( 193 "backend._internal.atproto.client.update_session_tokens", 194 side_effect=mock_update_session_tokens, 195 ), 196 ): 197 # this should fail to refresh but succeed via fallback 198 result = await _refresh_session_tokens( 199 mock_auth_session, mock_oauth_session 200 ) 201 202 # verify it tried to refresh 203 assert refresh_called 204 205 # verify it fell back to reloaded tokens 206 assert result.access_token == new_token 207 208 async def test_second_request_skips_refresh_if_already_done( 209 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 210 ): 211 """test that second request sees new token and skips refresh.""" 212 refresh_call_count = 0 213 new_token = "already-refreshed-token" 214 215 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 216 """mock OAuth client refresh.""" 217 nonlocal refresh_call_count 218 refresh_call_count += 1 219 await asyncio.sleep(0.1) 220 session.access_token = new_token 221 return session 222 223 get_session_calls = 0 224 225 async def mock_get_session(session_id: str) -> AuthSession | None: 226 """mock get_session that simulates first refresh completing quickly.""" 227 nonlocal get_session_calls 228 get_session_calls += 1 229 230 # on second+ call, act like refresh already happened 231 if get_session_calls > 1: 232 mock_auth_session.oauth_session["access_token"] = new_token 233 234 return mock_auth_session 235 236 async def mock_update_session_tokens( 237 session_id: str, oauth_session_data: dict 238 ) -> None: 239 """mock session update.""" 240 mock_auth_session.oauth_session.update(oauth_session_data) 241 242 # create a mock OAuth client with the refresh method 243 mock_oauth_client = type( 244 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 245 )() 246 247 with ( 248 patch( 249 "backend._internal.atproto.client.get_oauth_client", 250 return_value=mock_oauth_client, 251 ), 252 patch( 253 "backend._internal.atproto.client.get_session", 254 side_effect=mock_get_session, 255 ), 256 patch( 257 "backend._internal.atproto.client.update_session_tokens", 258 side_effect=mock_update_session_tokens, 259 ), 260 ): 261 # first refresh 262 result1 = await _refresh_session_tokens( 263 mock_auth_session, mock_oauth_session 264 ) 265 assert result1.access_token == new_token 266 267 # second refresh attempt should skip network call 268 result2 = await _refresh_session_tokens( 269 mock_auth_session, mock_oauth_session 270 ) 271 assert result2.access_token == new_token 272 273 # OAuth client should have been called exactly once 274 assert refresh_call_count == 1