at main 4.7 kB view raw
1"""tests for scope upgrade OAuth flow.""" 2 3from collections.abc import Generator 4from unittest.mock import AsyncMock, patch 5 6import pytest 7from fastapi import FastAPI 8from httpx import ASGITransport, AsyncClient 9from sqlalchemy.ext.asyncio import AsyncSession 10 11from backend._internal import Session, require_auth 12from backend.main import app 13 14 15class MockSession(Session): 16 """mock session for auth bypass in tests.""" 17 18 def __init__(self, did: str = "did:test:user123"): 19 self.did = did 20 self.handle = "testuser.bsky.social" 21 self.session_id = "test_session_id_for_upgrade" 22 self.access_token = "test_token" 23 self.refresh_token = "test_refresh" 24 self.oauth_session = { 25 "did": did, 26 "handle": "testuser.bsky.social", 27 "pds_url": "https://test.pds", 28 "authserver_iss": "https://auth.test", 29 "scope": "atproto transition:generic", 30 "access_token": "test_token", 31 "refresh_token": "test_refresh", 32 "dpop_private_key_pem": "fake_key", 33 "dpop_authserver_nonce": "", 34 "dpop_pds_nonce": "", 35 } 36 37 38@pytest.fixture 39def test_app(db_session: AsyncSession) -> Generator[FastAPI, None, None]: 40 """create test app with mocked auth.""" 41 42 async def mock_require_auth() -> Session: 43 return MockSession() 44 45 app.dependency_overrides[require_auth] = mock_require_auth 46 47 yield app 48 49 app.dependency_overrides.clear() 50 51 52async def test_start_scope_upgrade_flow(test_app: FastAPI, db_session: AsyncSession): 53 """test starting the scope upgrade OAuth flow.""" 54 with patch( 55 "backend.api.auth.start_oauth_flow_with_scopes", new_callable=AsyncMock 56 ) as mock_oauth: 57 mock_oauth.return_value = ( 58 "https://auth.example.com/authorize?scope=teal", 59 "test_state", 60 ) 61 62 async with AsyncClient( 63 transport=ASGITransport(app=test_app), base_url="http://test" 64 ) as client: 65 response = await client.post( 66 "/auth/scope-upgrade/start", 67 json={"include_teal": True}, 68 ) 69 70 assert response.status_code == 200 71 data = response.json() 72 assert "auth_url" in data 73 assert data["auth_url"].startswith("https://auth.example.com") 74 mock_oauth.assert_called_once_with("testuser.bsky.social", include_teal=True) 75 76 77async def test_start_scope_upgrade_default_includes_teal( 78 test_app: FastAPI, db_session: AsyncSession 79): 80 """test that scope upgrade defaults to including teal scopes.""" 81 with patch( 82 "backend.api.auth.start_oauth_flow_with_scopes", new_callable=AsyncMock 83 ) as mock_oauth: 84 mock_oauth.return_value = ("https://auth.example.com/authorize", "test_state") 85 86 async with AsyncClient( 87 transport=ASGITransport(app=test_app), base_url="http://test" 88 ) as client: 89 response = await client.post( 90 "/auth/scope-upgrade/start", 91 json={}, # empty body - should default to include_teal=True 92 ) 93 94 assert response.status_code == 200 95 mock_oauth.assert_called_once_with("testuser.bsky.social", include_teal=True) 96 97 98async def test_scope_upgrade_requires_auth(db_session: AsyncSession): 99 """test that scope upgrade requires authentication.""" 100 async with AsyncClient( 101 transport=ASGITransport(app=app), base_url="http://test" 102 ) as client: 103 response = await client.post( 104 "/auth/scope-upgrade/start", 105 json={"include_teal": True}, 106 ) 107 108 assert response.status_code == 401 109 110 111async def test_scope_upgrade_saves_pending_record( 112 test_app: FastAPI, db_session: AsyncSession 113): 114 """test that starting scope upgrade saves pending record.""" 115 from backend._internal import get_pending_scope_upgrade 116 117 with patch( 118 "backend.api.auth.start_oauth_flow_with_scopes", new_callable=AsyncMock 119 ) as mock_oauth: 120 mock_oauth.return_value = ("https://auth.example.com/authorize", "test_state") 121 122 async with AsyncClient( 123 transport=ASGITransport(app=test_app), base_url="http://test" 124 ) as client: 125 response = await client.post( 126 "/auth/scope-upgrade/start", 127 json={"include_teal": True}, 128 ) 129 130 assert response.status_code == 200 131 132 # verify pending record was saved 133 pending = await get_pending_scope_upgrade("test_state") 134 assert pending is not None 135 assert pending.did == "did:test:user123" 136 assert pending.old_session_id == "test_session_id_for_upgrade" 137 assert pending.requested_scopes == "teal"