at main 9.8 kB view raw
1"""tests for concurrent token refresh locking.""" 2 3import asyncio 4from unittest.mock import patch 5 6import pytest 7from atproto_oauth.models import OAuthSession 8from cachetools import LRUCache 9from cryptography.hazmat.backends import default_backend 10from cryptography.hazmat.primitives import serialization 11from cryptography.hazmat.primitives.asymmetric import ec 12 13from backend._internal import Session as AuthSession 14from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens 15 16 17@pytest.fixture 18def mock_auth_session() -> AuthSession: 19 """create mock auth session.""" 20 private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) 21 22 dpop_key_pem = private_key.private_bytes( 23 encoding=serialization.Encoding.PEM, 24 format=serialization.PrivateFormat.PKCS8, 25 encryption_algorithm=serialization.NoEncryption(), 26 ).decode("utf-8") 27 28 return AuthSession( 29 session_id="test-session-123", 30 did="did:plc:test123", 31 handle="test.bsky.social", 32 oauth_session={ 33 "did": "did:plc:test123", 34 "handle": "test.bsky.social", 35 "pds_url": "https://pds.test", 36 "authserver_iss": "https://auth.test", 37 "scope": "atproto transition:generic", 38 "access_token": "old-token", 39 "refresh_token": "refresh-token", 40 "dpop_private_key_pem": dpop_key_pem, 41 "dpop_authserver_nonce": "nonce1", 42 "dpop_pds_nonce": "nonce2", 43 }, 44 ) 45 46 47@pytest.fixture 48def mock_oauth_session() -> OAuthSession: 49 """create mock oauth session.""" 50 private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) 51 52 return OAuthSession( 53 did="did:plc:test123", 54 handle="test.bsky.social", 55 pds_url="https://pds.test", 56 authserver_iss="https://auth.test", 57 access_token="old-token", 58 refresh_token="refresh-token", 59 dpop_private_key=private_key, 60 dpop_authserver_nonce="nonce1", 61 dpop_pds_nonce="nonce2", 62 scope="atproto transition:generic", 63 ) 64 65 66class TestConcurrentTokenRefresh: 67 """test concurrent token refresh race condition handling.""" 68 69 async def test_concurrent_refresh_only_calls_once( 70 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 71 ): 72 """test that concurrent refresh attempts only call the OAuth client once.""" 73 refresh_call_count = 0 74 new_token = "new-refreshed-token" 75 76 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 77 nonlocal refresh_call_count 78 refresh_call_count += 1 79 await asyncio.sleep(0.1) 80 session.access_token = new_token 81 return session 82 83 async def mock_get_session(session_id: str) -> AuthSession | None: 84 if refresh_call_count > 0: 85 mock_auth_session.oauth_session["access_token"] = new_token 86 return mock_auth_session 87 88 async def mock_update_session_tokens( 89 session_id: str, oauth_session_data: dict 90 ) -> None: 91 mock_auth_session.oauth_session.update(oauth_session_data) 92 93 mock_oauth_client = type( 94 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 95 )() 96 97 with ( 98 patch( 99 "backend._internal.atproto.client.get_oauth_client", 100 return_value=mock_oauth_client, 101 ), 102 patch( 103 "backend._internal.atproto.client.get_session", 104 side_effect=mock_get_session, 105 ), 106 patch( 107 "backend._internal.atproto.client.update_session_tokens", 108 side_effect=mock_update_session_tokens, 109 ), 110 ): 111 tasks = [ 112 _refresh_session_tokens(mock_auth_session, mock_oauth_session) 113 for _ in range(5) 114 ] 115 results = await asyncio.gather(*tasks) 116 117 assert all(result.access_token == new_token for result in results) 118 assert refresh_call_count == 1 119 120 async def test_refresh_failure_uses_fallback( 121 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 122 ): 123 """test that on refresh failure, retries with reload from DB.""" 124 new_token = "new-refreshed-token" 125 refresh_called = False 126 127 async def mock_refresh_session_fails( 128 self, session: OAuthSession 129 ) -> OAuthSession: 130 nonlocal refresh_called 131 refresh_called = True 132 await asyncio.sleep(0.05) 133 raise Exception("500 server_error from PDS") 134 135 get_session_calls = 0 136 137 async def mock_get_session(session_id: str) -> AuthSession | None: 138 nonlocal get_session_calls 139 get_session_calls += 1 140 if get_session_calls >= 2: 141 mock_auth_session.oauth_session["access_token"] = new_token 142 return mock_auth_session 143 144 async def mock_update_session_tokens( 145 session_id: str, oauth_session_data: dict 146 ) -> None: 147 mock_auth_session.oauth_session.update(oauth_session_data) 148 149 mock_oauth_client = type( 150 "MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails} 151 )() 152 153 with ( 154 patch( 155 "backend._internal.atproto.client.get_oauth_client", 156 return_value=mock_oauth_client, 157 ), 158 patch( 159 "backend._internal.atproto.client.get_session", 160 side_effect=mock_get_session, 161 ), 162 patch( 163 "backend._internal.atproto.client.update_session_tokens", 164 side_effect=mock_update_session_tokens, 165 ), 166 ): 167 result = await _refresh_session_tokens( 168 mock_auth_session, mock_oauth_session 169 ) 170 171 assert refresh_called 172 assert result.access_token == new_token 173 174 async def test_second_request_skips_refresh_if_already_done( 175 self, mock_auth_session: AuthSession, mock_oauth_session: OAuthSession 176 ): 177 """test that second request sees new token and skips refresh.""" 178 refresh_call_count = 0 179 new_token = "already-refreshed-token" 180 181 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 182 nonlocal refresh_call_count 183 refresh_call_count += 1 184 await asyncio.sleep(0.1) 185 session.access_token = new_token 186 return session 187 188 get_session_calls = 0 189 190 async def mock_get_session(session_id: str) -> AuthSession | None: 191 nonlocal get_session_calls 192 get_session_calls += 1 193 if get_session_calls > 1: 194 mock_auth_session.oauth_session["access_token"] = new_token 195 return mock_auth_session 196 197 async def mock_update_session_tokens( 198 session_id: str, oauth_session_data: dict 199 ) -> None: 200 mock_auth_session.oauth_session.update(oauth_session_data) 201 202 mock_oauth_client = type( 203 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 204 )() 205 206 with ( 207 patch( 208 "backend._internal.atproto.client.get_oauth_client", 209 return_value=mock_oauth_client, 210 ), 211 patch( 212 "backend._internal.atproto.client.get_session", 213 side_effect=mock_get_session, 214 ), 215 patch( 216 "backend._internal.atproto.client.update_session_tokens", 217 side_effect=mock_update_session_tokens, 218 ), 219 ): 220 result1 = await _refresh_session_tokens( 221 mock_auth_session, mock_oauth_session 222 ) 223 assert result1.access_token == new_token 224 225 result2 = await _refresh_session_tokens( 226 mock_auth_session, mock_oauth_session 227 ) 228 assert result2.access_token == new_token 229 230 assert refresh_call_count == 1 231 232 233class TestRefreshLocksCache: 234 """test _refresh_locks cache behavior (memory leak prevention).""" 235 236 def test_same_session_returns_same_lock(self): 237 """same session_id should return the same lock instance.""" 238 _refresh_locks.clear() 239 240 _refresh_locks["session-a"] = asyncio.Lock() 241 lock1 = _refresh_locks["session-a"] 242 lock2 = _refresh_locks["session-a"] 243 244 assert lock1 is lock2 245 246 def test_different_sessions_have_different_locks(self): 247 """different session_ids should have different lock instances.""" 248 _refresh_locks.clear() 249 250 _refresh_locks["session-a"] = asyncio.Lock() 251 _refresh_locks["session-b"] = asyncio.Lock() 252 253 assert _refresh_locks["session-a"] is not _refresh_locks["session-b"] 254 255 def test_cache_is_bounded_by_maxsize(self): 256 """cache should evict entries when full (LRU behavior).""" 257 _refresh_locks.clear() 258 259 assert _refresh_locks.maxsize == 10_000 260 261 for i in range(100): 262 _refresh_locks[f"session-{i}"] = asyncio.Lock() 263 264 assert len(_refresh_locks) == 100 265 266 def test_lru_eviction_order(self): 267 """LRU cache should evict least recently used entries first.""" 268 small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3) 269 270 small_cache["a"] = asyncio.Lock() 271 small_cache["b"] = asyncio.Lock() 272 small_cache["c"] = asyncio.Lock() 273 274 _ = small_cache["a"] 275 small_cache["d"] = asyncio.Lock() 276 277 assert "a" in small_cache 278 assert "b" not in small_cache 279 assert "c" in small_cache 280 assert "d" in small_cache