music on atproto
plyr.fm
1"""tests for scope upgrade OAuth flow."""
2
3from collections.abc import Generator
4from unittest.mock import AsyncMock, patch
5
6import pytest
7from fastapi import FastAPI
8from httpx import ASGITransport, AsyncClient
9from sqlalchemy.ext.asyncio import AsyncSession
10
11from backend._internal import Session, require_auth
12from backend.main import app
13
14
15class MockSession(Session):
16 """mock session for auth bypass in tests."""
17
18 def __init__(self, did: str = "did:test:user123"):
19 self.did = did
20 self.handle = "testuser.bsky.social"
21 self.session_id = "test_session_id_for_upgrade"
22 self.access_token = "test_token"
23 self.refresh_token = "test_refresh"
24 self.oauth_session = {
25 "did": did,
26 "handle": "testuser.bsky.social",
27 "pds_url": "https://test.pds",
28 "authserver_iss": "https://auth.test",
29 "scope": "atproto transition:generic",
30 "access_token": "test_token",
31 "refresh_token": "test_refresh",
32 "dpop_private_key_pem": "fake_key",
33 "dpop_authserver_nonce": "",
34 "dpop_pds_nonce": "",
35 }
36
37
38@pytest.fixture
39def test_app(db_session: AsyncSession) -> Generator[FastAPI, None, None]:
40 """create test app with mocked auth."""
41
42 async def mock_require_auth() -> Session:
43 return MockSession()
44
45 app.dependency_overrides[require_auth] = mock_require_auth
46
47 yield app
48
49 app.dependency_overrides.clear()
50
51
52async def test_start_scope_upgrade_flow(test_app: FastAPI, db_session: AsyncSession):
53 """test starting the scope upgrade OAuth flow."""
54 with patch(
55 "backend.api.auth.start_oauth_flow_with_scopes", new_callable=AsyncMock
56 ) as mock_oauth:
57 mock_oauth.return_value = (
58 "https://auth.example.com/authorize?scope=teal",
59 "test_state",
60 )
61
62 async with AsyncClient(
63 transport=ASGITransport(app=test_app), base_url="http://test"
64 ) as client:
65 response = await client.post(
66 "/auth/scope-upgrade/start",
67 json={"include_teal": True},
68 )
69
70 assert response.status_code == 200
71 data = response.json()
72 assert "auth_url" in data
73 assert data["auth_url"].startswith("https://auth.example.com")
74 mock_oauth.assert_called_once_with("testuser.bsky.social", include_teal=True)
75
76
77async def test_start_scope_upgrade_default_includes_teal(
78 test_app: FastAPI, db_session: AsyncSession
79):
80 """test that scope upgrade defaults to including teal scopes."""
81 with patch(
82 "backend.api.auth.start_oauth_flow_with_scopes", new_callable=AsyncMock
83 ) as mock_oauth:
84 mock_oauth.return_value = ("https://auth.example.com/authorize", "test_state")
85
86 async with AsyncClient(
87 transport=ASGITransport(app=test_app), base_url="http://test"
88 ) as client:
89 response = await client.post(
90 "/auth/scope-upgrade/start",
91 json={}, # empty body - should default to include_teal=True
92 )
93
94 assert response.status_code == 200
95 mock_oauth.assert_called_once_with("testuser.bsky.social", include_teal=True)
96
97
98async def test_scope_upgrade_requires_auth(db_session: AsyncSession):
99 """test that scope upgrade requires authentication."""
100 async with AsyncClient(
101 transport=ASGITransport(app=app), base_url="http://test"
102 ) as client:
103 response = await client.post(
104 "/auth/scope-upgrade/start",
105 json={"include_teal": True},
106 )
107
108 assert response.status_code == 401
109
110
111async def test_scope_upgrade_saves_pending_record(
112 test_app: FastAPI, db_session: AsyncSession
113):
114 """test that starting scope upgrade saves pending record."""
115 from backend._internal import get_pending_scope_upgrade
116
117 with patch(
118 "backend.api.auth.start_oauth_flow_with_scopes", new_callable=AsyncMock
119 ) as mock_oauth:
120 mock_oauth.return_value = ("https://auth.example.com/authorize", "test_state")
121
122 async with AsyncClient(
123 transport=ASGITransport(app=test_app), base_url="http://test"
124 ) as client:
125 response = await client.post(
126 "/auth/scope-upgrade/start",
127 json={"include_teal": True},
128 )
129
130 assert response.status_code == 200
131
132 # verify pending record was saved
133 pending = await get_pending_scope_upgrade("test_state")
134 assert pending is not None
135 assert pending.did == "did:test:user123"
136 assert pending.old_session_id == "test_session_id_for_upgrade"
137 assert pending.requested_scopes == "teal"