music on atproto
plyr.fm
1"""test OAuth authentication and session management."""
2
3import json
4from datetime import UTC, datetime, timedelta
5from unittest.mock import patch
6
7from sqlalchemy import select
8from sqlalchemy.ext.asyncio import AsyncSession
9
10from backend._internal.auth import (
11 _decrypt_data,
12 _encrypt_data,
13 consume_exchange_token,
14 create_exchange_token,
15 create_session,
16 delete_session,
17 get_public_jwks,
18 get_session,
19 is_confidential_client,
20 update_session_tokens,
21)
22from backend.models import ExchangeToken, UserSession
23
24
25def test_encryption_roundtrip():
26 """verify encryption and decryption work correctly."""
27 original_data = "sensitive oauth data"
28
29 encrypted = _encrypt_data(original_data)
30 decrypted = _decrypt_data(encrypted)
31
32 assert decrypted == original_data
33 assert encrypted != original_data # ensure it's actually encrypted
34
35
36def test_encryption_of_json_data():
37 """verify encryption works with json-serialized data."""
38 oauth_data = {
39 "did": "did:plc:test123",
40 "handle": "test.bsky.social",
41 "access_token": "secret_token_123",
42 "refresh_token": "secret_refresh_456",
43 "dpop_private_key_pem": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBg==\n-----END PRIVATE KEY-----",
44 }
45
46 json_str = json.dumps(oauth_data)
47 encrypted = _encrypt_data(json_str)
48 decrypted = _decrypt_data(encrypted)
49
50 assert decrypted is not None
51 assert json.loads(decrypted) == oauth_data
52
53
54async def test_create_session_with_encryption(db_session: AsyncSession):
55 """verify session creation encrypts OAuth data."""
56 did = "did:plc:session123"
57 handle = "session.bsky.social"
58 oauth_data = {
59 "did": did,
60 "handle": handle,
61 "access_token": "secret_token",
62 "refresh_token": "secret_refresh",
63 "dpop_private_key_pem": "secret_key",
64 }
65
66 session_id = await create_session(did, handle, oauth_data)
67
68 # retrieve session and verify it was created correctly
69 session = await get_session(session_id)
70 assert session is not None
71 assert session.did == did
72 assert session.handle == handle
73 assert session.oauth_session["access_token"] == "secret_token"
74 assert session.oauth_session["refresh_token"] == "secret_refresh"
75
76
77async def test_get_session_decrypts_data(db_session: AsyncSession):
78 """verify get_session correctly decrypts OAuth data."""
79 did = "did:plc:decrypt123"
80 handle = "decrypt.bsky.social"
81 oauth_data = {
82 "did": did,
83 "handle": handle,
84 "access_token": "secret_token_xyz",
85 "refresh_token": "secret_refresh_xyz",
86 }
87
88 session_id = await create_session(did, handle, oauth_data)
89
90 # retrieve and verify decryption
91 session = await get_session(session_id)
92
93 assert session is not None
94 assert session.did == did
95 assert session.handle == handle
96 assert session.oauth_session["access_token"] == "secret_token_xyz"
97 assert session.oauth_session["refresh_token"] == "secret_refresh_xyz"
98
99
100async def test_get_session_returns_none_for_invalid_id(db_session: AsyncSession):
101 """verify get_session returns None for non-existent session."""
102 session = await get_session("invalid_session_id_that_does_not_exist")
103 assert session is None
104
105
106async def test_update_session_tokens(db_session: AsyncSession):
107 """verify session token update encrypts new data."""
108 did = "did:plc:update123"
109 handle = "update.bsky.social"
110 original_oauth_data = {
111 "access_token": "original_token",
112 "refresh_token": "original_refresh",
113 }
114
115 session_id = await create_session(did, handle, original_oauth_data)
116
117 # update with new tokens
118 updated_oauth_data = {"access_token": "new_token", "refresh_token": "new_refresh"}
119 await update_session_tokens(session_id, updated_oauth_data)
120
121 # verify tokens were updated
122 session = await get_session(session_id)
123 assert session is not None
124 assert session.oauth_session["access_token"] == "new_token"
125 assert session.oauth_session["refresh_token"] == "new_refresh"
126
127
128async def test_delete_session(db_session: AsyncSession):
129 """verify session deletion works."""
130 did = "did:plc:delete123"
131 handle = "delete.bsky.social"
132 oauth_data = {"access_token": "token"}
133
134 session_id = await create_session(did, handle, oauth_data)
135
136 # verify session exists
137 session = await get_session(session_id)
138 assert session is not None
139
140 # delete session
141 await delete_session(session_id)
142
143 # verify session is gone
144 session = await get_session(session_id)
145 assert session is None
146
147
148async def test_create_exchange_token(db_session: AsyncSession):
149 """verify exchange token creation."""
150 did = "did:plc:exchange123"
151 handle = "exchange.bsky.social"
152 oauth_data = {"access_token": "token"}
153
154 session_id = await create_session(did, handle, oauth_data)
155
156 # create exchange token
157 token = await create_exchange_token(session_id)
158
159 # verify token can be consumed (proves it was created correctly)
160 result = await consume_exchange_token(token)
161 assert result is not None
162 returned_session_id, is_dev_token = result
163 assert returned_session_id == session_id
164 assert is_dev_token is False
165
166 # verify token can't be reused
167 second_attempt = await consume_exchange_token(token)
168 assert second_attempt is None
169
170
171async def test_consume_exchange_token(db_session: AsyncSession):
172 """verify exchange token consumption works."""
173 did = "did:plc:consume123"
174 handle = "consume.bsky.social"
175 oauth_data = {"access_token": "token"}
176
177 session_id = await create_session(did, handle, oauth_data)
178 token = await create_exchange_token(session_id)
179
180 # consume token
181 result = await consume_exchange_token(token)
182 assert result is not None
183 returned_session_id, is_dev_token = result
184 assert returned_session_id == session_id
185 assert is_dev_token is False
186
187 # verify token can't be consumed again (proves it was marked as used)
188 second_attempt = await consume_exchange_token(token)
189 assert second_attempt is None
190
191
192async def test_exchange_token_cannot_be_reused(db_session: AsyncSession):
193 """verify exchange token can only be used once."""
194 did = "did:plc:reuse123"
195 handle = "reuse.bsky.social"
196 oauth_data = {"access_token": "token"}
197
198 session_id = await create_session(did, handle, oauth_data)
199 token = await create_exchange_token(session_id)
200
201 # consume token first time
202 result = await consume_exchange_token(token)
203 assert result is not None
204 returned_session_id, _is_dev_token = result
205 assert returned_session_id == session_id
206
207 # try to consume again - should return None
208 second_result = await consume_exchange_token(token)
209 assert second_result is None
210
211
212async def test_exchange_token_returns_none_for_invalid_token(db_session: AsyncSession):
213 """verify consume_exchange_token returns None for invalid token."""
214 result = await consume_exchange_token("invalid_token_that_does_not_exist")
215 assert result is None
216
217
218async def test_exchange_token_expires(db_session: AsyncSession):
219 """verify expired exchange token returns None."""
220 # use a separate database session to manually expire the token
221 from backend.utilities.database import db_session as get_db_session
222
223 did = "did:plc:expire123"
224 handle = "expire.bsky.social"
225 oauth_data = {"access_token": "token"}
226
227 session_id = await create_session(did, handle, oauth_data)
228 token = await create_exchange_token(session_id)
229
230 # manually expire the token by updating its expiration
231 async with get_db_session() as db:
232 result = await db.execute(
233 select(ExchangeToken).where(ExchangeToken.token == token)
234 )
235 exchange_token = result.scalar_one_or_none()
236 assert exchange_token is not None
237
238 # set expiration to past
239 exchange_token.expires_at = datetime.now(UTC) - timedelta(seconds=1)
240 await db.commit()
241
242 # try to consume expired token - should return None
243 consumed = await consume_exchange_token(token)
244 assert consumed is None
245
246
247async def test_session_isolation(db_session: AsyncSession):
248 """verify each test starts with clean database."""
249 # this should not see sessions from other tests
250 result = await db_session.execute(select(UserSession))
251 sessions = result.scalars().all()
252 assert len(sessions) == 0
253
254 result = await db_session.execute(select(ExchangeToken))
255 tokens = result.scalars().all()
256 assert len(tokens) == 0
257
258
259async def test_create_session_with_custom_expiration(db_session: AsyncSession):
260 """verify session creation with custom expiration works."""
261 did = "did:plc:customexp123"
262 handle = "customexp.bsky.social"
263 oauth_data = {"access_token": "token", "refresh_token": "refresh"}
264
265 # create session with 30-day expiration
266 session_id = await create_session(did, handle, oauth_data, expires_in_days=30)
267
268 # verify session exists and works
269 session = await get_session(session_id)
270 assert session is not None
271 assert session.did == did
272
273 # verify expiration is set (within reasonable range)
274 result = await db_session.execute(
275 select(UserSession).where(UserSession.session_id == session_id)
276 )
277 db_session_record = result.scalar_one_or_none()
278 assert db_session_record is not None
279 assert db_session_record.expires_at is not None
280
281 # should expire roughly 30 days from now
282 expected_expiry = datetime.now(UTC) + timedelta(days=30)
283 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC)
284 diff = abs((expected_expiry - actual_expiry).total_seconds())
285 assert diff < 60 # within 1 minute
286
287
288async def test_create_session_with_no_expiration(db_session: AsyncSession):
289 """verify session creation with expires_in_days=0 creates non-expiring session."""
290 did = "did:plc:noexp123"
291 handle = "noexp.bsky.social"
292 oauth_data = {"access_token": "token", "refresh_token": "refresh"}
293
294 # create session with no expiration
295 session_id = await create_session(did, handle, oauth_data, expires_in_days=0)
296
297 # verify session exists
298 session = await get_session(session_id)
299 assert session is not None
300 assert session.did == did
301
302 # verify expires_at is None
303 result = await db_session.execute(
304 select(UserSession).where(UserSession.session_id == session_id)
305 )
306 db_session_record = result.scalar_one_or_none()
307 assert db_session_record is not None
308 assert db_session_record.expires_at is None
309
310
311async def test_create_session_default_expiration(db_session: AsyncSession):
312 """verify session creation uses default 14-day expiration."""
313 did = "did:plc:defaultexp123"
314 handle = "defaultexp.bsky.social"
315 oauth_data = {"access_token": "token", "refresh_token": "refresh"}
316
317 # create session with default expiration
318 session_id = await create_session(did, handle, oauth_data)
319
320 # verify expiration is set to default (14 days)
321 result = await db_session.execute(
322 select(UserSession).where(UserSession.session_id == session_id)
323 )
324 db_session_record = result.scalar_one_or_none()
325 assert db_session_record is not None
326 assert db_session_record.expires_at is not None
327
328 expected_expiry = datetime.now(UTC) + timedelta(days=14)
329 actual_expiry = db_session_record.expires_at.replace(tzinfo=UTC)
330 diff = abs((expected_expiry - actual_expiry).total_seconds())
331 assert diff < 60 # within 1 minute
332
333
334# confidential client tests
335
336
337def test_is_confidential_client_false_by_default():
338 """verify is_confidential_client returns False when OAUTH_JWK not set."""
339 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None):
340 assert is_confidential_client() is False
341
342
343def test_is_confidential_client_true_when_configured():
344 """verify is_confidential_client returns True when OAUTH_JWK is set."""
345 test_jwk = '{"kty":"EC","crv":"P-256","x":"test","y":"test","d":"test"}'
346
347 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk):
348 assert is_confidential_client() is True
349
350
351def test_get_public_jwks_returns_none_without_config():
352 """verify get_public_jwks returns None when OAUTH_JWK not configured."""
353 with patch("backend._internal.auth.settings.atproto.oauth_jwk", None):
354 assert get_public_jwks() is None
355
356
357def test_get_public_jwks_returns_public_key():
358 """verify get_public_jwks returns JWKS with public key only."""
359 # generate a test JWK
360 from cryptography.hazmat.primitives import serialization
361 from cryptography.hazmat.primitives.asymmetric import ec
362 from jose import jwk as jose_jwk
363
364 # generate test key
365 private_key = ec.generate_private_key(ec.SECP256R1())
366 pem_bytes = private_key.private_bytes(
367 encoding=serialization.Encoding.PEM,
368 format=serialization.PrivateFormat.PKCS8,
369 encryption_algorithm=serialization.NoEncryption(),
370 )
371 key_obj = jose_jwk.construct(pem_bytes, algorithm="ES256")
372 jwk_dict = key_obj.to_dict()
373 jwk_dict["kid"] = "test-key-id" # add kid to test preservation
374 test_jwk = json.dumps(jwk_dict)
375
376 with patch("backend._internal.auth.settings.atproto.oauth_jwk", test_jwk):
377 jwks = get_public_jwks()
378
379 assert jwks is not None
380 assert "keys" in jwks
381 assert len(jwks["keys"]) == 1
382
383 public_key = jwks["keys"][0]
384 # should NOT have private key component
385 assert "d" not in public_key
386 # should have public key components
387 assert "x" in public_key
388 assert "y" in public_key
389 assert public_key["kty"] == "EC"
390 assert public_key["alg"] == "ES256"
391 assert public_key["use"] == "sig"
392 # should preserve kid from original JWK
393 assert public_key["kid"] == "test-key-id"