Merge pull request #710 from zzstoatzz/fix/refresh-locks-memory-leak

fix: prevent unbounded memory growth in token refresh locks

authored by zzstoatzz.io and committed by GitHub 2c5c8f23 31c33167

Changed files
+63 -53
backend
src
backend
_internal
atproto
tests
+6 -2
backend/src/backend/_internal/atproto/client.py
··· 6 6 from typing import Any 7 7 8 8 from atproto_oauth.models import OAuthSession 9 + from cachetools import LRUCache 9 10 10 11 from backend._internal import Session as AuthSession 11 12 from backend._internal import get_oauth_client, get_session, update_session_tokens 12 13 13 14 logger = logging.getLogger(__name__) 14 15 15 - # per-session locks for token refresh to prevent concurrent refresh races 16 - _refresh_locks: dict[str, asyncio.Lock] = {} 16 + # per-session locks for token refresh to prevent concurrent refresh races. 17 + # uses LRUCache (not TTLCache) to bound memory - LRU eviction is safe because: 18 + # 1. recently-used locks won't be evicted while in use 19 + # 2. TTL expiration could evict a lock while a coroutine holds it, breaking mutual exclusion 20 + _refresh_locks: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=10_000) 17 21 18 22 19 23 def reconstruct_oauth_session(oauth_data: dict[str, Any]) -> OAuthSession:
+57 -51
backend/tests/test_token_refresh.py
··· 5 5 6 6 import pytest 7 7 from atproto_oauth.models import OAuthSession 8 + from cachetools import LRUCache 9 + from cryptography.hazmat.backends import default_backend 10 + from cryptography.hazmat.primitives import serialization 11 + from cryptography.hazmat.primitives.asymmetric import ec 8 12 9 13 from backend._internal import Session as AuthSession 10 - from backend._internal.atproto.client import _refresh_session_tokens 14 + from backend._internal.atproto.client import _refresh_locks, _refresh_session_tokens 11 15 12 16 13 17 @pytest.fixture 14 18 def mock_auth_session() -> AuthSession: 15 19 """create mock auth session.""" 16 - # generate a real EC key and serialize it 17 - import cryptography.hazmat.backends 18 - import cryptography.hazmat.primitives.asymmetric.ec as ec 19 - from cryptography.hazmat.primitives import serialization 20 - 21 - private_key = ec.generate_private_key( 22 - ec.SECP256R1(), cryptography.hazmat.backends.default_backend() 23 - ) 20 + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) 24 21 25 22 dpop_key_pem = private_key.private_bytes( 26 23 encoding=serialization.Encoding.PEM, ··· 50 47 @pytest.fixture 51 48 def mock_oauth_session() -> OAuthSession: 52 49 """create mock oauth session.""" 53 - # defer cryptography import to avoid overhead 54 - import cryptography.hazmat.backends 55 - import cryptography.hazmat.primitives.asymmetric.ec as ec 56 - 57 - # generate a real key for the mock 58 - private_key = ec.generate_private_key( 59 - ec.SECP256R1(), cryptography.hazmat.backends.default_backend() 60 - ) 50 + private_key = ec.generate_private_key(ec.SECP256R1(), default_backend()) 61 51 62 52 return OAuthSession( 63 53 did="did:plc:test123", ··· 84 74 new_token = "new-refreshed-token" 85 75 86 76 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 87 - """mock OAuth client refresh with delay to simulate race.""" 88 77 nonlocal refresh_call_count 89 78 refresh_call_count += 1 90 - 91 - # simulate network delay 92 79 await asyncio.sleep(0.1) 93 - 94 - # return updated session with new token 95 80 session.access_token = new_token 96 81 return session 97 82 98 83 async def mock_get_session(session_id: str) -> AuthSession | None: 99 - """mock get_session that returns updated tokens after first refresh.""" 100 84 if refresh_call_count > 0: 101 - # first refresh completed - return new token 102 85 mock_auth_session.oauth_session["access_token"] = new_token 103 86 return mock_auth_session 104 87 105 88 async def mock_update_session_tokens( 106 89 session_id: str, oauth_session_data: dict 107 90 ) -> None: 108 - """mock session update.""" 109 91 mock_auth_session.oauth_session.update(oauth_session_data) 110 92 111 - # create a mock OAuth client with the refresh method 112 93 mock_oauth_client = type( 113 94 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 114 95 )() ··· 127 108 side_effect=mock_update_session_tokens, 128 109 ), 129 110 ): 130 - # launch 5 concurrent refresh attempts 131 111 tasks = [ 132 112 _refresh_session_tokens(mock_auth_session, mock_oauth_session) 133 113 for _ in range(5) 134 114 ] 135 115 results = await asyncio.gather(*tasks) 136 116 137 - # all should succeed and get the new token 138 117 assert all(result.access_token == new_token for result in results) 139 - 140 - # but OAuth client should only be called once (the lock worked!) 141 118 assert refresh_call_count == 1 142 119 143 120 async def test_refresh_failure_uses_fallback( ··· 150 127 async def mock_refresh_session_fails( 151 128 self, session: OAuthSession 152 129 ) -> OAuthSession: 153 - """mock refresh that always fails.""" 154 130 nonlocal refresh_called 155 131 refresh_called = True 156 132 await asyncio.sleep(0.05) ··· 159 135 get_session_calls = 0 160 136 161 137 async def mock_get_session(session_id: str) -> AuthSession | None: 162 - """mock get_session that returns updated tokens on retry.""" 163 138 nonlocal get_session_calls 164 139 get_session_calls += 1 165 - 166 - # on retry (after failure), return new tokens as if another request succeeded 167 140 if get_session_calls >= 2: 168 141 mock_auth_session.oauth_session["access_token"] = new_token 169 - 170 142 return mock_auth_session 171 143 172 144 async def mock_update_session_tokens( 173 145 session_id: str, oauth_session_data: dict 174 146 ) -> None: 175 - """mock session update.""" 176 147 mock_auth_session.oauth_session.update(oauth_session_data) 177 148 178 - # create a mock OAuth client with the failing refresh method 179 149 mock_oauth_client = type( 180 150 "MockOAuthClient", (), {"refresh_session": mock_refresh_session_fails} 181 151 )() ··· 194 164 side_effect=mock_update_session_tokens, 195 165 ), 196 166 ): 197 - # this should fail to refresh but succeed via fallback 198 167 result = await _refresh_session_tokens( 199 168 mock_auth_session, mock_oauth_session 200 169 ) 201 170 202 - # verify it tried to refresh 203 171 assert refresh_called 204 - 205 - # verify it fell back to reloaded tokens 206 172 assert result.access_token == new_token 207 173 208 174 async def test_second_request_skips_refresh_if_already_done( ··· 213 179 new_token = "already-refreshed-token" 214 180 215 181 async def mock_refresh_session(self, session: OAuthSession) -> OAuthSession: 216 - """mock OAuth client refresh.""" 217 182 nonlocal refresh_call_count 218 183 refresh_call_count += 1 219 184 await asyncio.sleep(0.1) ··· 223 188 get_session_calls = 0 224 189 225 190 async def mock_get_session(session_id: str) -> AuthSession | None: 226 - """mock get_session that simulates first refresh completing quickly.""" 227 191 nonlocal get_session_calls 228 192 get_session_calls += 1 229 - 230 - # on second+ call, act like refresh already happened 231 193 if get_session_calls > 1: 232 194 mock_auth_session.oauth_session["access_token"] = new_token 233 - 234 195 return mock_auth_session 235 196 236 197 async def mock_update_session_tokens( 237 198 session_id: str, oauth_session_data: dict 238 199 ) -> None: 239 - """mock session update.""" 240 200 mock_auth_session.oauth_session.update(oauth_session_data) 241 201 242 - # create a mock OAuth client with the refresh method 243 202 mock_oauth_client = type( 244 203 "MockOAuthClient", (), {"refresh_session": mock_refresh_session} 245 204 )() ··· 258 217 side_effect=mock_update_session_tokens, 259 218 ), 260 219 ): 261 - # first refresh 262 220 result1 = await _refresh_session_tokens( 263 221 mock_auth_session, mock_oauth_session 264 222 ) 265 223 assert result1.access_token == new_token 266 224 267 - # second refresh attempt should skip network call 268 225 result2 = await _refresh_session_tokens( 269 226 mock_auth_session, mock_oauth_session 270 227 ) 271 228 assert result2.access_token == new_token 272 229 273 - # OAuth client should have been called exactly once 274 230 assert refresh_call_count == 1 231 + 232 + 233 + class TestRefreshLocksCache: 234 + """test _refresh_locks cache behavior (memory leak prevention).""" 235 + 236 + def test_same_session_returns_same_lock(self): 237 + """same session_id should return the same lock instance.""" 238 + _refresh_locks.clear() 239 + 240 + _refresh_locks["session-a"] = asyncio.Lock() 241 + lock1 = _refresh_locks["session-a"] 242 + lock2 = _refresh_locks["session-a"] 243 + 244 + assert lock1 is lock2 245 + 246 + def test_different_sessions_have_different_locks(self): 247 + """different session_ids should have different lock instances.""" 248 + _refresh_locks.clear() 249 + 250 + _refresh_locks["session-a"] = asyncio.Lock() 251 + _refresh_locks["session-b"] = asyncio.Lock() 252 + 253 + assert _refresh_locks["session-a"] is not _refresh_locks["session-b"] 254 + 255 + def test_cache_is_bounded_by_maxsize(self): 256 + """cache should evict entries when full (LRU behavior).""" 257 + _refresh_locks.clear() 258 + 259 + assert _refresh_locks.maxsize == 10_000 260 + 261 + for i in range(100): 262 + _refresh_locks[f"session-{i}"] = asyncio.Lock() 263 + 264 + assert len(_refresh_locks) == 100 265 + 266 + def test_lru_eviction_order(self): 267 + """LRU cache should evict least recently used entries first.""" 268 + small_cache: LRUCache[str, asyncio.Lock] = LRUCache(maxsize=3) 269 + 270 + small_cache["a"] = asyncio.Lock() 271 + small_cache["b"] = asyncio.Lock() 272 + small_cache["c"] = asyncio.Lock() 273 + 274 + _ = small_cache["a"] 275 + small_cache["d"] = asyncio.Lock() 276 + 277 + assert "a" in small_cache 278 + assert "b" not in small_cache 279 + assert "c" in small_cache 280 + assert "d" in small_cache