music on atproto
plyr.fm
1"""test OAuth authentication and session management."""
2
3import json
4from datetime import UTC, datetime, timedelta
5from unittest.mock import patch
6
7import pytest
8from fastapi import HTTPException
9from sqlalchemy import select
10from sqlalchemy.ext.asyncio import AsyncSession
11
12from backend._internal.auth import (
13 _decrypt_data,
14 _encrypt_data,
15 consume_exchange_token,
16 create_exchange_token,
17 create_session,
18 delete_session,
19 get_public_jwks,
20 get_refresh_token_lifetime_days,
21 get_session,
22 is_confidential_client,
23 update_session_tokens,
24)
25from backend.models import ExchangeToken, UserSession
26
27
28def test_encryption_roundtrip():
29 """verify encryption and decryption work correctly."""
30 original_data = "sensitive oauth data"
31
32 encrypted = _encrypt_data(original_data)
33 decrypted = _decrypt_data(encrypted)
34
35 assert decrypted == original_data
36 assert encrypted != original_data # ensure it's actually encrypted
37
38
39def test_encryption_of_json_data():
40 """verify encryption works with json-serialized data."""
41 oauth_data = {
42 "did": "did:plc:test123",
43 "handle": "test.bsky.social",
44 "access_token": "secret_token_123",
45 "refresh_token": "secret_refresh_456",
46 "dpop_private_key_pem": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg==\n-----END PRIVATE KEY-----",
47 }
48
49 json_str = json.dumps(oauth_data)
50 encrypted = _encrypt_data(json_str)
51 decrypted = _decrypt_data(encrypted)
52
53 assert decrypted is not None
54 assert json.loads(decrypted) == oauth_data
55
56
57async def test_create_session_with_encryption(db_session: AsyncSession):
58 """verify session creation encrypts OAuth data."""
59 did = "did:plc:session123"
60 handle = "session.bsky.social"
61 oauth_data = {
62 "did": did,
63 "handle": handle,
64 "access_token": "secret_token",
65 "refresh_token": "secret_refresh",
66 "dpop_private_key_pem": "secret_key",
67 }
68
69 session_id = await create_session(did, handle, oauth_data)
70
71 # retrieve session and verify it was created correctly
72 session = await get_session(session_id)
73 assert session is not None
74 assert session.did == did
75 assert session.handle == handle
76 assert session.oauth_session["access_token"] == "secret_token"
77 assert session.oauth_session["refresh_token"] == "secret_refresh"
78
79
80async def test_get_session_decrypts_data(db_session: AsyncSession):
81 """verify get_session correctly decrypts OAuth data."""
82 did = "did:plc:decrypt123"
83 handle = "decrypt.bsky.social"
84 oauth_data = {
85 "did": did,
86 "handle": handle,
87 "access_token": "secret_token_xyz",
88 "refresh_token": "secret_refresh_xyz",
89 }
90
91 session_id = await create_session(did, handle, oauth_data)
92
93 # retrieve and verify decryption
94 session = await get_session(session_id)
95
96 assert session is not None
97 assert session.did == did
98 assert session.handle == handle
99 assert session.oauth_session["access_token"] == "secret_token_xyz"
100 assert session.oauth_session["refresh_token"] == "secret_refresh_xyz"
101
102
103async def test_get_session_returns_none_for_invalid_id(db_session: AsyncSession):
104 """verify get_session returns None for non-existent session."""
105 session = await get_session("invalid_session_id_that_does_not_exist")
106 assert session is None
107
108
109async def test_update_session_tokens(db_session: AsyncSession):
110 """verify session token update encrypts new data."""
111 did = "did:plc:update123"
112 handle = "update.bsky.social"
113 original_oauth_data = {
114 "access_token": "original_token",
115 "refresh_token": "original_refresh",
116 }
117
118 session_id = await create_session(did, handle, original_oauth_data)
119
120 # update with new tokens
121 updated_oauth_data = {"access_token": "new_token", "refresh_token": "new_refresh"}
122 await update_session_tokens(session_id, updated_oauth_data)
123
124 # verify tokens were updated
125 session = await get_session(session_id)
126 assert session is not None
127 assert session.oauth_session["access_token"] == "new_token"
128 assert session.oauth_session["refresh_token"] == "new_refresh"
129
130
131async def test_delete_session(db_session: AsyncSession):
132 """verify session deletion works."""
133 did = "did:plc:delete123"
134 handle = "delete.bsky.social"
135 oauth_data = {"access_token": "token"}
136
137 session_id = await create_session(did, handle, oauth_data)
138
139 # verify session exists
140 session = await get_session(session_id)
141 assert session is not None
142
143 # delete session
144 await delete_session(session_id)
145
146 # verify session is gone
147 session = await get_session(session_id)
148 assert session is None
149
150
151async def test_create_exchange_token(db_session: AsyncSession):
152 """verify exchange token creation."""
153 did = "did:plc:exchange123"
154 handle = "exchange.bsky.social"
155 oauth_data = {"access_token": "token"}
156
157 session_id = await create_session(did, handle, oauth_data)
158
159 # create exchange token
160 token = await create_exchange_token(session_id)
161
162 # verify token can be consumed (proves it was created correctly)
163 result = await consume_exchange_token(token)
164 assert result is not None
165 returned_session_id, is_dev_token = result
166 assert returned_session_id == session_id
167 assert is_dev_token is False
168
169 # verify token can't be reused
170 second_attempt = await consume_exchange_token(token)
171 assert second_attempt is None
172
173
174async def test_consume_exchange_token(db_session: AsyncSession):
175 """verify exchange token consumption works."""
176 did = "did:plc:consume123"
177 handle = "consume.bsky.social"
178 oauth_data = {"access_token": "token"}
179
180 session_id = await create_session(did, handle, oauth_data)
181 token = await create_exchange_token(session_id)
182
183 # consume token
184 result = await consume_exchange_token(token)
185 assert result is not None
186 returned_session_id, is_dev_token = result
187 assert returned_session_id == session_id
188 assert is_dev_token is False
189
190 # verify token can't be consumed again (proves it was marked as used)
191 second_attempt = await consume_exchange_token(token)
192 assert second_attempt is None
193
194
195async def test_exchange_token_cannot_be_reused(db_session: AsyncSession):
196 """verify exchange token can only be used once."""
197 did = "did:plc:reuse123"
198 handle = "reuse.bsky.social"
199 oauth_data = {"access_token": "token"}
200
201 session_id = await create_session(did, handle, oauth_data)
202 token = await create_exchange_token(session_id)
203
204 # consume token first time
205 result = await consume_exchange_token(token)
206 assert result is not None
207 returned_session_id, _is_dev_token = result
208 assert returned_session_id == session_id
209
210 # try to consume again - should return None
211 second_result = await consume_exchange_token(token)
212 assert second_result is None
213
214
215async def test_exchange_token_returns_none_for_invalid_token(db_session: AsyncSession):
216 """verify consume_exchange_token returns None for invalid token."""
217 result = await consume_exchange_token("invalid_token_that_does_not_exist")
218 assert result is None
219
220
221async def test_exchange_token_expires(db_session: AsyncSession):
222 """verify expired exchange token returns None."""
223 # use a separate database session to manually expire the token
224 from backend.utilities.database import db_session as get_db_session
225
226 did = "did:plc:expire123"
227 handle = "expire.bsky.social"
228 oauth_data = {"access_token": "token"}
229
230 session_id = await create_session(did, handle, oauth_data)
231 token = await create_exchange_token(session_id)
232
233 # manually expire the token by updating its expiration
234 async with get_db_session() as db:
235 result = await db.execute(
236 select(ExchangeToken).where(ExchangeToken.token == token)
237 )
238 exchange_token = result.scalar_one_or_none()
239 assert exchange_token is not None
240
241 # set expiration to past
242 exchange_token.expires_at = datetime.now(UTC) - timedelta(seconds=1)
243 await db.commit()
244
245 # try to consume expired token - should return None
246 consumed = await consume_exchange_token(token)
247 assert consumed is None
248
249
250async def test_session_isolation(db_session: AsyncSession):
251 """verify each test starts with clean database."""
252 # this should not see sessions from other tests
253 result = await db_session.execute(select(UserSession))
254 sessions = result.scalars().all()
255 assert len(sessions) == 0
256
257 result = await db_session.execute(select(ExchangeToken))
258 tokens = result.scalars().all()
259 assert len(tokens) == 0
260
261
262async def test_create_session_with_custom_expiration(db_session: AsyncSession):
263 """verify session creation with custom expiration is capped by refresh lifetime."""
264 did = "did:plc:customexp123"
265 handle = "customexp.bsky.social"
266 oauth_data = {"access_token": "token", "refresh_token": "refresh"}
267
268 # create session with 30-day expiration
269 session_id = await create_session(did, handle, oauth_data, expires_in_days=30)
270
271 # verify session exists and works
272 session = await get_session(session_id)
273 assert session is not None
274 assert session.did == did
275
276 # verify expiration is set (within reasonable range)
277 result = await db_session.execute(
278 select(UserSession).where(UserSession.session_id == session_id)
279 )
280 db_session_record = result.scalar_one_or_none()
281 assert db_session_record is not None
282 assert db_session_record.expires_at is not None
283
284 expected_days = min(30, get_refresh_token_lifetime_days(None))
285 # should expire roughly expected_days from now
286 expected_expiry = datetime.now(UTC) + timedelta(days=expected_days)
287 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC)
288 diff = abs((expected_expiry - actual_expiry).total_seconds())
289 assert diff < 60 # within 1 minute
290
291
292async def test_create_session_with_no_expiration(db_session: AsyncSession):
293 """verify session creation with expires_in_days=0 caps to refresh lifetime."""
294 did = "did:plc:noexp123"
295 handle = "noexp.bsky.social"
296 oauth_data = {"access_token": "token", "refresh_token": "refresh"}
297
298 # create session with no expiration
299 session_id = await create_session(did, handle, oauth_data, expires_in_days=0)
300
301 # verify session exists
302 session = await get_session(session_id)
303 assert session is not None
304 assert session.did == did
305
306 # verify expires_at is capped to refresh token lifetime
307 result = await db_session.execute(
308 select(UserSession).where(UserSession.session_id == session_id)
309 )
310 db_session_record = result.scalar_one_or_none()
311 assert db_session_record is not None
312 assert db_session_record.expires_at is not None
313
314 expected_days = get_refresh_token_lifetime_days(None)
315 expected_expiry = datetime.now(UTC) + timedelta(days=expected_days)
316 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC)
317 diff = abs((expected_expiry - actual_expiry).total_seconds())
318 assert diff < 60 # within 1 minute
319
320
321async def test_create_session_default_expiration(db_session: AsyncSession):
322 """verify session creation uses default 14-day expiration."""
323 did = "did:plc:defaultexp123"
324 handle = "defaultexp.bsky.social"
325 oauth_data = {"access_token": "token", "refresh_token": "refresh"}
326
327 # create session with default expiration
328 session_id = await create_session(did, handle, oauth_data)
329
330 # verify expiration is set to default (14 days)
331 result = await db_session.execute(
332 select(UserSession).where(UserSession.session_id == session_id)
333 )
334 db_session_record = result.scalar_one_or_none()
335 assert db_session_record is not None
336 assert db_session_record.expires_at is not None
337
338 expected_expiry = datetime.now(UTC) + timedelta(days=14)
339 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC)
340 diff = abs((expected_expiry - actual_expiry).total_seconds())
341 assert diff < 60 # within 1 minute
342
343
344# confidential client tests
345
346
347def test_is_confidential_client_false_by_default():
348 """verify is_confidential_client returns False when OAUTH_JWK not set."""
349 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None):
350 assert is_confidential_client() is False
351
352
353def test_is_confidential_client_true_when_configured():
354 """verify is_confidential_client returns True when OAUTH_JWK is set."""
355 test_jwk = '{"kty":"EC","crv":"P-256","x":"test","y":"test","d":"test"}'
356
357 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk):
358 assert is_confidential_client() is True
359
360
361def test_get_public_jwks_returns_none_without_config():
362 """verify get_public_jwks returns None when OAUTH_JWK not configured."""
363 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None):
364 assert get_public_jwks() is None
365
366
367def test_get_public_jwks_returns_public_key():
368 """verify get_public_jwks returns JWKS with public key only."""
369 # generate a test JWK
370 from cryptography.hazmat.primitives import serialization
371 from cryptography.hazmat.primitives.asymmetric import ec
372 from jose import jwk as jose_jwk
373
374 # generate test key
375 private_key = ec.generate_private_key(ec.SECP256R1())
376 pem_bytes = private_key.private_bytes(
377 encoding=serialization.Encoding.PEM,
378 format=serialization.PrivateFormat.PKCS8,
379 encryption_algorithm=serialization.NoEncryption(),
380 )
381 key_obj = jose_jwk.construct(pem_bytes, algorithm="ES256")
382 jwk_dict = key_obj.to_dict()
383 jwk_dict["kid"] = "test-key-id" # add kid to test preservation
384 test_jwk = json.dumps(jwk_dict)
385
386 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk):
387 jwks = get_public_jwks()
388
389 assert jwks is not None
390 assert "keys" in jwks
391 assert len(jwks["keys"]) == 1
392
393 public_key = jwks["keys"][0]
394 # should NOT have private key component
395 assert "d" not in public_key
396 # should have public key components
397 assert "x" in public_key
398 assert "y" in public_key
399 assert public_key["kty"] == "EC"
400 assert public_key["alg"] == "ES256"
401 assert public_key["use"] == "sig"
402 # should preserve kid from original JWK
403 assert public_key["kid"] == "test-key-id"
404
405
406# multi-account tests
407
408
409async def test_get_or_create_group_id_creates_new(db_session: AsyncSession):
410 """verify group_id is created when session has none."""
411 from backend._internal.auth import get_or_create_group_id
412
413 session_id = await create_session(
414 "did:plc:group1", "group1.bsky.social", {"access_token": "t1"}
415 )
416
417 group_id = await get_or_create_group_id(session_id)
418 assert group_id is not None
419
420 # calling again returns same group_id
421 assert await get_or_create_group_id(session_id) == group_id
422
423
424async def test_get_session_group_empty_without_group(db_session: AsyncSession):
425 """verify get_session_group returns empty list for ungrouped session."""
426 from backend._internal.auth import get_session_group
427
428 session_id = await create_session(
429 "did:plc:solo", "solo.bsky.social", {"access_token": "t1"}
430 )
431
432 accounts = await get_session_group(session_id)
433 assert accounts == []
434
435
436async def test_get_session_group_returns_linked_accounts(db_session: AsyncSession):
437 """verify get_session_group returns all accounts in group."""
438 from backend._internal.auth import get_or_create_group_id, get_session_group
439
440 session1 = await create_session(
441 "did:plc:user1", "user1.bsky.social", {"access_token": "t1"}
442 )
443 session2 = await create_session(
444 "did:plc:user2", "user2.bsky.social", {"access_token": "t2"}
445 )
446
447 # link sessions to same group
448 group_id = await get_or_create_group_id(session1)
449
450 result = await db_session.execute(
451 select(UserSession).where(UserSession.session_id == session2)
452 )
453 s2 = result.scalar_one()
454 s2.group_id = group_id
455 await db_session.commit()
456
457 accounts = await get_session_group(session1)
458 assert len(accounts) == 2
459 dids = {a.did for a in accounts}
460 assert dids == {"did:plc:user1", "did:plc:user2"}
461
462
463async def test_switch_active_account_validates_group(db_session: AsyncSession):
464 """verify switch_active_account rejects sessions not in same group."""
465 from backend._internal.auth import get_or_create_group_id, switch_active_account
466
467 session1 = await create_session(
468 "did:plc:s1", "s1.bsky.social", {"access_token": "t1"}
469 )
470 session2 = await create_session(
471 "did:plc:s2", "s2.bsky.social", {"access_token": "t2"}
472 )
473
474 await get_or_create_group_id(session1)
475
476 # session2 not in group - should fail
477 with pytest.raises(HTTPException) as exc_info:
478 await switch_active_account(session1, session2)
479 assert isinstance(exc_info.value, HTTPException)
480 assert exc_info.value.status_code == 403
481
482
483async def test_switch_active_account_success(db_session: AsyncSession):
484 """verify switch_active_account works for same-group sessions."""
485 from backend._internal.auth import get_or_create_group_id, switch_active_account
486
487 session1 = await create_session(
488 "did:plc:sw1", "sw1.bsky.social", {"access_token": "t1"}
489 )
490 session2 = await create_session(
491 "did:plc:sw2", "sw2.bsky.social", {"access_token": "t2"}
492 )
493
494 group_id = await get_or_create_group_id(session1)
495
496 result = await db_session.execute(
497 select(UserSession).where(UserSession.session_id == session2)
498 )
499 s2 = result.scalar_one()
500 s2.group_id = group_id
501 await db_session.commit()
502
503 target = await switch_active_account(session1, session2)
504 assert target == session2
505
506
507async def test_remove_account_from_group_last_account(db_session: AsyncSession):
508 """verify remove_account_from_group returns None when last account removed."""
509 from backend._internal.auth import remove_account_from_group
510
511 session_id = await create_session(
512 "did:plc:last", "last.bsky.social", {"access_token": "t1"}
513 )
514
515 result = await remove_account_from_group(session_id)
516 assert result is None
517
518 # session should be deleted
519 assert await get_session(session_id) is None
520
521
522async def test_remove_account_from_group_returns_next(db_session: AsyncSession):
523 """verify remove_account_from_group returns next session when others remain."""
524 from backend._internal.auth import get_or_create_group_id, remove_account_from_group
525
526 session1 = await create_session(
527 "did:plc:rem1", "rem1.bsky.social", {"access_token": "t1"}
528 )
529 session2 = await create_session(
530 "did:plc:rem2", "rem2.bsky.social", {"access_token": "t2"}
531 )
532
533 group_id = await get_or_create_group_id(session1)
534
535 result = await db_session.execute(
536 select(UserSession).where(UserSession.session_id == session2)
537 )
538 s2 = result.scalar_one()
539 s2.group_id = group_id
540 await db_session.commit()
541
542 next_session = await remove_account_from_group(session1)
543 assert next_session == session2
544
545 # session1 deleted, session2 remains
546 assert await get_session(session1) is None
547 assert await get_session(session2) is not None
548
549
550async def test_pending_add_account_crud(db_session: AsyncSession):
551 """verify pending add account save/get/delete cycle."""
552 from backend._internal.auth import (
553 delete_pending_add_account,
554 get_pending_add_account,
555 save_pending_add_account,
556 )
557
558 state = "test-oauth-state-123"
559 group_id = "test-group-id-456"
560
561 await save_pending_add_account(state, group_id)
562
563 pending = await get_pending_add_account(state)
564 assert pending is not None
565 assert pending.state == state
566 assert pending.group_id == group_id
567
568 await delete_pending_add_account(state)
569 assert await get_pending_add_account(state) is None