audio streaming app plyr.fm
38
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 35ba4e175332b35076a5fd46ddf9c89fad5c8457 280 lines 9.8 kB view raw
1"""tests for concurrent token refresh locking.""" 2 3import asyncio 4from unittest.mock import patch 5 6import pytest 7from atproto_oauth.models import OAuthSession 8from cachetools import LRUCache 9from cryptography.hazmat.backends import default_backend 10from cryptography.hazmat.primitives import serialization 11from cryptography.hazmat.primitives.asymmetric import ec 12 13from backend._internal import Session as AuthSession 14from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens 15 16 17@pytest.fixture 18def mock_auth_session() -> AuthSession: 19 """create mock auth session.""" 20 private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) 21 22 dpop_key_pem = private_key.private_bytes( 23 encoding=serialization.Encoding.PEM, 24 format=serialization.PrivateFormat.PKCS8, 25 encryption_algorithm=serialization.NoEncryption(), 26 ).decode("utf-8") 27 28 return AuthSession( 29 session_id="test-session-123", 30 did="did:plc:test123", 31 handle="test.bsky.social", 32 oauth_session={ 33 "did": "did:plc:test123", 34 "handle": "test.bsky.social", 35 "pds_url": "https://pds.test", 36 "authserver_iss": "https://auth.test", 37 "scope": "atproto transition:generic", 38 "access_token": "old-token", 39 "refresh_token": "refresh-token", 40 "dpop_private_key_pem": dpop_key_pem, 41 "dpop_authserver_nonce": "nonce1", 42 "dpop_pds_nonce": "nonce2", 43 }, 44 ) 45 46 47@pytest.fixture 48def mock_oauth_session() -> OAuthSession: 49 """create mock oauth session.""" 50 private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) 51 52 return OAuthSession( 53 did="did:plc:test123", 54 handle="test.bsky.social", 55 pds_url="https://pds.test", 56 authserver_iss="https://auth.test", 57 access_token="old-token", 58 refresh_token="refresh-token", 59 dpop_private_key=private_key, 60 dpop_authserver_nonce="nonce1", 61 dpop_pds_nonce="nonce2", 62 scope="atproto transition:generic", 63 ) 64 65 66class TestConcurrentTokenRefresh: 67 """test concurrent token refresh race condition handling.""" 68 69 async def test_concurrent_refresh_only_calls_once( 70 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 71 ): 72 """test that concurrent refresh attempts only call the OAuth client once.""" 73 refresh_call_count = 0 74 new_token = "new-refreshed-token" 75 76 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 77 nonlocal refresh_call_count 78 refresh_call_count += 1 79 await asyncio.sleep(0.1) 80 session.access_token = new_token 81 return session 82 83 async def mock_get_session(session_id: str) -> AuthSession | None: 84 if refresh_call_count > 0: 85 mock_auth_session.oauth_session["access_token"] = new_token 86 return mock_auth_session 87 88 async def mock_update_session_tokens( 89 session_id: str, oauth_session_data: dict 90 ) -> None: 91 mock_auth_session.oauth_session.update(oauth_session_data) 92 93 mock_oauth_client = type( 94 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 95 )() 96 97 with ( 98 patch( 99 "backend._internal.atproto.client.get_oauth_client", 100 return_value=mock_oauth_client, 101 ), 102 patch( 103 "backend._internal.atproto.client.get_session", 104 side_effect=mock_get_session, 105 ), 106 patch( 107 "backend._internal.atproto.client.update_session_tokens", 108 side_effect=mock_update_session_tokens, 109 ), 110 ): 111 tasks = [ 112 _refresh_session_tokens(mock_auth_session, mock_oauth_session) 113 for _ in range(5) 114 ] 115 results = await asyncio.gather(*tasks) 116 117 assert all(result.access_token == new_token for result in results) 118 assert refresh_call_count == 1 119 120 async def test_refresh_failure_uses_fallback( 121 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 122 ): 123 """test that on refresh failure, retries with reload from DB.""" 124 new_token = "new-refreshed-token" 125 refresh_called = False 126 127 async def mock_refresh_session_fails( 128 self, session: OAuthSession 129 ) -> OAuthSession: 130 nonlocal refresh_called 131 refresh_called = True 132 await asyncio.sleep(0.05) 133 raise Exception("500 server_error from PDS") 134 135 get_session_calls = 0 136 137 async def mock_get_session(session_id: str) -> AuthSession | None: 138 nonlocal get_session_calls 139 get_session_calls += 1 140 if get_session_calls >= 2: 141 mock_auth_session.oauth_session["access_token"] = new_token 142 return mock_auth_session 143 144 async def mock_update_session_tokens( 145 session_id: str, oauth_session_data: dict 146 ) -> None: 147 mock_auth_session.oauth_session.update(oauth_session_data) 148 149 mock_oauth_client = type( 150 "MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails} 151 )() 152 153 with ( 154 patch( 155 "backend._internal.atproto.client.get_oauth_client", 156 return_value=mock_oauth_client, 157 ), 158 patch( 159 "backend._internal.atproto.client.get_session", 160 side_effect=mock_get_session, 161 ), 162 patch( 163 "backend._internal.atproto.client.update_session_tokens", 164 side_effect=mock_update_session_tokens, 165 ), 166 ): 167 result = await _refresh_session_tokens( 168 mock_auth_session, mock_oauth_session 169 ) 170 171 assert refresh_called 172 assert result.access_token == new_token 173 174 async def test_second_request_skips_refresh_if_already_done( 175 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 176 ): 177 """test that second request sees new token and skips refresh.""" 178 refresh_call_count = 0 179 new_token = "already-refreshed-token" 180 181 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 182 nonlocal refresh_call_count 183 refresh_call_count += 1 184 await asyncio.sleep(0.1) 185 session.access_token = new_token 186 return session 187 188 get_session_calls = 0 189 190 async def mock_get_session(session_id: str) -> AuthSession | None: 191 nonlocal get_session_calls 192 get_session_calls += 1 193 if get_session_calls > 1: 194 mock_auth_session.oauth_session["access_token"] = new_token 195 return mock_auth_session 196 197 async def mock_update_session_tokens( 198 session_id: str, oauth_session_data: dict 199 ) -> None: 200 mock_auth_session.oauth_session.update(oauth_session_data) 201 202 mock_oauth_client = type( 203 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 204 )() 205 206 with ( 207 patch( 208 "backend._internal.atproto.client.get_oauth_client", 209 return_value=mock_oauth_client, 210 ), 211 patch( 212 "backend._internal.atproto.client.get_session", 213 side_effect=mock_get_session, 214 ), 215 patch( 216 "backend._internal.atproto.client.update_session_tokens", 217 side_effect=mock_update_session_tokens, 218 ), 219 ): 220 result1 = await _refresh_session_tokens( 221 mock_auth_session, mock_oauth_session 222 ) 223 assert result1.access_token == new_token 224 225 result2 = await _refresh_session_tokens( 226 mock_auth_session, mock_oauth_session 227 ) 228 assert result2.access_token == new_token 229 230 assert refresh_call_count == 1 231 232 233class TestRefreshLocksCache: 234 """test _refresh_locks cache behavior (memory leak prevention).""" 235 236 def test_same_session_returns_same_lock(self): 237 """same session_id should return the same lock instance.""" 238 _refresh_locks.clear() 239 240 _refresh_locks["session-a"] = asyncio.Lock() 241 lock1 = _refresh_locks["session-a"] 242 lock2 = _refresh_locks["session-a"] 243 244 assert lock1 is lock2 245 246 def test_different_sessions_have_different_locks(self): 247 """different session_ids should have different lock instances.""" 248 _refresh_locks.clear() 249 250 _refresh_locks["session-a"] = asyncio.Lock() 251 _refresh_locks["session-b"] = asyncio.Lock() 252 253 assert _refresh_locks["session-a"] is not _refresh_locks["session-b"] 254 255 def test_cache_is_bounded_by_maxsize(self): 256 """cache should evict entries when full (LRU behavior).""" 257 _refresh_locks.clear() 258 259 assert _refresh_locks.maxsize == 10_000 260 261 for i in range(100): 262 _refresh_locks[f"session-{i}"] = asyncio.Lock() 263 264 assert len(_refresh_locks) == 100 265 266 def test_lru_eviction_order(self): 267 """LRU cache should evict least recently used entries first.""" 268 small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3) 269 270 small_cache["a"] = asyncio.Lock() 271 small_cache["b"] = asyncio.Lock() 272 small_cache["c"] = asyncio.Lock() 273 274 _ = small_cache["a"] 275 small_cache["d"] = asyncio.Lock() 276 277 assert "a" in small_cache 278 assert "b" not in small_cache 279 assert "c" in small_cache 280 assert "d" in small_cache