1"""test OAuth authentication and session management.""" 2 3import json 4from datetime import UTC, datetime, timedelta 5from unittest.mock import patch 6 7from sqlalchemy import select 8from sqlalchemy.ext.asyncio import AsyncSession 9 10from backend._internal.auth import ( 11 _decrypt_data, 12 _encrypt_data, 13 consume_exchange_token, 14 create_exchange_token, 15 create_session, 16 delete_session, 17 get_public_jwks, 18 get_session, 19 is_confidential_client, 20 update_session_tokens, 21) 22from backend.models import ExchangeToken, UserSession 23 24 25def test_encryption_roundtrip(): 26 """verify encryption and decryption work correctly.""" 27 original_data = "sensitive oauth data" 28 29 encrypted = _encrypt_data(original_data) 30 decrypted = _decrypt_data(encrypted) 31 32 assert decrypted == original_data 33 assert encrypted != original_data # ensure it's actually encrypted 34 35 36def test_encryption_of_json_data(): 37 """verify encryption works with json-serialized data.""" 38 oauth_data = { 39 "did": "did:plc:test123", 40 "handle": "test.bsky.social", 41 "access_token": "secret_token_123", 42 "refresh_token": "secret_refresh_456", 43 "dpop_private_key_pem": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg==\n-----END PRIVATE KEY-----", 44 } 45 46 json_str = json.dumps(oauth_data) 47 encrypted = _encrypt_data(json_str) 48 decrypted = _decrypt_data(encrypted) 49 50 assert decrypted is not None 51 assert json.loads(decrypted) == oauth_data 52 53 54async def test_create_session_with_encryption(db_session: AsyncSession): 55 """verify session creation encrypts OAuth data.""" 56 did = "did:plc:session123" 57 handle = "session.bsky.social" 58 oauth_data = { 59 "did": did, 60 "handle": handle, 61 "access_token": "secret_token", 62 "refresh_token": "secret_refresh", 63 "dpop_private_key_pem": "secret_key", 64 } 65 66 session_id = await create_session(did, handle, oauth_data) 67 68 # retrieve session and verify it was created correctly 69 session = await get_session(session_id) 70 assert session is not None 71 assert session.did == did 72 assert session.handle == handle 73 assert session.oauth_session["access_token"] == "secret_token" 74 assert session.oauth_session["refresh_token"] == "secret_refresh" 75 76 77async def test_get_session_decrypts_data(db_session: AsyncSession): 78 """verify get_session correctly decrypts OAuth data.""" 79 did = "did:plc:decrypt123" 80 handle = "decrypt.bsky.social" 81 oauth_data = { 82 "did": did, 83 "handle": handle, 84 "access_token": "secret_token_xyz", 85 "refresh_token": "secret_refresh_xyz", 86 } 87 88 session_id = await create_session(did, handle, oauth_data) 89 90 # retrieve and verify decryption 91 session = await get_session(session_id) 92 93 assert session is not None 94 assert session.did == did 95 assert session.handle == handle 96 assert session.oauth_session["access_token"] == "secret_token_xyz" 97 assert session.oauth_session["refresh_token"] == "secret_refresh_xyz" 98 99 100async def test_get_session_returns_none_for_invalid_id(db_session: AsyncSession): 101 """verify get_session returns None for non-existent session.""" 102 session = await get_session("invalid_session_id_that_does_not_exist") 103 assert session is None 104 105 106async def test_update_session_tokens(db_session: AsyncSession): 107 """verify session token update encrypts new data.""" 108 did = "did:plc:update123" 109 handle = "update.bsky.social" 110 original_oauth_data = { 111 "access_token": "original_token", 112 "refresh_token": "original_refresh", 113 } 114 115 session_id = await create_session(did, handle, original_oauth_data) 116 117 # update with new tokens 118 updated_oauth_data = {"access_token": "new_token", "refresh_token": "new_refresh"} 119 await update_session_tokens(session_id, updated_oauth_data) 120 121 # verify tokens were updated 122 session = await get_session(session_id) 123 assert session is not None 124 assert session.oauth_session["access_token"] == "new_token" 125 assert session.oauth_session["refresh_token"] == "new_refresh" 126 127 128async def test_delete_session(db_session: AsyncSession): 129 """verify session deletion works.""" 130 did = "did:plc:delete123" 131 handle = "delete.bsky.social" 132 oauth_data = {"access_token": "token"} 133 134 session_id = await create_session(did, handle, oauth_data) 135 136 # verify session exists 137 session = await get_session(session_id) 138 assert session is not None 139 140 # delete session 141 await delete_session(session_id) 142 143 # verify session is gone 144 session = await get_session(session_id) 145 assert session is None 146 147 148async def test_create_exchange_token(db_session: AsyncSession): 149 """verify exchange token creation.""" 150 did = "did:plc:exchange123" 151 handle = "exchange.bsky.social" 152 oauth_data = {"access_token": "token"} 153 154 session_id = await create_session(did, handle, oauth_data) 155 156 # create exchange token 157 token = await create_exchange_token(session_id) 158 159 # verify token can be consumed (proves it was created correctly) 160 result = await consume_exchange_token(token) 161 assert result is not None 162 returned_session_id, is_dev_token = result 163 assert returned_session_id == session_id 164 assert is_dev_token is False 165 166 # verify token can't be reused 167 second_attempt = await consume_exchange_token(token) 168 assert second_attempt is None 169 170 171async def test_consume_exchange_token(db_session: AsyncSession): 172 """verify exchange token consumption works.""" 173 did = "did:plc:consume123" 174 handle = "consume.bsky.social" 175 oauth_data = {"access_token": "token"} 176 177 session_id = await create_session(did, handle, oauth_data) 178 token = await create_exchange_token(session_id) 179 180 # consume token 181 result = await consume_exchange_token(token) 182 assert result is not None 183 returned_session_id, is_dev_token = result 184 assert returned_session_id == session_id 185 assert is_dev_token is False 186 187 # verify token can't be consumed again (proves it was marked as used) 188 second_attempt = await consume_exchange_token(token) 189 assert second_attempt is None 190 191 192async def test_exchange_token_cannot_be_reused(db_session: AsyncSession): 193 """verify exchange token can only be used once.""" 194 did = "did:plc:reuse123" 195 handle = "reuse.bsky.social" 196 oauth_data = {"access_token": "token"} 197 198 session_id = await create_session(did, handle, oauth_data) 199 token = await create_exchange_token(session_id) 200 201 # consume token first time 202 result = await consume_exchange_token(token) 203 assert result is not None 204 returned_session_id, _is_dev_token = result 205 assert returned_session_id == session_id 206 207 # try to consume again - should return None 208 second_result = await consume_exchange_token(token) 209 assert second_result is None 210 211 212async def test_exchange_token_returns_none_for_invalid_token(db_session: AsyncSession): 213 """verify consume_exchange_token returns None for invalid token.""" 214 result = await consume_exchange_token("invalid_token_that_does_not_exist") 215 assert result is None 216 217 218async def test_exchange_token_expires(db_session: AsyncSession): 219 """verify expired exchange token returns None.""" 220 # use a separate database session to manually expire the token 221 from backend.utilities.database import db_session as get_db_session 222 223 did = "did:plc:expire123" 224 handle = "expire.bsky.social" 225 oauth_data = {"access_token": "token"} 226 227 session_id = await create_session(did, handle, oauth_data) 228 token = await create_exchange_token(session_id) 229 230 # manually expire the token by updating its expiration 231 async with get_db_session() as db: 232 result = await db.execute( 233 select(ExchangeToken).where(ExchangeToken.token == token) 234 ) 235 exchange_token = result.scalar_one_or_none() 236 assert exchange_token is not None 237 238 # set expiration to past 239 exchange_token.expires_at = datetime.now(UTC) - timedelta(seconds=1) 240 await db.commit() 241 242 # try to consume expired token - should return None 243 consumed = await consume_exchange_token(token) 244 assert consumed is None 245 246 247async def test_session_isolation(db_session: AsyncSession): 248 """verify each test starts with clean database.""" 249 # this should not see sessions from other tests 250 result = await db_session.execute(select(UserSession)) 251 sessions = result.scalars().all() 252 assert len(sessions) == 0 253 254 result = await db_session.execute(select(ExchangeToken)) 255 tokens = result.scalars().all() 256 assert len(tokens) == 0 257 258 259async def test_create_session_with_custom_expiration(db_session: AsyncSession): 260 """verify session creation with custom expiration works.""" 261 did = "did:plc:customexp123" 262 handle = "customexp.bsky.social" 263 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 264 265 # create session with 30-day expiration 266 session_id = await create_session(did, handle, oauth_data, expires_in_days=30) 267 268 # verify session exists and works 269 session = await get_session(session_id) 270 assert session is not None 271 assert session.did == did 272 273 # verify expiration is set (within reasonable range) 274 result = await db_session.execute( 275 select(UserSession).where(UserSession.session_id == session_id) 276 ) 277 db_session_record = result.scalar_one_or_none() 278 assert db_session_record is not None 279 assert db_session_record.expires_at is not None 280 281 # should expire roughly 30 days from now 282 expected_expiry = datetime.now(UTC) + timedelta(days=30) 283 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 284 diff = abs((expected_expiry - actual_expiry).total_seconds()) 285 assert diff < 60 # within 1 minute 286 287 288async def test_create_session_with_no_expiration(db_session: AsyncSession): 289 """verify session creation with expires_in_days=0 creates non-expiring session.""" 290 did = "did:plc:noexp123" 291 handle = "noexp.bsky.social" 292 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 293 294 # create session with no expiration 295 session_id = await create_session(did, handle, oauth_data, expires_in_days=0) 296 297 # verify session exists 298 session = await get_session(session_id) 299 assert session is not None 300 assert session.did == did 301 302 # verify expires_at is None 303 result = await db_session.execute( 304 select(UserSession).where(UserSession.session_id == session_id) 305 ) 306 db_session_record = result.scalar_one_or_none() 307 assert db_session_record is not None 308 assert db_session_record.expires_at is None 309 310 311async def test_create_session_default_expiration(db_session: AsyncSession): 312 """verify session creation uses default 14-day expiration.""" 313 did = "did:plc:defaultexp123" 314 handle = "defaultexp.bsky.social" 315 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 316 317 # create session with default expiration 318 session_id = await create_session(did, handle, oauth_data) 319 320 # verify expiration is set to default (14 days) 321 result = await db_session.execute( 322 select(UserSession).where(UserSession.session_id == session_id) 323 ) 324 db_session_record = result.scalar_one_or_none() 325 assert db_session_record is not None 326 assert db_session_record.expires_at is not None 327 328 expected_expiry = datetime.now(UTC) + timedelta(days=14) 329 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 330 diff = abs((expected_expiry - actual_expiry).total_seconds()) 331 assert diff < 60 # within 1 minute 332 333 334# confidential client tests 335 336 337def test_is_confidential_client_false_by_default(): 338 """verify is_confidential_client returns False when OAUTH_JWK not set.""" 339 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 340 assert is_confidential_client() is False 341 342 343def test_is_confidential_client_true_when_configured(): 344 """verify is_confidential_client returns True when OAUTH_JWK is set.""" 345 test_jwk = '{"kty":"EC","crv":"P-256","x":"test","y":"test","d":"test"}' 346 347 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 348 assert is_confidential_client() is True 349 350 351def test_get_public_jwks_returns_none_without_config(): 352 """verify get_public_jwks returns None when OAUTH_JWK not configured.""" 353 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 354 assert get_public_jwks() is None 355 356 357def test_get_public_jwks_returns_public_key(): 358 """verify get_public_jwks returns JWKS with public key only.""" 359 # generate a test JWK 360 from cryptography.hazmat.primitives import serialization 361 from cryptography.hazmat.primitives.asymmetric import ec 362 from jose import jwk as jose_jwk 363 364 # generate test key 365 private_key = ec.generate_private_key(ec.SECP256R1()) 366 pem_bytes = private_key.private_bytes( 367 encoding=serialization.Encoding.PEM, 368 format=serialization.PrivateFormat.PKCS8, 369 encryption_algorithm=serialization.NoEncryption(), 370 ) 371 key_obj = jose_jwk.construct(pem_bytes, algorithm="ES256") 372 jwk_dict = key_obj.to_dict() 373 jwk_dict["kid"] = "test-key-id" # add kid to test preservation 374 test_jwk = json.dumps(jwk_dict) 375 376 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 377 jwks = get_public_jwks() 378 379 assert jwks is not None 380 assert "keys" in jwks 381 assert len(jwks["keys"]) == 1 382 383 public_key = jwks["keys"][0] 384 # should NOT have private key component 385 assert "d" not in public_key 386 # should have public key components 387 assert "x" in public_key 388 assert "y" in public_key 389 assert public_key["kty"] == "EC" 390 assert public_key["alg"] == "ES256" 391 assert public_key["use"] == "sig" 392 # should preserve kid from original JWK 393 assert public_key["kid"] == "test-key-id"