audio streaming app plyr.fm
38
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 2025.1230.154826 393 lines 14 kB view raw
1"""test OAuth authentication and session management.""" 2 3import json 4from datetime import UTC, datetime, timedelta 5from unittest.mock import patch 6 7from sqlalchemy import select 8from sqlalchemy.ext.asyncio import AsyncSession 9 10from backend._internal.auth import ( 11 _decrypt_data, 12 _encrypt_data, 13 consume_exchange_token, 14 create_exchange_token, 15 create_session, 16 delete_session, 17 get_public_jwks, 18 get_session, 19 is_confidential_client, 20 update_session_tokens, 21) 22from backend.models import ExchangeToken, UserSession 23 24 25def test_encryption_roundtrip(): 26 """verify encryption and decryption work correctly.""" 27 original_data = "sensitive oauth data" 28 29 encrypted = _encrypt_data(original_data) 30 decrypted = _decrypt_data(encrypted) 31 32 assert decrypted == original_data 33 assert encrypted != original_data # ensure it's actually encrypted 34 35 36def test_encryption_of_json_data(): 37 """verify encryption works with json-serialized data.""" 38 oauth_data = { 39 "did": "did:plc:test123", 40 "handle": "test.bsky.social", 41 "access_token": "secret_token_123", 42 "refresh_token": "secret_refresh_456", 43 "dpop_private_key_pem": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg==\n-----END PRIVATE KEY-----", 44 } 45 46 json_str = json.dumps(oauth_data) 47 encrypted = _encrypt_data(json_str) 48 decrypted = _decrypt_data(encrypted) 49 50 assert decrypted is not None 51 assert json.loads(decrypted) == oauth_data 52 53 54async def test_create_session_with_encryption(db_session: AsyncSession): 55 """verify session creation encrypts OAuth data.""" 56 did = "did:plc:session123" 57 handle = "session.bsky.social" 58 oauth_data = { 59 "did": did, 60 "handle": handle, 61 "access_token": "secret_token", 62 "refresh_token": "secret_refresh", 63 "dpop_private_key_pem": "secret_key", 64 } 65 66 session_id = await create_session(did, handle, oauth_data) 67 68 # retrieve session and verify it was created correctly 69 session = await get_session(session_id) 70 assert session is not None 71 assert session.did == did 72 assert session.handle == handle 73 assert session.oauth_session["access_token"] == "secret_token" 74 assert session.oauth_session["refresh_token"] == "secret_refresh" 75 76 77async def test_get_session_decrypts_data(db_session: AsyncSession): 78 """verify get_session correctly decrypts OAuth data.""" 79 did = "did:plc:decrypt123" 80 handle = "decrypt.bsky.social" 81 oauth_data = { 82 "did": did, 83 "handle": handle, 84 "access_token": "secret_token_xyz", 85 "refresh_token": "secret_refresh_xyz", 86 } 87 88 session_id = await create_session(did, handle, oauth_data) 89 90 # retrieve and verify decryption 91 session = await get_session(session_id) 92 93 assert session is not None 94 assert session.did == did 95 assert session.handle == handle 96 assert session.oauth_session["access_token"] == "secret_token_xyz" 97 assert session.oauth_session["refresh_token"] == "secret_refresh_xyz" 98 99 100async def test_get_session_returns_none_for_invalid_id(db_session: AsyncSession): 101 """verify get_session returns None for non-existent session.""" 102 session = await get_session("invalid_session_id_that_does_not_exist") 103 assert session is None 104 105 106async def test_update_session_tokens(db_session: AsyncSession): 107 """verify session token update encrypts new data.""" 108 did = "did:plc:update123" 109 handle = "update.bsky.social" 110 original_oauth_data = { 111 "access_token": "original_token", 112 "refresh_token": "original_refresh", 113 } 114 115 session_id = await create_session(did, handle, original_oauth_data) 116 117 # update with new tokens 118 updated_oauth_data = {"access_token": "new_token", "refresh_token": "new_refresh"} 119 await update_session_tokens(session_id, updated_oauth_data) 120 121 # verify tokens were updated 122 session = await get_session(session_id) 123 assert session is not None 124 assert session.oauth_session["access_token"] == "new_token" 125 assert session.oauth_session["refresh_token"] == "new_refresh" 126 127 128async def test_delete_session(db_session: AsyncSession): 129 """verify session deletion works.""" 130 did = "did:plc:delete123" 131 handle = "delete.bsky.social" 132 oauth_data = {"access_token": "token"} 133 134 session_id = await create_session(did, handle, oauth_data) 135 136 # verify session exists 137 session = await get_session(session_id) 138 assert session is not None 139 140 # delete session 141 await delete_session(session_id) 142 143 # verify session is gone 144 session = await get_session(session_id) 145 assert session is None 146 147 148async def test_create_exchange_token(db_session: AsyncSession): 149 """verify exchange token creation.""" 150 did = "did:plc:exchange123" 151 handle = "exchange.bsky.social" 152 oauth_data = {"access_token": "token"} 153 154 session_id = await create_session(did, handle, oauth_data) 155 156 # create exchange token 157 token = await create_exchange_token(session_id) 158 159 # verify token can be consumed (proves it was created correctly) 160 result = await consume_exchange_token(token) 161 assert result is not None 162 returned_session_id, is_dev_token = result 163 assert returned_session_id == session_id 164 assert is_dev_token is False 165 166 # verify token can't be reused 167 second_attempt = await consume_exchange_token(token) 168 assert second_attempt is None 169 170 171async def test_consume_exchange_token(db_session: AsyncSession): 172 """verify exchange token consumption works.""" 173 did = "did:plc:consume123" 174 handle = "consume.bsky.social" 175 oauth_data = {"access_token": "token"} 176 177 session_id = await create_session(did, handle, oauth_data) 178 token = await create_exchange_token(session_id) 179 180 # consume token 181 result = await consume_exchange_token(token) 182 assert result is not None 183 returned_session_id, is_dev_token = result 184 assert returned_session_id == session_id 185 assert is_dev_token is False 186 187 # verify token can't be consumed again (proves it was marked as used) 188 second_attempt = await consume_exchange_token(token) 189 assert second_attempt is None 190 191 192async def test_exchange_token_cannot_be_reused(db_session: AsyncSession): 193 """verify exchange token can only be used once.""" 194 did = "did:plc:reuse123" 195 handle = "reuse.bsky.social" 196 oauth_data = {"access_token": "token"} 197 198 session_id = await create_session(did, handle, oauth_data) 199 token = await create_exchange_token(session_id) 200 201 # consume token first time 202 result = await consume_exchange_token(token) 203 assert result is not None 204 returned_session_id, _is_dev_token = result 205 assert returned_session_id == session_id 206 207 # try to consume again - should return None 208 second_result = await consume_exchange_token(token) 209 assert second_result is None 210 211 212async def test_exchange_token_returns_none_for_invalid_token(db_session: AsyncSession): 213 """verify consume_exchange_token returns None for invalid token.""" 214 result = await consume_exchange_token("invalid_token_that_does_not_exist") 215 assert result is None 216 217 218async def test_exchange_token_expires(db_session: AsyncSession): 219 """verify expired exchange token returns None.""" 220 # use a separate database session to manually expire the token 221 from backend.utilities.database import db_session as get_db_session 222 223 did = "did:plc:expire123" 224 handle = "expire.bsky.social" 225 oauth_data = {"access_token": "token"} 226 227 session_id = await create_session(did, handle, oauth_data) 228 token = await create_exchange_token(session_id) 229 230 # manually expire the token by updating its expiration 231 async with get_db_session() as db: 232 result = await db.execute( 233 select(ExchangeToken).where(ExchangeToken.token == token) 234 ) 235 exchange_token = result.scalar_one_or_none() 236 assert exchange_token is not None 237 238 # set expiration to past 239 exchange_token.expires_at = datetime.now(UTC) - timedelta(seconds=1) 240 await db.commit() 241 242 # try to consume expired token - should return None 243 consumed = await consume_exchange_token(token) 244 assert consumed is None 245 246 247async def test_session_isolation(db_session: AsyncSession): 248 """verify each test starts with clean database.""" 249 # this should not see sessions from other tests 250 result = await db_session.execute(select(UserSession)) 251 sessions = result.scalars().all() 252 assert len(sessions) == 0 253 254 result = await db_session.execute(select(ExchangeToken)) 255 tokens = result.scalars().all() 256 assert len(tokens) == 0 257 258 259async def test_create_session_with_custom_expiration(db_session: AsyncSession): 260 """verify session creation with custom expiration works.""" 261 did = "did:plc:customexp123" 262 handle = "customexp.bsky.social" 263 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 264 265 # create session with 30-day expiration 266 session_id = await create_session(did, handle, oauth_data, expires_in_days=30) 267 268 # verify session exists and works 269 session = await get_session(session_id) 270 assert session is not None 271 assert session.did == did 272 273 # verify expiration is set (within reasonable range) 274 result = await db_session.execute( 275 select(UserSession).where(UserSession.session_id == session_id) 276 ) 277 db_session_record = result.scalar_one_or_none() 278 assert db_session_record is not None 279 assert db_session_record.expires_at is not None 280 281 # should expire roughly 30 days from now 282 expected_expiry = datetime.now(UTC) + timedelta(days=30) 283 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 284 diff = abs((expected_expiry - actual_expiry).total_seconds()) 285 assert diff < 60 # within 1 minute 286 287 288async def test_create_session_with_no_expiration(db_session: AsyncSession): 289 """verify session creation with expires_in_days=0 creates non-expiring session.""" 290 did = "did:plc:noexp123" 291 handle = "noexp.bsky.social" 292 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 293 294 # create session with no expiration 295 session_id = await create_session(did, handle, oauth_data, expires_in_days=0) 296 297 # verify session exists 298 session = await get_session(session_id) 299 assert session is not None 300 assert session.did == did 301 302 # verify expires_at is None 303 result = await db_session.execute( 304 select(UserSession).where(UserSession.session_id == session_id) 305 ) 306 db_session_record = result.scalar_one_or_none() 307 assert db_session_record is not None 308 assert db_session_record.expires_at is None 309 310 311async def test_create_session_default_expiration(db_session: AsyncSession): 312 """verify session creation uses default 14-day expiration.""" 313 did = "did:plc:defaultexp123" 314 handle = "defaultexp.bsky.social" 315 oauth_data = {"access_token": "token", "refresh_token": "refresh"} 316 317 # create session with default expiration 318 session_id = await create_session(did, handle, oauth_data) 319 320 # verify expiration is set to default (14 days) 321 result = await db_session.execute( 322 select(UserSession).where(UserSession.session_id == session_id) 323 ) 324 db_session_record = result.scalar_one_or_none() 325 assert db_session_record is not None 326 assert db_session_record.expires_at is not None 327 328 expected_expiry = datetime.now(UTC) + timedelta(days=14) 329 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC) 330 diff = abs((expected_expiry - actual_expiry).total_seconds()) 331 assert diff < 60 # within 1 minute 332 333 334# confidential client tests 335 336 337def test_is_confidential_client_false_by_default(): 338 """verify is_confidential_client returns False when OAUTH_JWK not set.""" 339 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 340 assert is_confidential_client() is False 341 342 343def test_is_confidential_client_true_when_configured(): 344 """verify is_confidential_client returns True when OAUTH_JWK is set.""" 345 test_jwk = '{"kty":"EC","crv":"P-256","x":"test","y":"test","d":"test"}' 346 347 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 348 assert is_confidential_client() is True 349 350 351def test_get_public_jwks_returns_none_without_config(): 352 """verify get_public_jwks returns None when OAUTH_JWK not configured.""" 353 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None): 354 assert get_public_jwks() is None 355 356 357def test_get_public_jwks_returns_public_key(): 358 """verify get_public_jwks returns JWKS with public key only.""" 359 # generate a test JWK 360 from cryptography.hazmat.primitives import serialization 361 from cryptography.hazmat.primitives.asymmetric import ec 362 from jose import jwk as jose_jwk 363 364 # generate test key 365 private_key = ec.generate_private_key(ec.SECP256R1()) 366 pem_bytes = private_key.private_bytes( 367 encoding=serialization.Encoding.PEM, 368 format=serialization.PrivateFormat.PKCS8, 369 encryption_algorithm=serialization.NoEncryption(), 370 ) 371 key_obj = jose_jwk.construct(pem_bytes, algorithm="ES256") 372 jwk_dict = key_obj.to_dict() 373 jwk_dict["kid"] = "test-key-id" # add kid to test preservation 374 test_jwk = json.dumps(jwk_dict) 375 376 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk): 377 jwks = get_public_jwks() 378 379 assert jwks is not None 380 assert "keys" in jwks 381 assert len(jwks["keys"]) == 1 382 383 public_key = jwks["keys"][0] 384 # should NOT have private key component 385 assert "d" not in public_key 386 # should have public key components 387 assert "x" in public_key 388 assert "y" in public_key 389 assert public_key["kty"] == "EC" 390 assert public_key["alg"] == "ES256" 391 assert public_key["use"] == "sig" 392 # should preserve kid from original JWK 393 assert public_key["kid"] == "test-key-id"