music on atproto
plyr.fm
1"""pytest configuration for relay tests."""
2
3import os
4from collections.abc import AsyncGenerator, Generator
5from contextlib import asynccontextmanager
6from datetime import UTC, datetime
7from urllib.parse import urlsplit, urlunsplit
8
9import asyncpg
10import pytest
11import redis as sync_redis_lib
12import sqlalchemy as sa
13from fastapi import FastAPI
14from fastapi.testclient import TestClient
15from sqlalchemy.ext.asyncio import (
16 AsyncConnection,
17 AsyncEngine,
18 AsyncSession,
19 create_async_engine,
20)
21from sqlalchemy.orm import sessionmaker
22
23from backend.config import settings
24from backend.models import Base
25from backend.storage.r2 import R2Storage
26from backend.utilities.redis import clear_client_cache
27
28
29class MockStorage(R2Storage):
30 """Mock storage for tests - no R2 credentials needed."""
31
32 def __init__(self):
33 # skip R2Storage.__init__ which requires credentials
34 pass
35
36 async def save(self, file_obj, filename: str, progress_callback=None) -> str:
37 """Mock save - returns a fake file_id."""
38 return "mock_file_id_123"
39
40 async def get_url(
41 self,
42 file_id: str,
43 *,
44 file_type: str | None = None,
45 extension: str | None = None,
46 ) -> str | None:
47 """Mock get_url - returns a fake URL."""
48 return f"https://mock.r2.dev/{file_id}"
49
50 async def delete(self, file_id: str, file_type: str | None = None) -> bool:
51 """Mock delete."""
52 return True
53
54
55def pytest_configure(config):
56 """Set mock storage before any test modules are imported."""
57 import backend.storage
58
59 # set _storage directly to prevent R2Storage initialization
60 backend.storage._storage = MockStorage() # type: ignore[assignment]
61
62
63def _database_from_url(url: str) -> str:
64 """extract database name from connection URL."""
65 _, _, path, _, _ = urlsplit(url)
66 return path.strip("/")
67
68
69def _postgres_admin_url(database_url: str) -> str:
70 """convert async database URL to sync postgres URL for admin operations."""
71 scheme, netloc, _, query, fragment = urlsplit(database_url)
72 # asyncpg -> postgres for direct connection
73 scheme = scheme.replace("+asyncpg", "").replace("postgresql", "postgres")
74 return urlunsplit((scheme, netloc, "/postgres", query, fragment))
75
76
77@asynccontextmanager
78async def session_context(engine: AsyncEngine) -> AsyncGenerator[AsyncSession, None]:
79 """create a database session context."""
80 async_session_maker = sessionmaker(
81 bind=engine,
82 class_=AsyncSession,
83 expire_on_commit=False,
84 )
85 async with async_session_maker() as session:
86 yield session
87
88
89async def _create_clear_database_procedure(
90 connection: AsyncConnection,
91) -> None:
92 """creates a stored procedure in the test database used for quickly clearing
93 the database between tests.
94 """
95 tables = list(reversed(Base.metadata.sorted_tables))
96
97 def schema(table: sa.Table) -> str:
98 return table.schema or "public"
99
100 def timestamp_column(table: sa.Table) -> str | None:
101 """find the timestamp column to use for filtering"""
102 if "created_at" in table.columns:
103 return "created_at"
104 elif "updated_at" in table.columns:
105 return "updated_at"
106 else:
107 # if no timestamp column, delete all rows
108 return None
109
110 delete_statements = []
111 for table in tables:
112 ts_col = timestamp_column(table)
113 if ts_col:
114 delete_statements.append(
115 f"""
116 BEGIN
117 DELETE FROM {schema(table)}.{table.name}
118 WHERE {ts_col} > _test_start_time;
119 EXCEPTION WHEN OTHERS THEN
120 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM;
121 END;
122 """
123 )
124 else:
125 # no timestamp column - delete all rows
126 delete_statements.append(
127 f"""
128 BEGIN
129 DELETE FROM {schema(table)}.{table.name};
130 EXCEPTION WHEN OTHERS THEN
131 RAISE EXCEPTION 'Error clearing table {schema(table)}.{table.name}: %', SQLERRM;
132 END;
133 """
134 )
135
136 deletes = "\n".join(delete_statements)
137
138 signature = "clear_database(_test_start_time timestamptz)"
139 procedure_body = f"""
140 CREATE PROCEDURE {signature}
141 LANGUAGE PLPGSQL
142 AS $$
143 BEGIN
144 {deletes}
145 END;
146 $$;
147 """
148
149 await connection.execute(sa.text(f"DROP PROCEDURE IF EXISTS {signature};"))
150 await connection.execute(sa.text(procedure_body))
151
152
153async def _truncate_tables(connection: AsyncConnection) -> None:
154 """truncate all tables to ensure a clean slate at start of session."""
155 # get all table names from metadata
156 tables = [table.name for table in Base.metadata.sorted_tables]
157 if not tables:
158 return
159
160 # truncate all tables with cascade to handle foreign keys
161 # restart identity resets auto-increment counters
162 stmt = f"TRUNCATE TABLE {', '.join(tables)} RESTART IDENTITY CASCADE;"
163 await connection.execute(sa.text(stmt))
164
165
166async def _setup_template_database(template_url: str) -> None:
167 """initialize database schema and helper procedure on template database."""
168 engine = create_async_engine(template_url, echo=False)
169 try:
170 async with engine.begin() as conn:
171 await conn.run_sync(Base.metadata.create_all)
172 await _truncate_tables(conn)
173 await _create_clear_database_procedure(conn)
174 finally:
175 await engine.dispose()
176
177
178async def _ensure_template_database(base_url: str) -> str:
179 """ensure template database exists and is migrated.
180
181 uses advisory lock to coordinate between xdist workers.
182 returns the template database name.
183 """
184 base_db_name = _database_from_url(base_url)
185 template_db_name = f"{base_db_name}_template"
186 postgres_url = _postgres_admin_url(base_url)
187
188 conn = await asyncpg.connect(postgres_url)
189 try:
190 # advisory lock prevents race condition between workers
191 await conn.execute("SELECT pg_advisory_lock(hashtext($1))", template_db_name)
192
193 # check if template exists
194 exists = await conn.fetchval(
195 "SELECT 1 FROM pg_database WHERE datname = $1", template_db_name
196 )
197
198 if not exists:
199 # create template database
200 await conn.execute(f'CREATE DATABASE "{template_db_name}"')
201
202 # build URL for template and set it up
203 scheme, netloc, _, query, fragment = urlsplit(base_url)
204 template_url = urlunsplit(
205 (scheme, netloc, f"/{template_db_name}", query, fragment)
206 )
207 await _setup_template_database(template_url)
208
209 # release lock (other workers waiting will see template exists)
210 await conn.execute("SELECT pg_advisory_unlock(hashtext($1))", template_db_name)
211
212 return template_db_name
213 finally:
214 await conn.close()
215
216
217async def _create_worker_database_from_template(
218 base_url: str, worker_id: str, template_db_name: str
219) -> str:
220 """create worker database by cloning the template (instant file copy)."""
221 base_db_name = _database_from_url(base_url)
222 worker_db_name = f"{base_db_name}_{worker_id}"
223 postgres_url = _postgres_admin_url(base_url)
224
225 conn = await asyncpg.connect(postgres_url)
226 try:
227 # kill connections to worker db (if it exists from previous run)
228 await conn.execute(
229 """
230 SELECT pg_terminate_backend(pid)
231 FROM pg_stat_activity
232 WHERE datname = $1 AND pid <> pg_backend_pid()
233 """,
234 worker_db_name,
235 )
236
237 # kill connections to template db (required for cloning)
238 await conn.execute(
239 """
240 SELECT pg_terminate_backend(pid)
241 FROM pg_stat_activity
242 WHERE datname = $1 AND pid <> pg_backend_pid()
243 """,
244 template_db_name,
245 )
246
247 # drop and recreate from template (instant - just file copy)
248 await conn.execute(f'DROP DATABASE IF EXISTS "{worker_db_name}"')
249 await conn.execute(
250 f'CREATE DATABASE "{worker_db_name}" WITH TEMPLATE "{template_db_name}"'
251 )
252
253 return worker_db_name
254 finally:
255 await conn.close()
256
257
258@pytest.fixture(scope="session")
259def test_database_url(worker_id: str) -> str:
260 """generate a unique test database URL for each pytest worker.
261
262 uses template database pattern for fast parallel test execution:
263 1. first worker creates template db with migrations (once)
264 2. each worker clones from template (instant file copy)
265
266 also patches settings.database.url so all production code uses test db.
267 """
268 import asyncio
269 import os
270
271 base_url = settings.database.url
272
273 # single worker - just use base database
274 if worker_id == "master":
275 asyncio.run(_setup_database_direct(base_url))
276 return base_url
277
278 # xdist workers - use template pattern
279 template_db_name = asyncio.run(_ensure_template_database(base_url))
280 asyncio.run(
281 _create_worker_database_from_template(base_url, worker_id, template_db_name)
282 )
283
284 # build URL for worker database
285 scheme, netloc, _, query, fragment = urlsplit(base_url)
286 base_db_name = _database_from_url(base_url)
287 worker_db_name = f"{base_db_name}_{worker_id}"
288 worker_url = urlunsplit((scheme, netloc, f"/{worker_db_name}", query, fragment))
289
290 # patch settings so all production code uses this URL
291 # this is safe because each xdist worker is a separate process
292 settings.database.url = worker_url
293 os.environ["DATABASE_URL"] = worker_url
294
295 return worker_url
296
297
298async def _setup_database_direct(database_url: str) -> None:
299 """set up database directly (for single worker mode)."""
300 engine = create_async_engine(database_url, echo=False)
301 try:
302 async with engine.begin() as conn:
303 await conn.run_sync(Base.metadata.create_all)
304 await _truncate_tables(conn)
305 await _create_clear_database_procedure(conn)
306 finally:
307 await engine.dispose()
308
309
310@pytest.fixture(scope="session")
311def _database_setup(test_database_url: str) -> None:
312 """marker fixture - database is set up by test_database_url fixture."""
313 _ = test_database_url # ensure dependency chain
314
315
316@pytest.fixture()
317async def _engine(
318 test_database_url: str, _database_setup: None
319) -> AsyncGenerator[AsyncEngine, None]:
320 """create a database engine for each test (to avoid event loop issues)."""
321 from backend.utilities.database import ENGINES
322
323 # clear any cached engines from previous tests
324 for cached_engine in list(ENGINES.values()):
325 await cached_engine.dispose()
326 ENGINES.clear()
327
328 engine = create_async_engine(
329 test_database_url,
330 echo=False,
331 pool_size=2,
332 max_overflow=0,
333 )
334 try:
335 yield engine
336 finally:
337 await engine.dispose()
338 # clean up cached engines
339 for cached_engine in list(ENGINES.values()):
340 await cached_engine.dispose()
341 ENGINES.clear()
342
343
344@pytest.fixture()
345async def _clear_db(_engine: AsyncEngine) -> AsyncGenerator[None, None]:
346 """clear the database after each test."""
347 start_time = datetime.now(UTC)
348
349 try:
350 yield
351 finally:
352 # clear the database after the test
353 async with _engine.begin() as conn:
354 await conn.execute(
355 sa.text("CALL clear_database(:start_time)"),
356 {"start_time": start_time},
357 )
358
359
360@pytest.fixture
361async def db_session(
362 _engine: AsyncEngine, _clear_db: None
363) -> AsyncGenerator[AsyncSession, None]:
364 """provide a database session for each test.
365
366 the _clear_db fixture is used as a dependency to ensure proper cleanup order.
367 """
368 async with session_context(engine=_engine) as session:
369 yield session
370
371
372@pytest.fixture(scope="session")
373def fastapi_app() -> FastAPI:
374 """provides the FastAPI app instance (session-scoped for performance)."""
375 from backend.main import app as main_app
376
377 return main_app
378
379
380@pytest.fixture(scope="session")
381def client(fastapi_app: FastAPI) -> Generator[TestClient, None, None]:
382 """provides a TestClient for testing the FastAPI application.
383
384 session-scoped to avoid the overhead of starting the full lifespan
385 (database init, services, docket worker) for each test.
386 """
387 with TestClient(fastapi_app) as tc:
388 yield tc
389
390
391def _redis_db_for_worker(worker_id: str) -> int:
392 """determine redis database number based on xdist worker id.
393
394 uses different DB numbers for each worker to isolate parallel tests:
395 - master/gw0: db 1
396 - gw1: db 2
397 - gw2: db 3
398 - etc.
399
400 db 0 is reserved for local development.
401 """
402 if worker_id == "master" or not worker_id:
403 return 1
404 if "gw" in worker_id:
405 return 1 + int(worker_id.replace("gw", ""))
406 return 1
407
408
409def _redis_url_with_db(base_url: str, db: int) -> str:
410 """replace database number in redis URL."""
411 # redis://host:port/db -> redis://host:port/{new_db}
412 if "/" in base_url.rsplit(":", 1)[-1]:
413 # has db number, replace it
414 base = base_url.rsplit("/", 1)[0]
415 return f"{base}/{db}"
416 else:
417 # no db number, append it
418 return f"{base_url}/{db}"
419
420
421@pytest.fixture(scope="session", autouse=True)
422def redis_database(worker_id: str) -> Generator[None, None, None]:
423 """use isolated redis databases for parallel test execution.
424
425 each xdist worker gets its own redis database to prevent cache pollution
426 between tests running in parallel. flushes the db before and after tests.
427
428 if redis is not available, silently skips - tests that actually need redis
429 will fail on their own with a more specific error.
430 """
431 # skip if no redis configured
432 if not settings.docket.url:
433 yield
434 return
435
436 db = _redis_db_for_worker(worker_id)
437 new_url = _redis_url_with_db(settings.docket.url, db)
438
439 # patch settings for this worker process
440 settings.docket.url = new_url
441 os.environ["DOCKET_URL"] = new_url
442
443 # clear any cached clients (they have old URL)
444 clear_client_cache()
445
446 # try to flush db before tests - if redis unavailable, skip silently
447 try:
448 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1)
449 client.flushdb()
450 client.close()
451 except sync_redis_lib.ConnectionError:
452 # redis not available - tests that need it will fail with specific errors
453 yield
454 return
455
456 yield
457
458 # flush db after tests and clear cached clients
459 clear_client_cache()
460 try:
461 client = sync_redis_lib.Redis.from_url(new_url, socket_connect_timeout=1)
462 client.flushdb()
463 client.close()
464 except sync_redis_lib.ConnectionError:
465 pass # redis went away during tests, nothing to clean up