at main 20 kB view raw
1"""test OAuth authentication and session management.""" 2 3import json 4from datetime import UTC, datetime, timedelta 5from unittest.mock import patch 6 7import pytest 8from fastapi import HTTPException 9from sqlalchemy import select 10from sqlalchemy.ext.asyncio import AsyncSession 11 12from backend._internal.auth import ( 13 _decrypt_data, 14 _encrypt_data, 15 consume_exchange_token, 16 create_exchange_token, 17 create_session, 18 delete_session, 19 get_public_jwks, 20 get_refresh_token_lifetime_days, 21 get_session, 22 is_confidential_client, 23 update_session_tokens, 24) 25from backend.models import ExchangeToken, UserSession 26 27 28def test_encryption_roundtrip(): 29 """verify encryption and decryption work correctly.""" 30 original_data = "sensitive oauth data" 31 32 encrypted = _encrypt_data(original_data) 33 decrypted = _decrypt_data(encrypted) 34 35 assert decrypted == original_data 36 assert encrypted != original_data # ensure it's actually encrypted 37 38 39def test_encryption_of_json_data(): 40 """verify encryption works with json-serialized data.""" 41 oauth_data = { 42 "did": "did:plc:test123", 43 "handle": "test.bsky.social", 44 "access_token": "secret_token_123", 45 "refresh_token": "secret_refresh_456", 46 "dpop_private_key_pem": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg==\n-----END PRIVATE KEY-----", 47 } 48 49 json_str = json.dumps(oauth_data) 50 encrypted = _encrypt_data(json_str) 51 decrypted = _decrypt_data(encrypted) 52 53 assert decrypted is not None 54 assert json.loads(decrypted) == oauth_data 55 56 57async def test_create_session_with_encryption(db_session: AsyncSession): 58 """verify session creation encrypts OAuth data.""" 59 did = "did:plc:session123" 60 handle = "session.bsky.social" 61 oauth_data = { 62 "did": did, 63 "handle": handle, 64 "access_token": "secret_token", 65 "refresh_token": "secret_refresh", 66 "dpop_private_key_pem": "secret_key", 67 } 68 69 session_id = await create_session(did, handle, oauth_data) 70 71 # retrieve session and verify it was created correctly 72 session = await get_session(session_id) 73 assert session is not None 74 assert session.did == did 75 assert session.handle == handle 76 assert session.oauth_session["access_token"] == "secret_token" 77 assert session.oauth_session["refresh_token"] == "secret_refresh" 78 79 80async def test_get_session_decrypts_data(db_session: AsyncSession): 81 """verify get_session correctly decrypts OAuth data.""" 82 did = "did:plc:decrypt123" 83 handle = "decrypt.bsky.social" 84 oauth_data = { 85 "did": did, 86 "handle": handle, 87 "access_token": "secret_token_xyz", 88 "refresh_token": "secret_refresh_xyz", 89 } 90 91 session_id = await create_session(did, handle, oauth_data) 92 93 # retrieve and verify decryption 94 session = await get_session(session_id) 95 96 assert session is not None 97 assert session.did == did 98 assert session.handle == handle 99 assert session.oauth_session["access_token"] == "secret_token_xyz" 100 assert session.oauth_session["refresh_token"] == "secret_refresh_xyz" 101 102 103async def test_get_session_returns_none_for_invalid_id(db_session: AsyncSession): 104 """verify get_session returns None for non-existent session.""" 105 session = await get_session("invalid_session_id_that_does_not_exist") 106 assert session is None 107 108 109async def test_update_session_tokens(db_session: AsyncSession): 110 """verify session token update encrypts new data.""" 111 did = "did:plc:update123" 112 handle = "update.bsky.social" 113 original_oauth_data = { 114 "access_token": "original_token", 115 "refresh_token": "original_refresh", 116 } 117 118 session_id = await create_session(did, handle, original_oauth_data) 119 120 # update with new tokens 121 updated_oauth_data = {"access_token": "new_token", "refresh_token": "new_refresh"} 122 await update_session_tokens(session_id, updated_oauth_data) 123 124 # verify tokens were updated 125 session = await get_session(session_id) 126 assert session is not None 127 assert session.oauth_session["access_token"] == "new_token" 128 assert session.oauth_session["refresh_token"] == "new_refresh" 129 130 131async def test_delete_session(db_session: AsyncSession): 132 """verify session deletion works.""" 133 did = "did:plc:delete123" 134 handle = "delete.bsky.social" 135 oauth_data = {"access_token": "token"} 136 137 session_id = await create_session(did, handle, oauth_data) 138 139 # verify session exists 140 session = await get_session(session_id) 141 assert session is not None 142 143 # delete session 144 await delete_session(session_id) 145 146 # verify session is gone 147 session = await get_session(session_id) 148 assert session is None 149 150 151async def test_create_exchange_token(db_session: AsyncSession): 152 """verify exchange token creation.""" 153 did = "did:plc:exchange123" 154 handle = "exchange.bsky.social" 155 oauth_data = {"access_token": "token"} 156 157 session_id = await create_session(did, handle, oauth_data) 158 159 # create exchange token 160 token = await create_exchange_token(session_id) 161 162 # verify token can be consumed (proves it was created correctly) 163 result = await consume_exchange_token(token) 164 assert result is not None 165 returned_session_id, is_dev_token = result 166 assert returned_session_id == session_id 167 assert is_dev_token is False 168 169 # verify token can't be reused 170 second_attempt = await consume_exchange_token(token) 171 assert second_attempt is None 172 173 174async def test_consume_exchange_token(db_session: AsyncSession): 175 """verify exchange token consumption works.""" 176 did = "did:plc:consume123" 177 handle = "consume.bsky.social" 178 oauth_data = {"access_token": "token"} 179 180 session_id = await create_session(did, handle, oauth_data) 181 token = await create_exchange_token(session_id) 182 183 # consume token 184 result = await consume_exchange_token(token) 185 assert result is not None 186 returned_session_id, is_dev_token = result 187 assert returned_session_id == session_id 188 assert is_dev_token is False 189 190 # verify token can't be consumed again (proves it was marked as used) 191 second_attempt = await consume_exchange_token(token) 192 assert second_attempt is None 193 194 195async def test_exchange_token_cannot_be_reused(db_session: AsyncSession): 196 """verify exchange token can only be used once.""" 197 did = "did:plc:reuse123" 198 handle = "reuse.bsky.social" 199 oauth_data = {"access_token": "token"} 200 201 session_id = await create_session(did, handle, oauth_data) 202 token = await create_exchange_token(session_id) 203 204 # consume token first time 205 result = await consume_exchange_token(token) 206 assert result is not None 207 returned_session_id, _is_dev_token = result 208 assert returned_session_id == session_id 209 210 # try to consume again - should return None 211 second_result = await consume_exchange_token(token) 212 assert second_result is None 213 214 215async def test_exchange_token_returns_none_for_invalid_token(db_session: AsyncSession): 216 """verify consume_exchange_token returns None for invalid token.""" 217 result = await consume_exchange_token("invalid_token_that_does_not_exist") 218 assert result is None 219 220 221async def test_exchange_token_expires(db_session: AsyncSession): 222 """verify expired exchange token returns None.""" 223 # use a separate database session to manually expire the token 224 from backend.utilities.database import db_session as get_db_session 225 226 did = "did:plc:expire123" 227 handle = "expire.bsky.social" 228 oauth_data = {"access_token": "token"} 229 230 session_id = await create_session(did, handle, oauth_data) 231 token = await create_exchange_token(session_id) 232 233 # manually expire the token by updating its expiration 234 async with get_db_session() as db: 235 result = await db.execute( 236 select(ExchangeToken).where(ExchangeToken.token == token) 237 ) 238 exchange_token = result.scalar_one_or_none() 239 assert exchange_token is not None 240 241 # set expiration to past 242 exchange_token.expires_at = datetime.now(UTC) - timedelta(seconds=1) 243 await db.commit() 244 245 # try to consume expired token - should return None 246 consumed = await consume_exchange_token(token) 247 assert consumed is None 248 249 250async def test_session_isolation(db_session: AsyncSession): 251 """verify each test starts with clean database.""" 252 # this should not see sessions from other tests 253 result = await db_session.execute(select(UserSession)) 254 sessions = result.scalars().all() 255 assert len(sessions) == 0 256 257 result = await db_session.execute(select(ExchangeToken)) 258 tokens = result.scalars().all() 259 assert len(tokens) == 0 260 261 262async def test_create_session_with_custom_expiration(db_session: AsyncSession): 263 """verify session creation with custom expiration is capped by refresh lifetime.""" 264 did = "did:plc:customexp123" 265 handle = "customexp.bsky.social" 266 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 267 268 # create session with 30-day expiration 269 session_id = await create_session(did, handle, oauth_data, expires_in_days=30) 270 271 # verify session exists and works 272 session = await get_session(session_id) 273 assert session is not None 274 assert session.did == did 275 276 # verify expiration is set (within reasonable range) 277 result = await db_session.execute( 278 select(UserSession).where(UserSession.session_id == session_id) 279 ) 280 db_session_record = result.scalar_one_or_none() 281 assert db_session_record is not None 282 assert db_session_record.expires_at is not None 283 284 expected_days = min(30, get_refresh_token_lifetime_days(None)) 285 # should expire roughly expected_days from now 286 expected_expiry = datetime.now(UTC) + timedelta(days=expected_days) 287 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 288 diff = abs((expected_expiry - actual_expiry).total_seconds()) 289 assert diff < 60 # within 1 minute 290 291 292async def test_create_session_with_no_expiration(db_session: AsyncSession): 293 """verify session creation with expires_in_days=0 caps to refresh lifetime.""" 294 did = "did:plc:noexp123" 295 handle = "noexp.bsky.social" 296 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 297 298 # create session with no expiration 299 session_id = await create_session(did, handle, oauth_data, expires_in_days=0) 300 301 # verify session exists 302 session = await get_session(session_id) 303 assert session is not None 304 assert session.did == did 305 306 # verify expires_at is capped to refresh token lifetime 307 result = await db_session.execute( 308 select(UserSession).where(UserSession.session_id == session_id) 309 ) 310 db_session_record = result.scalar_one_or_none() 311 assert db_session_record is not None 312 assert db_session_record.expires_at is not None 313 314 expected_days = get_refresh_token_lifetime_days(None) 315 expected_expiry = datetime.now(UTC) + timedelta(days=expected_days) 316 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 317 diff = abs((expected_expiry - actual_expiry).total_seconds()) 318 assert diff < 60 # within 1 minute 319 320 321async def test_create_session_default_expiration(db_session: AsyncSession): 322 """verify session creation uses default 14-day expiration.""" 323 did = "did:plc:defaultexp123" 324 handle = "defaultexp.bsky.social" 325 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 326 327 # create session with default expiration 328 session_id = await create_session(did, handle, oauth_data) 329 330 # verify expiration is set to default (14 days) 331 result = await db_session.execute( 332 select(UserSession).where(UserSession.session_id == session_id) 333 ) 334 db_session_record = result.scalar_one_or_none() 335 assert db_session_record is not None 336 assert db_session_record.expires_at is not None 337 338 expected_expiry = datetime.now(UTC) + timedelta(days=14) 339 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 340 diff = abs((expected_expiry - actual_expiry).total_seconds()) 341 assert diff < 60 # within 1 minute 342 343 344# confidential client tests 345 346 347def test_is_confidential_client_false_by_default(): 348 """verify is_confidential_client returns False when OAUTH_JWK not set.""" 349 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 350 assert is_confidential_client() is False 351 352 353def test_is_confidential_client_true_when_configured(): 354 """verify is_confidential_client returns True when OAUTH_JWK is set.""" 355 test_jwk = '{"kty":"EC","crv":"P-256","x":"test","y":"test","d":"test"}' 356 357 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 358 assert is_confidential_client() is True 359 360 361def test_get_public_jwks_returns_none_without_config(): 362 """verify get_public_jwks returns None when OAUTH_JWK not configured.""" 363 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 364 assert get_public_jwks() is None 365 366 367def test_get_public_jwks_returns_public_key(): 368 """verify get_public_jwks returns JWKS with public key only.""" 369 # generate a test JWK 370 from cryptography.hazmat.primitives import serialization 371 from cryptography.hazmat.primitives.asymmetric import ec 372 from jose import jwk as jose_jwk 373 374 # generate test key 375 private_key = ec.generate_private_key(ec.SECP256R1()) 376 pem_bytes = private_key.private_bytes( 377 encoding=serialization.Encoding.PEM, 378 format=serialization.PrivateFormat.PKCS8, 379 encryption_algorithm=serialization.NoEncryption(), 380 ) 381 key_obj = jose_jwk.construct(pem_bytes, algorithm="ES256") 382 jwk_dict = key_obj.to_dict() 383 jwk_dict["kid"] = "test-key-id" # add kid to test preservation 384 test_jwk = json.dumps(jwk_dict) 385 386 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 387 jwks = get_public_jwks() 388 389 assert jwks is not None 390 assert "keys" in jwks 391 assert len(jwks["keys"]) == 1 392 393 public_key = jwks["keys"][0] 394 # should NOT have private key component 395 assert "d" not in public_key 396 # should have public key components 397 assert "x" in public_key 398 assert "y" in public_key 399 assert public_key["kty"] == "EC" 400 assert public_key["alg"] == "ES256" 401 assert public_key["use"] == "sig" 402 # should preserve kid from original JWK 403 assert public_key["kid"] == "test-key-id" 404 405 406# multi-account tests 407 408 409async def test_get_or_create_group_id_creates_new(db_session: AsyncSession): 410 """verify group_id is created when session has none.""" 411 from backend._internal.auth import get_or_create_group_id 412 413 session_id = await create_session( 414 "did:plc:group1", "group1.bsky.social", {"access_token": "t1"} 415 ) 416 417 group_id = await get_or_create_group_id(session_id) 418 assert group_id is not None 419 420 # calling again returns same group_id 421 assert await get_or_create_group_id(session_id) == group_id 422 423 424async def test_get_session_group_empty_without_group(db_session: AsyncSession): 425 """verify get_session_group returns empty list for ungrouped session.""" 426 from backend._internal.auth import get_session_group 427 428 session_id = await create_session( 429 "did:plc:solo", "solo.bsky.social", {"access_token": "t1"} 430 ) 431 432 accounts = await get_session_group(session_id) 433 assert accounts == [] 434 435 436async def test_get_session_group_returns_linked_accounts(db_session: AsyncSession): 437 """verify get_session_group returns all accounts in group.""" 438 from backend._internal.auth import get_or_create_group_id, get_session_group 439 440 session1 = await create_session( 441 "did:plc:user1", "user1.bsky.social", {"access_token": "t1"} 442 ) 443 session2 = await create_session( 444 "did:plc:user2", "user2.bsky.social", {"access_token": "t2"} 445 ) 446 447 # link sessions to same group 448 group_id = await get_or_create_group_id(session1) 449 450 result = await db_session.execute( 451 select(UserSession).where(UserSession.session_id == session2) 452 ) 453 s2 = result.scalar_one() 454 s2.group_id = group_id 455 await db_session.commit() 456 457 accounts = await get_session_group(session1) 458 assert len(accounts) == 2 459 dids = {a.did for a in accounts} 460 assert dids == {"did:plc:user1", "did:plc:user2"} 461 462 463async def test_switch_active_account_validates_group(db_session: AsyncSession): 464 """verify switch_active_account rejects sessions not in same group.""" 465 from backend._internal.auth import get_or_create_group_id, switch_active_account 466 467 session1 = await create_session( 468 "did:plc:s1", "s1.bsky.social", {"access_token": "t1"} 469 ) 470 session2 = await create_session( 471 "did:plc:s2", "s2.bsky.social", {"access_token": "t2"} 472 ) 473 474 await get_or_create_group_id(session1) 475 476 # session2 not in group - should fail 477 with pytest.raises(HTTPException) as exc_info: 478 await switch_active_account(session1, session2) 479 assert isinstance(exc_info.value, HTTPException) 480 assert exc_info.value.status_code == 403 481 482 483async def test_switch_active_account_success(db_session: AsyncSession): 484 """verify switch_active_account works for same-group sessions.""" 485 from backend._internal.auth import get_or_create_group_id, switch_active_account 486 487 session1 = await create_session( 488 "did:plc:sw1", "sw1.bsky.social", {"access_token": "t1"} 489 ) 490 session2 = await create_session( 491 "did:plc:sw2", "sw2.bsky.social", {"access_token": "t2"} 492 ) 493 494 group_id = await get_or_create_group_id(session1) 495 496 result = await db_session.execute( 497 select(UserSession).where(UserSession.session_id == session2) 498 ) 499 s2 = result.scalar_one() 500 s2.group_id = group_id 501 await db_session.commit() 502 503 target = await switch_active_account(session1, session2) 504 assert target == session2 505 506 507async def test_remove_account_from_group_last_account(db_session: AsyncSession): 508 """verify remove_account_from_group returns None when last account removed.""" 509 from backend._internal.auth import remove_account_from_group 510 511 session_id = await create_session( 512 "did:plc:last", "last.bsky.social", {"access_token": "t1"} 513 ) 514 515 result = await remove_account_from_group(session_id) 516 assert result is None 517 518 # session should be deleted 519 assert await get_session(session_id) is None 520 521 522async def test_remove_account_from_group_returns_next(db_session: AsyncSession): 523 """verify remove_account_from_group returns next session when others remain.""" 524 from backend._internal.auth import get_or_create_group_id, remove_account_from_group 525 526 session1 = await create_session( 527 "did:plc:rem1", "rem1.bsky.social", {"access_token": "t1"} 528 ) 529 session2 = await create_session( 530 "did:plc:rem2", "rem2.bsky.social", {"access_token": "t2"} 531 ) 532 533 group_id = await get_or_create_group_id(session1) 534 535 result = await db_session.execute( 536 select(UserSession).where(UserSession.session_id == session2) 537 ) 538 s2 = result.scalar_one() 539 s2.group_id = group_id 540 await db_session.commit() 541 542 next_session = await remove_account_from_group(session1) 543 assert next_session == session2 544 545 # session1 deleted, session2 remains 546 assert await get_session(session1) is None 547 assert await get_session(session2) is not None 548 549 550async def test_pending_add_account_crud(db_session: AsyncSession): 551 """verify pending add account save/get/delete cycle.""" 552 from backend._internal.auth import ( 553 delete_pending_add_account, 554 get_pending_add_account, 555 save_pending_add_account, 556 ) 557 558 state = "test-oauth-state-123" 559 group_id = "test-group-id-456" 560 561 await save_pending_add_account(state, group_id) 562 563 pending = await get_pending_add_account(state) 564 assert pending is not None 565 assert pending.state == state 566 assert pending.group_id == group_id 567 568 await delete_pending_add_account(state) 569 assert await get_pending_add_account(state) is None