at main 15 kB view raw
1"""pytest configuration for relay tests.""" 2 3import os 4from collections.abc import AsyncGenerator, Generator 5from contextlib import asynccontextmanager 6from datetime import UTC, datetime 7from urllib.parse import urlsplit, urlunsplit 8 9import asyncpg 10import pytest 11import redis as sync_redis_lib 12import sqlalchemy as sa 13from fastapi import FastAPI 14from fastapi.testclient import TestClient 15from sqlalchemy.ext.asyncio import ( 16 AsyncConnection, 17 AsyncEngine, 18 AsyncSession, 19 create_async_engine, 20) 21from sqlalchemy.orm import sessionmaker 22 23from backend.config import settings 24from backend.models import Base 25from backend.storage.r2 import R2Storage 26from backend.utilities.redis import clear_client_cache 27 28 29class MockStorage(R2Storage): 30 """Mock storage for tests - no R2 credentials needed.""" 31 32 def __init__(self): 33 # skip R2Storage.__init__ which requires credentials 34 pass 35 36 async def save(self, file_obj, filename: str, progress_callback=None) -> str: 37 """Mock save - returns a fake file_id.""" 38 return "mock_file_id_123" 39 40 async def get_url( 41 self, 42 file_id: str, 43 *, 44 file_type: str | None = None, 45 extension: str | None = None, 46 ) -> str | None: 47 """Mock get_url - returns a fake URL.""" 48 return f"https://mock.r2.dev/{file_id}" 49 50 async def delete(self, file_id: str, file_type: str | None = None) -> bool: 51 """Mock delete.""" 52 return True 53 54 55def pytest_configure(config): 56 """Set mock storage before any test modules are imported.""" 57 import backend.storage 58 59 # set _storage directly to prevent R2Storage initialization 60 backend.storage._storage = MockStorage() # type: ignore[assignment] 61 62 63def _database_from_url(url: str) -> str: 64 """extract database name from connection URL.""" 65 _, _, path, _, _ = urlsplit(url) 66 return path.strip("/") 67 68 69def _postgres_admin_url(database_url: str) -> str: 70 """convert async database URL to sync postgres URL for admin operations.""" 71 scheme, netloc, _, query, fragment = urlsplit(database_url) 72 # asyncpg -> postgres for direct connection 73 scheme = scheme.replace("+asyncpg", "").replace("postgresql", "postgres") 74 return urlunsplit((scheme, netloc, "/postgres", query, fragment)) 75 76 77@asynccontextmanager 78async def session_context(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]: 79 """create a database session context.""" 80 async_session_maker = sessionmaker( 81 bind=engine, 82 class_=AsyncSession, 83 expire_on_commit=False, 84 ) 85 async with async_session_maker() as session: 86 yield session 87 88 89async def _create_clear_database_procedure( 90 connection: AsyncConnection, 91) -> None: 92 """creates a stored procedure in the test database used for quickly clearing 93 the database between tests. 94 """ 95 tables = list(reversed(Base.metadata.sorted_tables)) 96 97 def schema(table: sa.Table) -> str: 98 return table.schema or "public" 99 100 def timestamp_column(table: sa.Table) -> str | None: 101 """find the timestamp column to use for filtering""" 102 if "created_at" in table.columns: 103 return "created_at" 104 elif "updated_at" in table.columns: 105 return "updated_at" 106 else: 107 # if no timestamp column, delete all rows 108 return None 109 110 delete_statements = [] 111 for table in tables: 112 ts_col = timestamp_column(table) 113 if ts_col: 114 delete_statements.append( 115 f""" 116 BEGIN 117 DELETE FROM {schema(table)}.{table.name} 118 WHERE {ts_col} > _test_start_time; 119 EXCEPTION WHEN OTHERS THEN 120 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM; 121 END; 122 """ 123 ) 124 else: 125 # no timestamp column - delete all rows 126 delete_statements.append( 127 f""" 128 BEGIN 129 DELETE FROM {schema(table)}.{table.name}; 130 EXCEPTION WHEN OTHERS THEN 131 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM; 132 END; 133 """ 134 ) 135 136 deletes = "\n".join(delete_statements) 137 138 signature = "clear_database(_test_start_time timestamptz)" 139 procedure_body = f""" 140 CREATE PROCEDURE {signature} 141 LANGUAGE PLPGSQL 142 AS $$ 143 BEGIN 144 {deletes} 145 END; 146 $$; 147 """ 148 149 await connection.execute(sa.text(f"DROP PROCEDURE IF EXISTS {signature};")) 150 await connection.execute(sa.text(procedure_body)) 151 152 153async def _truncate_tables(connection: AsyncConnection) -> None: 154 """truncate all tables to ensure a clean slate at start of session.""" 155 # get all table names from metadata 156 tables = [table.name for table in Base.metadata.sorted_tables] 157 if not tables: 158 return 159 160 # truncate all tables with cascade to handle foreign keys 161 # restart identity resets auto-increment counters 162 stmt = f"TRUNCATE TABLE {', '.join(tables)} RESTART IDENTITY CASCADE;" 163 await connection.execute(sa.text(stmt)) 164 165 166async def _setup_template_database(template_url: str) -> None: 167 """initialize database schema and helper procedure on template database.""" 168 engine = create_async_engine(template_url, echo=False) 169 try: 170 async with engine.begin() as conn: 171 await conn.run_sync(Base.metadata.create_all) 172 await _truncate_tables(conn) 173 await _create_clear_database_procedure(conn) 174 finally: 175 await engine.dispose() 176 177 178async def _ensure_template_database(base_url: str) -> str: 179 """ensure template database exists and is migrated. 180 181 uses advisory lock to coordinate between xdist workers. 182 returns the template database name. 183 """ 184 base_db_name = _database_from_url(base_url) 185 template_db_name = f"{base_db_name}_template" 186 postgres_url = _postgres_admin_url(base_url) 187 188 conn = await asyncpg.connect(postgres_url) 189 try: 190 # advisory lock prevents race condition between workers 191 await conn.execute("SELECT pg_advisory_lock(hashtext($1))", template_db_name) 192 193 # check if template exists 194 exists = await conn.fetchval( 195 "SELECT 1 FROM pg_database WHERE datname = $1", template_db_name 196 ) 197 198 if not exists: 199 # create template database 200 await conn.execute(f'CREATE DATABASE "{template_db_name}"') 201 202 # build URL for template and set it up 203 scheme, netloc, _, query, fragment = urlsplit(base_url) 204 template_url = urlunsplit( 205 (scheme, netloc, f"/{template_db_name}", query, fragment) 206 ) 207 await _setup_template_database(template_url) 208 209 # release lock (other workers waiting will see template exists) 210 await conn.execute("SELECT pg_advisory_unlock(hashtext($1))", template_db_name) 211 212 return template_db_name 213 finally: 214 await conn.close() 215 216 217async def _create_worker_database_from_template( 218 base_url: str, worker_id: str, template_db_name: str 219) -> str: 220 """create worker database by cloning the template (instant file copy).""" 221 base_db_name = _database_from_url(base_url) 222 worker_db_name = f"{base_db_name}_{worker_id}" 223 postgres_url = _postgres_admin_url(base_url) 224 225 conn = await asyncpg.connect(postgres_url) 226 try: 227 # kill connections to worker db (if it exists from previous run) 228 await conn.execute( 229 """ 230 SELECT pg_terminate_backend(pid) 231 FROM pg_stat_activity 232 WHERE datname = $1 AND pid <> pg_backend_pid() 233 """, 234 worker_db_name, 235 ) 236 237 # kill connections to template db (required for cloning) 238 await conn.execute( 239 """ 240 SELECT pg_terminate_backend(pid) 241 FROM pg_stat_activity 242 WHERE datname = $1 AND pid <> pg_backend_pid() 243 """, 244 template_db_name, 245 ) 246 247 # drop and recreate from template (instant - just file copy) 248 await conn.execute(f'DROP DATABASE IF EXISTS "{worker_db_name}"') 249 await conn.execute( 250 f'CREATE DATABASE "{worker_db_name}" WITH TEMPLATE "{template_db_name}"' 251 ) 252 253 return worker_db_name 254 finally: 255 await conn.close() 256 257 258@pytest.fixture(scope="session") 259def test_database_url(worker_id: str) -> str: 260 """generate a unique test database URL for each pytest worker. 261 262 uses template database pattern for fast parallel test execution: 263 1. first worker creates template db with migrations (once) 264 2. each worker clones from template (instant file copy) 265 266 also patches settings.database.url so all production code uses test db. 267 """ 268 import asyncio 269 import os 270 271 base_url = settings.database.url 272 273 # single worker - just use base database 274 if worker_id == "master": 275 asyncio.run(_setup_database_direct(base_url)) 276 return base_url 277 278 # xdist workers - use template pattern 279 template_db_name = asyncio.run(_ensure_template_database(base_url)) 280 asyncio.run( 281 _create_worker_database_from_template(base_url, worker_id, template_db_name) 282 ) 283 284 # build URL for worker database 285 scheme, netloc, _, query, fragment = urlsplit(base_url) 286 base_db_name = _database_from_url(base_url) 287 worker_db_name = f"{base_db_name}_{worker_id}" 288 worker_url = urlunsplit((scheme, netloc, f"/{worker_db_name}", query, fragment)) 289 290 # patch settings so all production code uses this URL 291 # this is safe because each xdist worker is a separate process 292 settings.database.url = worker_url 293 os.environ["DATABASE_URL"] = worker_url 294 295 return worker_url 296 297 298async def _setup_database_direct(database_url: str) -> None: 299 """set up database directly (for single worker mode).""" 300 engine = create_async_engine(database_url, echo=False) 301 try: 302 async with engine.begin() as conn: 303 await conn.run_sync(Base.metadata.create_all) 304 await _truncate_tables(conn) 305 await _create_clear_database_procedure(conn) 306 finally: 307 await engine.dispose() 308 309 310@pytest.fixture(scope="session") 311def _database_setup(test_database_url: str) -> None: 312 """marker fixture - database is set up by test_database_url fixture.""" 313 _ = test_database_url # ensure dependency chain 314 315 316@pytest.fixture() 317async def _engine( 318 test_database_url: str, _database_setup: None 319) -> AsyncGenerator[AsyncEngine, None]: 320 """create a database engine for each test (to avoid event loop issues).""" 321 from backend.utilities.database import ENGINES 322 323 # clear any cached engines from previous tests 324 for cached_engine in list(ENGINES.values()): 325 await cached_engine.dispose() 326 ENGINES.clear() 327 328 engine = create_async_engine( 329 test_database_url, 330 echo=False, 331 pool_size=2, 332 max_overflow=0, 333 ) 334 try: 335 yield engine 336 finally: 337 await engine.dispose() 338 # clean up cached engines 339 for cached_engine in list(ENGINES.values()): 340 await cached_engine.dispose() 341 ENGINES.clear() 342 343 344@pytest.fixture() 345async def _clear_db(_engine: AsyncEngine) -> AsyncGenerator[None, None]: 346 """clear the database after each test.""" 347 start_time = datetime.now(UTC) 348 349 try: 350 yield 351 finally: 352 # clear the database after the test 353 async with _engine.begin() as conn: 354 await conn.execute( 355 sa.text("CALL clear_database(:start_time)"), 356 {"start_time": start_time}, 357 ) 358 359 360@pytest.fixture 361async def db_session( 362 _engine: AsyncEngine, _clear_db: None 363) -> AsyncGenerator[AsyncSession, None]: 364 """provide a database session for each test. 365 366 the _clear_db fixture is used as a dependency to ensure proper cleanup order. 367 """ 368 async with session_context(engine=_engine) as session: 369 yield session 370 371 372@pytest.fixture(scope="session") 373def fastapi_app() -> FastAPI: 374 """provides the FastAPI app instance (session-scoped for performance).""" 375 from backend.main import app as main_app 376 377 return main_app 378 379 380@pytest.fixture(scope="session") 381def client(fastapi_app: FastAPI) -> Generator[TestClient, None, None]: 382 """provides a TestClient for testing the FastAPI application. 383 384 session-scoped to avoid the overhead of starting the full lifespan 385 (database init, services, docket worker) for each test. 386 """ 387 with TestClient(fastapi_app) as tc: 388 yield tc 389 390 391def _redis_db_for_worker(worker_id: str) -> int: 392 """determine redis database number based on xdist worker id. 393 394 uses different DB numbers for each worker to isolate parallel tests: 395 - master/gw0: db 1 396 - gw1: db 2 397 - gw2: db 3 398 - etc. 399 400 db 0 is reserved for local development. 401 """ 402 if worker_id == "master" or not worker_id: 403 return 1 404 if "gw" in worker_id: 405 return 1 + int(worker_id.replace("gw", "")) 406 return 1 407 408 409def _redis_url_with_db(base_url: str, db: int) -> str: 410 """replace database number in redis URL.""" 411 # redis://host:port/db -> redis://host:port/{new_db} 412 if "/" in base_url.rsplit(":", 1)[-1]: 413 # has db number, replace it 414 base = base_url.rsplit("/", 1)[0] 415 return f"{base}/{db}" 416 else: 417 # no db number, append it 418 return f"{base_url}/{db}" 419 420 421@pytest.fixture(scope="session", autouse=True) 422def redis_database(worker_id: str) -> Generator[None, None, None]: 423 """use isolated redis databases for parallel test execution. 424 425 each xdist worker gets its own redis database to prevent cache pollution 426 between tests running in parallel. flushes the db before and after tests. 427 428 if redis is not available, silently skips - tests that actually need redis 429 will fail on their own with a more specific error. 430 """ 431 # skip if no redis configured 432 if not settings.docket.url: 433 yield 434 return 435 436 db = _redis_db_for_worker(worker_id) 437 new_url = _redis_url_with_db(settings.docket.url, db) 438 439 # patch settings for this worker process 440 settings.docket.url = new_url 441 os.environ["DOCKET_URL"] = new_url 442 443 # clear any cached clients (they have old URL) 444 clear_client_cache() 445 446 # try to flush db before tests - if redis unavailable, skip silently 447 try: 448 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1) 449 client.flushdb() 450 client.close() 451 except sync_redis_lib.ConnectionError: 452 # redis not available - tests that need it will fail with specific errors 453 yield 454 return 455 456 yield 457 458 # flush db after tests and clear cached clients 459 clear_client_cache() 460 try: 461 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1) 462 client.flushdb() 463 client.close() 464 except sync_redis_lib.ConnectionError: 465 pass # redis went away during tests, nothing to clean up