Merge pull request #734 from zzstoatzz/fix/auth-token-mismatch

Align session expiry with refresh token lifetime

authored by nate nowack and committed by GitHub 5fb8a757 ea8d67d8

Changed files
+159 -31
backend
src
backend
_internal
atproto
api
tests
docs
+15
backend/src/backend/_internal/atproto/client.py
··· 3 3 import asyncio 4 4 import json 5 5 import logging 6 + from datetime import UTC, datetime, timedelta 6 7 from typing import Any 7 8 8 9 from atproto_oauth.models import OAuthSession ··· 10 11 11 12 from backend._internal import Session as AuthSession 12 13 from backend._internal import get_oauth_client, get_session, update_session_tokens 14 + from backend._internal.auth import ( 15 + get_client_auth_method, 16 + get_refresh_token_lifetime_days, 17 + ) 13 18 14 19 logger = logging.getLogger(__name__) 15 20 ··· 123 128 "dpop_authserver_nonce": refreshed_session.dpop_authserver_nonce, 124 129 "dpop_pds_nonce": refreshed_session.dpop_pds_nonce or "", 125 130 } 131 + client_auth_method = get_client_auth_method(updated_oauth_data) 132 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 133 + refresh_expires_at = datetime.now(UTC) + timedelta( 134 + days=refresh_lifetime_days 135 + ) 136 + updated_session_data["client_auth_method"] = client_auth_method 137 + updated_session_data["refresh_token_lifetime_days"] = refresh_lifetime_days 138 + updated_session_data["refresh_token_expires_at"] = ( 139 + refresh_expires_at.isoformat() 140 + ) 126 141 127 142 # update session in database 128 143 await update_session_tokens(session_id, updated_session_data)
+113 -9
backend/src/backend/_internal/auth.py
··· 25 25 26 26 logger = logging.getLogger(__name__) 27 27 28 + PUBLIC_REFRESH_TOKEN_DAYS = 14 29 + CONFIDENTIAL_REFRESH_TOKEN_DAYS = 180 30 + 28 31 29 32 def _parse_scopes(scope_string: str) -> set[str]: 30 33 """parse an OAuth scope string into a set of individual scopes. ··· 163 166 return bool(settings.atproto.oauth_jwk) 164 167 165 168 169 + def get_client_auth_method(oauth_session_data: dict[str, Any] | None = None) -> str: 170 + """resolve client auth method for a session.""" 171 + if oauth_session_data: 172 + method = oauth_session_data.get("client_auth_method") 173 + if method in {"public", "confidential"}: 174 + return method 175 + return "confidential" if is_confidential_client() else "public" 176 + 177 + 178 + def get_refresh_token_lifetime_days(client_auth_method: str | None) -> int: 179 + """get expected refresh token lifetime in days.""" 180 + method = client_auth_method or get_client_auth_method() 181 + return ( 182 + CONFIDENTIAL_REFRESH_TOKEN_DAYS 183 + if method == "confidential" 184 + else PUBLIC_REFRESH_TOKEN_DAYS 185 + ) 186 + 187 + 188 + def _compute_refresh_token_expires_at( 189 + now: datetime, client_auth_method: str | None 190 + ) -> datetime: 191 + """compute refresh token expiration time.""" 192 + return now + timedelta(days=get_refresh_token_lifetime_days(client_auth_method)) 193 + 194 + 195 + def _parse_datetime(value: str | None) -> datetime | None: 196 + """parse ISO datetime string safely.""" 197 + if not value: 198 + return None 199 + try: 200 + return datetime.fromisoformat(value) 201 + except ValueError: 202 + return None 203 + 204 + 205 + def _get_refresh_token_expires_at( 206 + user_session: UserSession, 207 + oauth_session_data: dict[str, Any], 208 + ) -> datetime | None: 209 + """determine refresh token expiry for a session.""" 210 + parsed = _parse_datetime(oauth_session_data.get("refresh_token_expires_at")) 211 + if parsed: 212 + return parsed 213 + 214 + client_auth_method = oauth_session_data.get("client_auth_method") 215 + if client_auth_method: 216 + return user_session.created_at + timedelta( 217 + days=get_refresh_token_lifetime_days(client_auth_method) 218 + ) 219 + 220 + if user_session.is_developer_token: 221 + return user_session.created_at + timedelta(days=PUBLIC_REFRESH_TOKEN_DAYS) 222 + 223 + return None 224 + 225 + 166 226 def get_oauth_client(include_teal: bool = False) -> OAuthClient: 167 227 """create an OAuth client with the appropriate scopes. 168 228 ··· 249 309 did: user's decentralized identifier 250 310 handle: user's ATProto handle 251 311 oauth_session: OAuth session data to encrypt and store 252 - expires_in_days: session expiration in days (default 14, use 0 for no expiration) 312 + expires_in_days: session expiration in days (default 14, capped by refresh lifetime) 253 313 is_developer_token: whether this is a developer token (for listing/revocation) 254 314 token_name: optional name for the token (only for developer tokens) 255 315 group_id: optional session group ID for multi-account support 256 316 """ 257 317 session_id = secrets.token_urlsafe(32) 318 + now = datetime.now(UTC) 258 319 259 - encrypted_data = _encrypt_data(json.dumps(oauth_session)) 320 + client_auth_method = get_client_auth_method(oauth_session) 321 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 322 + refresh_expires_at = _compute_refresh_token_expires_at(now, client_auth_method) 260 323 261 - expires_at = ( 262 - datetime.now(UTC) + timedelta(days=expires_in_days) 263 - if expires_in_days > 0 264 - else None 324 + oauth_session = dict(oauth_session) 325 + oauth_session.setdefault("client_auth_method", client_auth_method) 326 + oauth_session.setdefault("refresh_token_lifetime_days", refresh_lifetime_days) 327 + oauth_session.setdefault("refresh_token_expires_at", refresh_expires_at.isoformat()) 328 + 329 + effective_days = ( 330 + refresh_lifetime_days 331 + if expires_in_days <= 0 332 + else min(expires_in_days, refresh_lifetime_days) 265 333 ) 334 + expires_at = now + timedelta(days=effective_days) 335 + 336 + encrypted_data = _encrypt_data(json.dumps(oauth_session)) 266 337 267 338 async with db_session() as db: 268 339 user_session = UserSession( ··· 301 372 if decrypted_data is None: 302 373 # decryption failed - session is invalid (key changed or data corrupted) 303 374 # delete the corrupted session 375 + await delete_session(session_id) 376 + return None 377 + 378 + oauth_session_data = json.loads(decrypted_data) 379 + 380 + refresh_expires_at = _get_refresh_token_expires_at( 381 + user_session, oauth_session_data 382 + ) 383 + if refresh_expires_at and datetime.now(UTC) > refresh_expires_at: 304 384 await delete_session(session_id) 305 385 return None 306 386 ··· 308 388 session_id=user_session.session_id, 309 389 did=user_session.did, 310 390 handle=user_session.handle, 311 - oauth_session=json.loads(decrypted_data), 391 + oauth_session=oauth_session_data, 312 392 ) 313 393 314 394 ··· 445 525 encryption_algorithm=serialization.NoEncryption(), 446 526 ).decode("utf-8") 447 527 528 + client_auth_method = get_client_auth_method() 529 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 530 + refresh_expires_at = _compute_refresh_token_expires_at( 531 + datetime.now(UTC), client_auth_method 532 + ) 533 + 448 534 # store full OAuth session with tokens in database 449 535 session_data = { 450 536 "did": oauth_session.did, ··· 457 543 "dpop_private_key_pem": dpop_key_pem, 458 544 "dpop_authserver_nonce": oauth_session.dpop_authserver_nonce, 459 545 "dpop_pds_nonce": oauth_session.dpop_pds_nonce or "", 546 + "client_auth_method": client_auth_method, 547 + "refresh_token_lifetime_days": refresh_lifetime_days, 548 + "refresh_token_expires_at": refresh_expires_at.isoformat(), 460 549 } 461 550 return oauth_session.did, oauth_session.handle, session_data 462 551 except Exception as e: ··· 658 747 sessions = result.scalars().all() 659 748 660 749 tokens = [] 750 + now = datetime.now(UTC) 661 751 for session in sessions: 752 + decrypted_data = _decrypt_data(session.oauth_session_data) 753 + oauth_session_data = ( 754 + json.loads(decrypted_data) if decrypted_data is not None else {} 755 + ) 756 + refresh_expires_at = _get_refresh_token_expires_at( 757 + session, oauth_session_data 758 + ) 759 + effective_expires_at = session.expires_at 760 + if refresh_expires_at and ( 761 + effective_expires_at is None 762 + or refresh_expires_at < effective_expires_at 763 + ): 764 + effective_expires_at = refresh_expires_at 765 + 662 766 # check if expired 663 - if session.expires_at and datetime.now(UTC) > session.expires_at: 767 + if effective_expires_at and now > effective_expires_at: 664 768 continue # skip expired tokens 665 769 666 770 tokens.append( ··· 668 772 session_id=session.session_id, 669 773 token_name=session.token_name, 670 774 created_at=session.created_at, 671 - expires_at=session.expires_at, 775 + expires_at=effective_expires_at, 672 776 ) 673 777 ) 674 778
+6 -1
backend/src/backend/api/auth.py
··· 35 35 start_oauth_flow_with_scopes, 36 36 switch_active_account, 37 37 ) 38 + from backend._internal.auth import get_refresh_token_lifetime_days 38 39 from backend._internal.background_tasks import schedule_atproto_sync 39 40 from backend.config import settings 40 41 from backend.models import Artist, get_db ··· 466 467 if expires_in_days > max_days: 467 468 raise HTTPException( 468 469 status_code=400, 469 - detail=f"expires_in_days cannot exceed {max_days} (use 0 for no expiration)", 470 + detail=f"expires_in_days cannot exceed {max_days}", 470 471 ) 472 + 473 + refresh_lifetime_days = get_refresh_token_lifetime_days(None) 474 + if expires_in_days <= 0 or expires_in_days > refresh_lifetime_days: 475 + expires_in_days = refresh_lifetime_days 471 476 472 477 # start OAuth flow using the user's handle 473 478 auth_url, state = await start_oauth_flow(session.handle)
+1 -1
backend/src/backend/config.py
··· 666 666 667 667 developer_token_default_days: int = Field( 668 668 default=90, 669 - description="Default expiration in days for developer tokens (0 = no expiration)", 669 + description="Default expiration in days for developer tokens (capped by refresh lifetime)", 670 670 ) 671 671 developer_token_max_days: int = Field( 672 672 default=365,
+14 -6
backend/tests/test_auth.py
··· 17 17 create_session, 18 18 delete_session, 19 19 get_public_jwks, 20 + get_refresh_token_lifetime_days, 20 21 get_session, 21 22 is_confidential_client, 22 23 update_session_tokens, ··· 259 260 260 261 261 262 async def test_create_session_with_custom_expiration(db_session: AsyncSession): 262 - """verify session creation with custom expiration works.""" 263 + """verify session creation with custom expiration is capped by refresh lifetime.""" 263 264 did = "did:plc:customexp123" 264 265 handle = "customexp.bsky.social" 265 266 oauth_data = {"access_token": "token", "refresh_token": "refresh"} ··· 280 281 assert db_session_record is not None 281 282 assert db_session_record.expires_at is not None 282 283 283 - # should expire roughly 30 days from now 284 - expected_expiry = datetime.now(UTC) + timedelta(days=30) 284 + expected_days = min(30, get_refresh_token_lifetime_days(None)) 285 + # should expire roughly expected_days from now 286 + expected_expiry = datetime.now(UTC) + timedelta(days=expected_days) 285 287 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 286 288 diff = abs((expected_expiry - actual_expiry).total_seconds()) 287 289 assert diff < 60 # within 1 minute 288 290 289 291 290 292 async def test_create_session_with_no_expiration(db_session: AsyncSession): 291 - """verify session creation with expires_in_days=0 creates non-expiring session.""" 293 + """verify session creation with expires_in_days=0 caps to refresh lifetime.""" 292 294 did = "did:plc:noexp123" 293 295 handle = "noexp.bsky.social" 294 296 oauth_data = {"access_token": "token", "refresh_token": "refresh"} ··· 301 303 assert session is not None 302 304 assert session.did == did 303 305 304 - # verify expires_at is None 306 + # verify expires_at is capped to refresh token lifetime 305 307 result = await db_session.execute( 306 308 select(UserSession).where(UserSession.session_id == session_id) 307 309 ) 308 310 db_session_record = result.scalar_one_or_none() 309 311 assert db_session_record is not None 310 - assert db_session_record.expires_at is None 312 + assert db_session_record.expires_at is not None 313 + 314 + expected_days = get_refresh_token_lifetime_days(None) 315 + expected_expiry = datetime.now(UTC) + timedelta(days=expected_days) 316 + actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 317 + diff = abs((expected_expiry - actual_expiry).total_seconds()) 318 + assert diff < 60 # within 1 minute 311 319 312 320 313 321 async def test_create_session_default_expiration(db_session: AsyncSession):
+3 -3
docs/authentication.md
··· 439 439 backend settings in `AuthSettings`: 440 440 - `developer_token_default_days`: default expiration (90 days) 441 441 - `developer_token_max_days`: max allowed expiration (365 days) 442 - - use `expires_in_days: 0` for tokens that never expire 442 + - use `expires_in_days: 0` to request the maximum allowed by refresh lifetime 443 443 444 444 ### how it works 445 445 ··· 485 485 with public clients, the underlying ATProto refresh token expires after 2 weeks regardless of what we store in our database. users would need to re-authenticate with their PDS every 2 weeks. 486 486 487 487 with confidential clients: 488 - - **developer tokens actually work long-term** - not limited to 2 weeks 488 + - **developer tokens work long-term** - not limited to 2 weeks 489 489 - **users don't get randomly kicked out** after 2 weeks of inactivity 490 - - **sessions last effectively forever** as long as tokens are refreshed within 180 days 490 + - **sessions last up to refresh lifetime** as long as tokens are refreshed within 180 days 491 491 492 492 ### how it works 493 493
+7 -11
docs/frontend/data-loading.md
··· 44 44 45 45 used for: 46 46 - auth-dependent data (liked tracks, user preferences) 47 - - data that needs client context (localStorage, cookies) 47 + - data that needs client context (local caches, media state) 48 48 - progressive enhancement 49 49 50 50 ```typescript 51 51 // frontend/src/routes/liked/+page.ts 52 52 export const load: PageLoad = async ({ fetch }) => { 53 - const sessionId = localStorage.getItem('session_id'); 54 - if (!sessionId) return { tracks: [] }; 55 - 56 53 const response = await fetch(`${API_URL}/tracks/liked`, { 57 - headers: { 'Authorization': `Bearer ${sessionId}` } 54 + credentials: 'include' 58 55 }); 56 + 57 + if (!response.ok) return { tracks: [] }; 59 58 60 59 return { tracks: await response.json() }; 61 60 }; 62 61 ``` 63 62 64 63 **benefits**: 65 - - access to browser APIs (localStorage, cookies) 66 - - runs on client, can use session tokens 64 + - access to browser APIs (window, local caches) 65 + - runs on client, can use HttpOnly cookie auth 67 66 - still loads before component mounts (faster than `onMount`) 68 67 69 68 ### layout loading (`+layout.ts`) ··· 75 74 ```typescript 76 75 // frontend/src/routes/+layout.ts 77 76 export async function load({ fetch }: LoadEvent) { 78 - const sessionId = localStorage.getItem('session_id'); 79 - if (!sessionId) return { user: null, isAuthenticated: false }; 80 - 81 77 const response = await fetch(`${API_URL}/auth/me`, { 82 - headers: { 'Authorization': `Bearer ${sessionId}` } 78 + credentials: 'include' 83 79 }); 84 80 85 81 if (response.ok) {