Merge pull request #734 from zzstoatzz/fix/auth-token-mismatch

Align session expiry with refresh token lifetime

authored by nate nowack and committed by GitHub 5fb8a757 ea8d67d8

Changed files
+159 -31
backend
src
backend
_internal
atproto
api
tests
docs
+15
backend/src/backend/_internal/atproto/client.py
··· 3 import asyncio 4 import json 5 import logging 6 from typing import Any 7 8 from atproto_oauth.models import OAuthSession ··· 10 11 from backend._internal import Session as AuthSession 12 from backend._internal import get_oauth_client, get_session, update_session_tokens 13 14 logger = logging.getLogger(__name__) 15 ··· 123 "dpop_authserver_nonce": refreshed_session.dpop_authserver_nonce, 124 "dpop_pds_nonce": refreshed_session.dpop_pds_nonce or "", 125 } 126 127 # update session in database 128 await update_session_tokens(session_id, updated_session_data)
··· 3 import asyncio 4 import json 5 import logging 6 + from datetime import UTC, datetime, timedelta 7 from typing import Any 8 9 from atproto_oauth.models import OAuthSession ··· 11 12 from backend._internal import Session as AuthSession 13 from backend._internal import get_oauth_client, get_session, update_session_tokens 14 + from backend._internal.auth import ( 15 + get_client_auth_method, 16 + get_refresh_token_lifetime_days, 17 + ) 18 19 logger = logging.getLogger(__name__) 20 ··· 128 "dpop_authserver_nonce": refreshed_session.dpop_authserver_nonce, 129 "dpop_pds_nonce": refreshed_session.dpop_pds_nonce or "", 130 } 131 + client_auth_method = get_client_auth_method(updated_oauth_data) 132 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 133 + refresh_expires_at = datetime.now(UTC) + timedelta( 134 + days=refresh_lifetime_days 135 + ) 136 + updated_session_data["client_auth_method"] = client_auth_method 137 + updated_session_data["refresh_token_lifetime_days"] = refresh_lifetime_days 138 + updated_session_data["refresh_token_expires_at"] = ( 139 + refresh_expires_at.isoformat() 140 + ) 141 142 # update session in database 143 await update_session_tokens(session_id, updated_session_data)
+113 -9
backend/src/backend/_internal/auth.py
··· 25 26 logger = logging.getLogger(__name__) 27 28 29 def _parse_scopes(scope_string: str) -> set[str]: 30 """parse an OAuth scope string into a set of individual scopes. ··· 163 return bool(settings.atproto.oauth_jwk) 164 165 166 def get_oauth_client(include_teal: bool = False) -> OAuthClient: 167 """create an OAuth client with the appropriate scopes. 168 ··· 249 did: user's decentralized identifier 250 handle: user's ATProto handle 251 oauth_session: OAuth session data to encrypt and store 252 - expires_in_days: session expiration in days (default 14, use 0 for no expiration) 253 is_developer_token: whether this is a developer token (for listing/revocation) 254 token_name: optional name for the token (only for developer tokens) 255 group_id: optional session group ID for multi-account support 256 """ 257 session_id = secrets.token_urlsafe(32) 258 259 - encrypted_data = _encrypt_data(json.dumps(oauth_session)) 260 261 - expires_at = ( 262 - datetime.now(UTC) + timedelta(days=expires_in_days) 263 - if expires_in_days > 0 264 - else None 265 ) 266 267 async with db_session() as db: 268 user_session = UserSession( ··· 301 if decrypted_data is None: 302 # decryption failed - session is invalid (key changed or data corrupted) 303 # delete the corrupted session 304 await delete_session(session_id) 305 return None 306 ··· 308 session_id=user_session.session_id, 309 did=user_session.did, 310 handle=user_session.handle, 311 - oauth_session=json.loads(decrypted_data), 312 ) 313 314 ··· 445 encryption_algorithm=serialization.NoEncryption(), 446 ).decode("utf-8") 447 448 # store full OAuth session with tokens in database 449 session_data = { 450 "did": oauth_session.did, ··· 457 "dpop_private_key_pem": dpop_key_pem, 458 "dpop_authserver_nonce": oauth_session.dpop_authserver_nonce, 459 "dpop_pds_nonce": oauth_session.dpop_pds_nonce or "", 460 } 461 return oauth_session.did, oauth_session.handle, session_data 462 except Exception as e: ··· 658 sessions = result.scalars().all() 659 660 tokens = [] 661 for session in sessions: 662 # check if expired 663 - if session.expires_at and datetime.now(UTC) > session.expires_at: 664 continue # skip expired tokens 665 666 tokens.append( ··· 668 session_id=session.session_id, 669 token_name=session.token_name, 670 created_at=session.created_at, 671 - expires_at=session.expires_at, 672 ) 673 ) 674
··· 25 26 logger = logging.getLogger(__name__) 27 28 + PUBLIC_REFRESH_TOKEN_DAYS = 14 29 + CONFIDENTIAL_REFRESH_TOKEN_DAYS = 180 30 + 31 32 def _parse_scopes(scope_string: str) -> set[str]: 33 """parse an OAuth scope string into a set of individual scopes. ··· 166 return bool(settings.atproto.oauth_jwk) 167 168 169 + def get_client_auth_method(oauth_session_data: dict[str, Any] | None = None) -> str: 170 + """resolve client auth method for a session.""" 171 + if oauth_session_data: 172 + method = oauth_session_data.get("client_auth_method") 173 + if method in {"public", "confidential"}: 174 + return method 175 + return "confidential" if is_confidential_client() else "public" 176 + 177 + 178 + def get_refresh_token_lifetime_days(client_auth_method: str | None) -> int: 179 + """get expected refresh token lifetime in days.""" 180 + method = client_auth_method or get_client_auth_method() 181 + return ( 182 + CONFIDENTIAL_REFRESH_TOKEN_DAYS 183 + if method == "confidential" 184 + else PUBLIC_REFRESH_TOKEN_DAYS 185 + ) 186 + 187 + 188 + def _compute_refresh_token_expires_at( 189 + now: datetime, client_auth_method: str | None 190 + ) -> datetime: 191 + """compute refresh token expiration time.""" 192 + return now + timedelta(days=get_refresh_token_lifetime_days(client_auth_method)) 193 + 194 + 195 + def _parse_datetime(value: str | None) -> datetime | None: 196 + """parse ISO datetime string safely.""" 197 + if not value: 198 + return None 199 + try: 200 + return datetime.fromisoformat(value) 201 + except ValueError: 202 + return None 203 + 204 + 205 + def _get_refresh_token_expires_at( 206 + user_session: UserSession, 207 + oauth_session_data: dict[str, Any], 208 + ) -> datetime | None: 209 + """determine refresh token expiry for a session.""" 210 + parsed = _parse_datetime(oauth_session_data.get("refresh_token_expires_at")) 211 + if parsed: 212 + return parsed 213 + 214 + client_auth_method = oauth_session_data.get("client_auth_method") 215 + if client_auth_method: 216 + return user_session.created_at + timedelta( 217 + days=get_refresh_token_lifetime_days(client_auth_method) 218 + ) 219 + 220 + if user_session.is_developer_token: 221 + return user_session.created_at + timedelta(days=PUBLIC_REFRESH_TOKEN_DAYS) 222 + 223 + return None 224 + 225 + 226 def get_oauth_client(include_teal: bool = False) -> OAuthClient: 227 """create an OAuth client with the appropriate scopes. 228 ··· 309 did: user's decentralized identifier 310 handle: user's ATProto handle 311 oauth_session: OAuth session data to encrypt and store 312 + expires_in_days: session expiration in days (default 14, capped by refresh lifetime) 313 is_developer_token: whether this is a developer token (for listing/revocation) 314 token_name: optional name for the token (only for developer tokens) 315 group_id: optional session group ID for multi-account support 316 """ 317 session_id = secrets.token_urlsafe(32) 318 + now = datetime.now(UTC) 319 320 + client_auth_method = get_client_auth_method(oauth_session) 321 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 322 + refresh_expires_at = _compute_refresh_token_expires_at(now, client_auth_method) 323 324 + oauth_session = dict(oauth_session) 325 + oauth_session.setdefault("client_auth_method", client_auth_method) 326 + oauth_session.setdefault("refresh_token_lifetime_days", refresh_lifetime_days) 327 + oauth_session.setdefault("refresh_token_expires_at", refresh_expires_at.isoformat()) 328 + 329 + effective_days = ( 330 + refresh_lifetime_days 331 + if expires_in_days <= 0 332 + else min(expires_in_days, refresh_lifetime_days) 333 ) 334 + expires_at = now + timedelta(days=effective_days) 335 + 336 + encrypted_data = _encrypt_data(json.dumps(oauth_session)) 337 338 async with db_session() as db: 339 user_session = UserSession( ··· 372 if decrypted_data is None: 373 # decryption failed - session is invalid (key changed or data corrupted) 374 # delete the corrupted session 375 + await delete_session(session_id) 376 + return None 377 + 378 + oauth_session_data = json.loads(decrypted_data) 379 + 380 + refresh_expires_at = _get_refresh_token_expires_at( 381 + user_session, oauth_session_data 382 + ) 383 + if refresh_expires_at and datetime.now(UTC) > refresh_expires_at: 384 await delete_session(session_id) 385 return None 386 ··· 388 session_id=user_session.session_id, 389 did=user_session.did, 390 handle=user_session.handle, 391 + oauth_session=oauth_session_data, 392 ) 393 394 ··· 525 encryption_algorithm=serialization.NoEncryption(), 526 ).decode("utf-8") 527 528 + client_auth_method = get_client_auth_method() 529 + refresh_lifetime_days = get_refresh_token_lifetime_days(client_auth_method) 530 + refresh_expires_at = _compute_refresh_token_expires_at( 531 + datetime.now(UTC), client_auth_method 532 + ) 533 + 534 # store full OAuth session with tokens in database 535 session_data = { 536 "did": oauth_session.did, ··· 543 "dpop_private_key_pem": dpop_key_pem, 544 "dpop_authserver_nonce": oauth_session.dpop_authserver_nonce, 545 "dpop_pds_nonce": oauth_session.dpop_pds_nonce or "", 546 + "client_auth_method": client_auth_method, 547 + "refresh_token_lifetime_days": refresh_lifetime_days, 548 + "refresh_token_expires_at": refresh_expires_at.isoformat(), 549 } 550 return oauth_session.did, oauth_session.handle, session_data 551 except Exception as e: ··· 747 sessions = result.scalars().all() 748 749 tokens = [] 750 + now = datetime.now(UTC) 751 for session in sessions: 752 + decrypted_data = _decrypt_data(session.oauth_session_data) 753 + oauth_session_data = ( 754 + json.loads(decrypted_data) if decrypted_data is not None else {} 755 + ) 756 + refresh_expires_at = _get_refresh_token_expires_at( 757 + session, oauth_session_data 758 + ) 759 + effective_expires_at = session.expires_at 760 + if refresh_expires_at and ( 761 + effective_expires_at is None 762 + or refresh_expires_at < effective_expires_at 763 + ): 764 + effective_expires_at = refresh_expires_at 765 + 766 # check if expired 767 + if effective_expires_at and now > effective_expires_at: 768 continue # skip expired tokens 769 770 tokens.append( ··· 772 session_id=session.session_id, 773 token_name=session.token_name, 774 created_at=session.created_at, 775 + expires_at=effective_expires_at, 776 ) 777 ) 778
+6 -1
backend/src/backend/api/auth.py
··· 35 start_oauth_flow_with_scopes, 36 switch_active_account, 37 ) 38 from backend._internal.background_tasks import schedule_atproto_sync 39 from backend.config import settings 40 from backend.models import Artist, get_db ··· 466 if expires_in_days > max_days: 467 raise HTTPException( 468 status_code=400, 469 - detail=f"expires_in_days cannot exceed {max_days} (use 0 for no expiration)", 470 ) 471 472 # start OAuth flow using the user's handle 473 auth_url, state = await start_oauth_flow(session.handle)
··· 35 start_oauth_flow_with_scopes, 36 switch_active_account, 37 ) 38 + from backend._internal.auth import get_refresh_token_lifetime_days 39 from backend._internal.background_tasks import schedule_atproto_sync 40 from backend.config import settings 41 from backend.models import Artist, get_db ··· 467 if expires_in_days > max_days: 468 raise HTTPException( 469 status_code=400, 470 + detail=f"expires_in_days cannot exceed {max_days}", 471 ) 472 + 473 + refresh_lifetime_days = get_refresh_token_lifetime_days(None) 474 + if expires_in_days <= 0 or expires_in_days > refresh_lifetime_days: 475 + expires_in_days = refresh_lifetime_days 476 477 # start OAuth flow using the user's handle 478 auth_url, state = await start_oauth_flow(session.handle)
+1 -1
backend/src/backend/config.py
··· 666 667 developer_token_default_days: int = Field( 668 default=90, 669 - description="Default expiration in days for developer tokens (0 = no expiration)", 670 ) 671 developer_token_max_days: int = Field( 672 default=365,
··· 666 667 developer_token_default_days: int = Field( 668 default=90, 669 + description="Default expiration in days for developer tokens (capped by refresh lifetime)", 670 ) 671 developer_token_max_days: int = Field( 672 default=365,
+14 -6
backend/tests/test_auth.py
··· 17 create_session, 18 delete_session, 19 get_public_jwks, 20 get_session, 21 is_confidential_client, 22 update_session_tokens, ··· 259 260 261 async def test_create_session_with_custom_expiration(db_session: AsyncSession): 262 - """verify session creation with custom expiration works.""" 263 did = "did:plc:customexp123" 264 handle = "customexp.bsky.social" 265 oauth_data = {"access_token": "token", "refresh_token": "refresh"} ··· 280 assert db_session_record is not None 281 assert db_session_record.expires_at is not None 282 283 - # should expire roughly 30 days from now 284 - expected_expiry = datetime.now(UTC) + timedelta(days=30) 285 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 286 diff = abs((expected_expiry - actual_expiry).total_seconds()) 287 assert diff < 60 # within 1 minute 288 289 290 async def test_create_session_with_no_expiration(db_session: AsyncSession): 291 - """verify session creation with expires_in_days=0 creates non-expiring session.""" 292 did = "did:plc:noexp123" 293 handle = "noexp.bsky.social" 294 oauth_data = {"access_token": "token", "refresh_token": "refresh"} ··· 301 assert session is not None 302 assert session.did == did 303 304 - # verify expires_at is None 305 result = await db_session.execute( 306 select(UserSession).where(UserSession.session_id == session_id) 307 ) 308 db_session_record = result.scalar_one_or_none() 309 assert db_session_record is not None 310 - assert db_session_record.expires_at is None 311 312 313 async def test_create_session_default_expiration(db_session: AsyncSession):
··· 17 create_session, 18 delete_session, 19 get_public_jwks, 20 + get_refresh_token_lifetime_days, 21 get_session, 22 is_confidential_client, 23 update_session_tokens, ··· 260 261 262 async def test_create_session_with_custom_expiration(db_session: AsyncSession): 263 + """verify session creation with custom expiration is capped by refresh lifetime.""" 264 did = "did:plc:customexp123" 265 handle = "customexp.bsky.social" 266 oauth_data = {"access_token": "token", "refresh_token": "refresh"} ··· 281 assert db_session_record is not None 282 assert db_session_record.expires_at is not None 283 284 + expected_days = min(30, get_refresh_token_lifetime_days(None)) 285 + # should expire roughly expected_days from now 286 + expected_expiry = datetime.now(UTC) + timedelta(days=expected_days) 287 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 288 diff = abs((expected_expiry - actual_expiry).total_seconds()) 289 assert diff < 60 # within 1 minute 290 291 292 async def test_create_session_with_no_expiration(db_session: AsyncSession): 293 + """verify session creation with expires_in_days=0 caps to refresh lifetime.""" 294 did = "did:plc:noexp123" 295 handle = "noexp.bsky.social" 296 oauth_data = {"access_token": "token", "refresh_token": "refresh"} ··· 303 assert session is not None 304 assert session.did == did 305 306 + # verify expires_at is capped to refresh token lifetime 307 result = await db_session.execute( 308 select(UserSession).where(UserSession.session_id == session_id) 309 ) 310 db_session_record = result.scalar_one_or_none() 311 assert db_session_record is not None 312 + assert db_session_record.expires_at is not None 313 + 314 + expected_days = get_refresh_token_lifetime_days(None) 315 + expected_expiry = datetime.now(UTC) + timedelta(days=expected_days) 316 + actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 317 + diff = abs((expected_expiry - actual_expiry).total_seconds()) 318 + assert diff < 60 # within 1 minute 319 320 321 async def test_create_session_default_expiration(db_session: AsyncSession):
+3 -3
docs/authentication.md
··· 439 backend settings in `AuthSettings`: 440 - `developer_token_default_days`: default expiration (90 days) 441 - `developer_token_max_days`: max allowed expiration (365 days) 442 - - use `expires_in_days: 0` for tokens that never expire 443 444 ### how it works 445 ··· 485 with public clients, the underlying ATProto refresh token expires after 2 weeks regardless of what we store in our database. users would need to re-authenticate with their PDS every 2 weeks. 486 487 with confidential clients: 488 - - **developer tokens actually work long-term** - not limited to 2 weeks 489 - **users don't get randomly kicked out** after 2 weeks of inactivity 490 - - **sessions last effectively forever** as long as tokens are refreshed within 180 days 491 492 ### how it works 493
··· 439 backend settings in `AuthSettings`: 440 - `developer_token_default_days`: default expiration (90 days) 441 - `developer_token_max_days`: max allowed expiration (365 days) 442 + - use `expires_in_days: 0` to request the maximum allowed by refresh lifetime 443 444 ### how it works 445 ··· 485 with public clients, the underlying ATProto refresh token expires after 2 weeks regardless of what we store in our database. users would need to re-authenticate with their PDS every 2 weeks. 486 487 with confidential clients: 488 + - **developer tokens work long-term** - not limited to 2 weeks 489 - **users don't get randomly kicked out** after 2 weeks of inactivity 490 + - **sessions last up to refresh lifetime** as long as tokens are refreshed within 180 days 491 492 ### how it works 493
+7 -11
docs/frontend/data-loading.md
··· 44 45 used for: 46 - auth-dependent data (liked tracks, user preferences) 47 - - data that needs client context (localStorage, cookies) 48 - progressive enhancement 49 50 ```typescript 51 // frontend/src/routes/liked/+page.ts 52 export const load: PageLoad = async ({ fetch }) => { 53 - const sessionId = localStorage.getItem('session_id'); 54 - if (!sessionId) return { tracks: [] }; 55 - 56 const response = await fetch(`${API_URL}/tracks/liked`, { 57 - headers: { 'Authorization': `Bearer ${sessionId}` } 58 }); 59 60 return { tracks: await response.json() }; 61 }; 62 ``` 63 64 **benefits**: 65 - - access to browser APIs (localStorage, cookies) 66 - - runs on client, can use session tokens 67 - still loads before component mounts (faster than `onMount`) 68 69 ### layout loading (`+layout.ts`) ··· 75 ```typescript 76 // frontend/src/routes/+layout.ts 77 export async function load({ fetch }: LoadEvent) { 78 - const sessionId = localStorage.getItem('session_id'); 79 - if (!sessionId) return { user: null, isAuthenticated: false }; 80 - 81 const response = await fetch(`${API_URL}/auth/me`, { 82 - headers: { 'Authorization': `Bearer ${sessionId}` } 83 }); 84 85 if (response.ok) {
··· 44 45 used for: 46 - auth-dependent data (liked tracks, user preferences) 47 + - data that needs client context (local caches, media state) 48 - progressive enhancement 49 50 ```typescript 51 // frontend/src/routes/liked/+page.ts 52 export const load: PageLoad = async ({ fetch }) => { 53 const response = await fetch(`${API_URL}/tracks/liked`, { 54 + credentials: 'include' 55 }); 56 + 57 + if (!response.ok) return { tracks: [] }; 58 59 return { tracks: await response.json() }; 60 }; 61 ``` 62 63 **benefits**: 64 + - access to browser APIs (window, local caches) 65 + - runs on client, can use HttpOnly cookie auth 66 - still loads before component mounts (faster than `onMount`) 67 68 ### layout loading (`+layout.ts`) ··· 74 ```typescript 75 // frontend/src/routes/+layout.ts 76 export async function load({ fetch }: LoadEvent) { 77 const response = await fetch(`${API_URL}/auth/me`, { 78 + credentials: 'include' 79 }); 80 81 if (response.ok) {