feat: migrate ATProto sync and teal scrobbling to docket (#539)

* feat: migrate ATProto sync and teal scrobbling to docket

- add sync_atproto and scrobble_to_teal as docket background tasks
- remove all fallback/bifurcation code - Redis is always required
- simplify background.py (remove is_docket_enabled)
- update auth.py and playback.py to use new schedulers
- add Redis service to CI workflow
- update tests for simplified docket-only flow

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: update test for schedule_atproto_sync, fix cookies deprecation

- test_list_record_sync: patch schedule_atproto_sync instead of removed sync_atproto_records
- test_hidden_tags_filter: move cookies to client constructor to fix httpx deprecation warning

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* feat: template database pattern for fast parallel test execution

- use template db + clone for xdist workers (instant file copy vs migrations)
- advisory locks coordinate template creation between workers
- patch settings.database.url per-worker for production code compatibility
- borrowed pattern from prefecthq/nebula

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* ci: enable parallel test execution with xdist

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: ensure test_app fixture depends on db_session for xdist

fixes race condition where tests using test_app without db_session
would run before database URL was patched for the worker

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>

authored by zzstoatzz.io Claude and committed by GitHub 1a4e897c 45b10418

+12 -1
.github/workflows/test-backend.yml
··· 32 32 --health-timeout 5s 33 33 --health-retries 5 34 34 35 + redis: 36 + image: redis:7-alpine 37 + ports: 38 + - 6379:6379 39 + options: >- 40 + --health-cmd "redis-cli ping" 41 + --health-interval 10s 42 + --health-timeout 5s 43 + --health-retries 5 44 + 35 45 steps: 36 46 - uses: actions/checkout@v5 37 47 ··· 52 62 - name: run tests 53 63 env: 54 64 DATABASE_URL: postgresql+asyncpg://relay_test:relay_test@localhost:5432/relay_test 55 - run: cd backend && uv run pytest tests/ 65 + DOCKET_URL: redis://localhost:6379 66 + run: cd backend && uv run pytest tests/ -n auto 56 67 57 68 - name: prune uv cache 58 69 if: always()
+27 -35
backend/src/backend/_internal/background.py
··· 2 2 3 3 provides a docket instance for scheduling background tasks and a worker 4 4 that runs alongside the FastAPI server. requires DOCKET_URL to be set 5 - to a Redis URL for durable execution across multiple machines. 5 + to a Redis URL. 6 6 7 7 usage: 8 - from backend._internal.background import get_docket, is_docket_enabled 8 + from backend._internal.background import get_docket 9 9 10 - if is_docket_enabled(): 11 - docket = get_docket() 12 - await docket.add(my_task_function)(arg1, arg2) 13 - else: 14 - # fallback to direct execution or FastAPI BackgroundTasks 15 - await my_task_function(arg1, arg2) 10 + docket = get_docket() 11 + await docket.add(my_task_function)(arg1, arg2) 16 12 """ 17 13 18 14 import asyncio ··· 26 22 27 23 logger = logging.getLogger(__name__) 28 24 29 - # global docket instance - initialized in lifespan (None if disabled) 25 + # global docket instance - initialized in lifespan 30 26 _docket: Docket | None = None 31 - _docket_enabled: bool = False 32 - 33 - 34 - def is_docket_enabled() -> bool: 35 - """check if docket is enabled and initialized.""" 36 - return _docket_enabled and _docket is not None 37 27 38 28 39 29 def get_docket() -> Docket: 40 30 """get the global docket instance. 41 31 42 32 raises: 43 - RuntimeError: if docket is not initialized or disabled 33 + RuntimeError: if docket is not initialized 44 34 """ 45 - if not _docket_enabled: 46 - raise RuntimeError("docket is disabled - set DOCKET_URL to enable") 47 35 if _docket is None: 48 36 raise RuntimeError("docket not initialized - is the server running?") 49 37 return _docket 50 38 51 39 52 40 @asynccontextmanager 53 - async def background_worker_lifespan() -> AsyncGenerator[Docket | None, None]: 41 + async def background_worker_lifespan() -> AsyncGenerator[Docket, None]: 54 42 """lifespan context manager for docket and its worker. 55 43 56 - if DOCKET_URL is not set, docket is disabled and this yields None. 57 - when enabled, initializes the docket connection and starts an in-process 44 + initializes the docket connection and starts an in-process 58 45 worker that processes background tasks. 59 46 60 47 yields: 61 - Docket | None: the initialized docket instance, or None if disabled 48 + Docket: the initialized docket instance 62 49 """ 63 - global _docket, _docket_enabled 64 - 65 - # check if docket should be enabled 66 - if not settings.docket.url: 67 - logger.info("docket disabled (DOCKET_URL not set)") 68 - _docket_enabled = False 69 - yield None 70 - return 50 + global _docket 71 51 72 - _docket_enabled = True 73 52 logger.info( 74 53 "initializing docket", 75 54 extra={"docket_name": settings.docket.name, "url": settings.docket.url}, ··· 108 87 await worker_task 109 88 except asyncio.CancelledError: 110 89 logger.debug("docket worker task cancelled") 111 - # clear globals after worker is fully stopped 90 + # clear global after worker is fully stopped 112 91 _docket = None 113 - _docket_enabled = False 114 92 logger.info("docket worker stopped") 115 93 116 94 ··· 121 99 add new task imports here as they're created. 122 100 """ 123 101 # import task functions here to avoid circular imports 124 - from backend._internal.background_tasks import process_export, scan_copyright 102 + from backend._internal.background_tasks import ( 103 + process_export, 104 + scan_copyright, 105 + scrobble_to_teal, 106 + sync_atproto, 107 + ) 125 108 126 109 docket.register(scan_copyright) 127 110 docket.register(process_export) 111 + docket.register(sync_atproto) 112 + docket.register(scrobble_to_teal) 128 113 129 114 logger.info( 130 115 "registered background tasks", 131 - extra={"tasks": ["scan_copyright", "process_export"]}, 116 + extra={ 117 + "tasks": [ 118 + "scan_copyright", 119 + "process_export", 120 + "sync_atproto", 121 + "scrobble_to_teal", 122 + ] 123 + }, 132 124 )
+111 -45
backend/src/backend/_internal/background_tasks.py
··· 2 2 3 3 these functions are registered with docket and executed by workers. 4 4 they should be self-contained and handle their own database sessions. 5 + 6 + requires DOCKET_URL to be set (Redis is always available). 5 7 """ 6 8 7 - import asyncio 8 9 import logging 9 10 import os 10 11 import tempfile ··· 16 17 import aiofiles 17 18 import logfire 18 19 19 - from backend._internal.background import get_docket, is_docket_enabled 20 + from backend._internal.background import get_docket 20 21 21 22 logger = logging.getLogger(__name__) 22 23 23 24 24 25 async def scan_copyright(track_id: int, audio_url: str) -> None: 25 26 """scan a track for potential copyright matches. 26 - 27 - this is the docket version of the copyright scan task. when docket 28 - is enabled (DOCKET_URL set), this provides durability and retries 29 - compared to fire-and-forget asyncio.create_task(). 30 27 31 28 args: 32 29 track_id: database ID of the track to scan ··· 38 35 39 36 40 37 async def schedule_copyright_scan(track_id: int, audio_url: str) -> None: 41 - """schedule a copyright scan, using docket if enabled, else asyncio. 42 - 43 - this is the entry point for scheduling copyright scans. it handles 44 - the docket vs asyncio fallback logic in one place. 45 - """ 46 - from backend._internal.moderation import scan_track_for_copyright 47 - 48 - if is_docket_enabled(): 49 - try: 50 - docket = get_docket() 51 - await docket.add(scan_copyright)(track_id, audio_url) 52 - logfire.info("scheduled copyright scan via docket", track_id=track_id) 53 - return 54 - except Exception as e: 55 - logfire.warning( 56 - "docket scheduling failed, falling back to asyncio", 57 - track_id=track_id, 58 - error=str(e), 59 - ) 60 - 61 - # fallback: fire-and-forget 62 - asyncio.create_task(scan_track_for_copyright(track_id, audio_url)) # noqa: RUF006 38 + """schedule a copyright scan via docket.""" 39 + docket = get_docket() 40 + await docket.add(scan_copyright)(track_id, audio_url) 41 + logfire.info("scheduled copyright scan", track_id=track_id) 63 42 64 43 65 44 async def process_export(export_id: str, artist_did: str) -> None: ··· 283 262 284 263 285 264 async def schedule_export(export_id: str, artist_did: str) -> None: 286 - """schedule an export, using docket if enabled, else asyncio. 265 + """schedule an export via docket.""" 266 + docket = get_docket() 267 + await docket.add(process_export)(export_id, artist_did) 268 + logfire.info("scheduled export", export_id=export_id) 269 + 270 + 271 + async def sync_atproto(session_id: str, user_did: str) -> None: 272 + """sync ATProto records (profile, albums, liked tracks) for a user. 273 + 274 + this runs after login or scope upgrade to ensure the user's PDS 275 + has up-to-date records for their plyr.fm data. 276 + 277 + args: 278 + session_id: the user's session ID for authentication 279 + user_did: the user's DID 280 + """ 281 + from backend._internal.atproto.sync import sync_atproto_records 282 + from backend._internal.auth import get_session 283 + 284 + auth_session = await get_session(session_id) 285 + if not auth_session: 286 + logger.warning(f"sync_atproto: session {session_id[:8]}... not found") 287 + return 288 + 289 + await sync_atproto_records(auth_session, user_did) 290 + 291 + 292 + async def schedule_atproto_sync(session_id: str, user_did: str) -> None: 293 + """schedule an ATProto sync via docket.""" 294 + docket = get_docket() 295 + await docket.add(sync_atproto)(session_id, user_did) 296 + logfire.info("scheduled atproto sync", user_did=user_did) 297 + 298 + 299 + async def scrobble_to_teal( 300 + session_id: str, 301 + track_id: int, 302 + track_title: str, 303 + artist_name: str, 304 + duration: int | None, 305 + album_name: str | None, 306 + ) -> None: 307 + """scrobble a play to teal.fm (creates play record + updates status). 287 308 288 - this is the entry point for scheduling exports. it handles 289 - the docket vs asyncio fallback logic in one place. 309 + args: 310 + session_id: the user's session ID for authentication 311 + track_id: database ID of the track 312 + track_title: title of the track 313 + artist_name: name of the artist 314 + duration: track duration in seconds 315 + album_name: album name (optional) 290 316 """ 291 - if is_docket_enabled(): 292 - try: 293 - docket = get_docket() 294 - await docket.add(process_export)(export_id, artist_did) 295 - logfire.info("scheduled export via docket", export_id=export_id) 296 - return 297 - except Exception as e: 298 - logfire.warning( 299 - "docket scheduling failed, falling back to asyncio", 300 - export_id=export_id, 301 - error=str(e), 302 - ) 317 + from backend._internal.atproto.teal import ( 318 + create_teal_play_record, 319 + update_teal_status, 320 + ) 321 + from backend._internal.auth import get_session 322 + from backend.config import settings 323 + 324 + auth_session = await get_session(session_id) 325 + if not auth_session: 326 + logger.warning(f"teal scrobble: session {session_id[:8]}... not found") 327 + return 303 328 304 - # fallback: fire-and-forget 305 - asyncio.create_task(process_export(export_id, artist_did)) # noqa: RUF006 329 + origin_url = f"{settings.frontend.url}/track/{track_id}" 330 + 331 + try: 332 + # create play record (scrobble) 333 + play_uri = await create_teal_play_record( 334 + auth_session=auth_session, 335 + track_name=track_title, 336 + artist_name=artist_name, 337 + duration=duration, 338 + album_name=album_name, 339 + origin_url=origin_url, 340 + ) 341 + logger.info(f"teal play record created: {play_uri}") 342 + 343 + # update status (now playing) 344 + status_uri = await update_teal_status( 345 + auth_session=auth_session, 346 + track_name=track_title, 347 + artist_name=artist_name, 348 + duration=duration, 349 + album_name=album_name, 350 + origin_url=origin_url, 351 + ) 352 + logger.info(f"teal status updated: {status_uri}") 353 + 354 + except Exception as e: 355 + logger.error(f"teal scrobble failed for track {track_id}: {e}", exc_info=True) 356 + 357 + 358 + async def schedule_teal_scrobble( 359 + session_id: str, 360 + track_id: int, 361 + track_title: str, 362 + artist_name: str, 363 + duration: int | None, 364 + album_name: str | None, 365 + ) -> None: 366 + """schedule a teal scrobble via docket.""" 367 + docket = get_docket() 368 + await docket.add(scrobble_to_teal)( 369 + session_id, track_id, track_title, artist_name, duration, album_name 370 + ) 371 + logfire.info("scheduled teal scrobble", track_id=track_id)
+5 -44
backend/src/backend/api/auth.py
··· 1 1 """authentication api endpoints.""" 2 2 3 - import asyncio 4 3 import logging 5 4 from typing import Annotated 6 5 ··· 29 28 start_oauth_flow, 30 29 start_oauth_flow_with_scopes, 31 30 ) 32 - from backend._internal.atproto import sync_atproto_records 31 + from backend._internal.background_tasks import schedule_atproto_sync 33 32 from backend.config import settings 34 33 from backend.utilities.rate_limit import limiter 35 34 36 35 logger = logging.getLogger(__name__) 37 - 38 - # hold references to background tasks to prevent GC before completion 39 - _background_tasks: set[asyncio.Task[None]] = set() 40 - 41 - 42 - def _create_background_task(coro) -> asyncio.Task: 43 - """create a background task with proper lifecycle management.""" 44 - task = asyncio.create_task(coro) 45 - _background_tasks.add(task) 46 - task.add_done_callback(_background_tasks.discard) 47 - return task 48 - 49 36 50 37 router = APIRouter(prefix="/auth", tags=["auth"]) 51 38 ··· 159 146 # create exchange token - NOT marked as dev token so cookie gets set 160 147 exchange_token = await create_exchange_token(session_id) 161 148 162 - # fire-and-forget: sync ATProto records with new scopes 163 - auth_session = Session( 164 - session_id=session_id, 165 - did=did, 166 - handle=handle, 167 - oauth_session=oauth_session, 168 - ) 169 - 170 - async def _sync_on_scope_upgrade(): 171 - try: 172 - await sync_atproto_records(auth_session, did) 173 - except Exception as e: 174 - logger.error(f"background sync failed for {did}: {e}", exc_info=True) 175 - 176 - _create_background_task(_sync_on_scope_upgrade()) 149 + # schedule ATProto sync (via docket if enabled, else asyncio) 150 + await schedule_atproto_sync(session_id, did) 177 151 178 152 return RedirectResponse( 179 153 url=f"{settings.frontend.url}/settings?exchange_token={exchange_token}&scope_upgraded=true", ··· 189 163 # check if artist profile exists 190 164 has_profile = await check_artist_profile_exists(did) 191 165 192 - # fire-and-forget: sync ATProto records on login 193 - auth_session = Session( 194 - session_id=session_id, 195 - did=did, 196 - handle=handle, 197 - oauth_session=oauth_session, 198 - ) 199 - 200 - async def _sync_on_login(): 201 - try: 202 - await sync_atproto_records(auth_session, did) 203 - except Exception as e: 204 - logger.error(f"background sync failed for {did}: {e}", exc_info=True) 205 - 206 - _create_background_task(_sync_on_login()) 166 + # schedule ATProto sync (via docket if enabled, else asyncio) 167 + await schedule_atproto_sync(session_id, did) 207 168 208 169 # redirect to profile setup if needed, otherwise to portal 209 170 redirect_path = "/portal" if has_profile else "/profile/setup"
+3 -48
backend/src/backend/api/tracks/playback.py
··· 4 4 import logging 5 5 from typing import Annotated 6 6 7 - from fastapi import BackgroundTasks, Depends, HTTPException 7 + from fastapi import Depends, HTTPException 8 8 from sqlalchemy import select 9 9 from sqlalchemy.ext.asyncio import AsyncSession 10 10 from sqlalchemy.orm import selectinload 11 11 12 12 from backend._internal import Session, get_optional_session 13 - from backend._internal.atproto.teal import create_teal_play_record, update_teal_status 14 - from backend._internal.auth import get_session 13 + from backend._internal.background_tasks import schedule_teal_scrobble 15 14 from backend.config import settings 16 15 from backend.models import Artist, Track, TrackLike, UserPreferences, get_db 17 16 from backend.schemas import TrackResponse ··· 59 58 ) 60 59 61 60 62 - async def _scrobble_to_teal( 63 - session_id: str, 64 - track_id: int, 65 - track_title: str, 66 - artist_name: str, 67 - duration: int | None, 68 - album_name: str | None, 69 - ) -> None: 70 - """scrobble a play to teal.fm (creates play record + updates status).""" 71 - if not (auth_session := await get_session(session_id)): 72 - logger.warning(f"teal scrobble failed: session {session_id[:8]}... not found") 73 - return 74 - 75 - origin_url = f"{settings.frontend.url}/track/{track_id}" 76 - 77 - try: 78 - # create play record (scrobble) 79 - play_uri = await create_teal_play_record( 80 - auth_session=auth_session, 81 - track_name=track_title, 82 - artist_name=artist_name, 83 - duration=duration, 84 - album_name=album_name, 85 - origin_url=origin_url, 86 - ) 87 - logger.info(f"teal play record created: {play_uri}") 88 - 89 - # update status (now playing) 90 - status_uri = await update_teal_status( 91 - auth_session=auth_session, 92 - track_name=track_title, 93 - artist_name=artist_name, 94 - duration=duration, 95 - album_name=album_name, 96 - origin_url=origin_url, 97 - ) 98 - logger.info(f"teal status updated: {status_uri}") 99 - 100 - except Exception as e: 101 - logger.error(f"teal scrobble failed for track {track_id}: {e}", exc_info=True) 102 - 103 - 104 61 @router.post("/{track_id}/play") 105 62 async def increment_play_count( 106 63 track_id: int, 107 64 db: Annotated[AsyncSession, Depends(get_db)], 108 - background_tasks: BackgroundTasks, 109 65 session: Session | None = Depends(get_optional_session), 110 66 ) -> dict: 111 67 """Increment play count for a track (called after 30 seconds of playback). ··· 135 91 # check if session has teal scopes 136 92 scope = session.oauth_session.get("scope", "") 137 93 if settings.teal.play_collection in scope: 138 - background_tasks.add_task( 139 - _scrobble_to_teal, 94 + await schedule_teal_scrobble( 140 95 session_id=session.session_id, 141 96 track_id=track_id, 142 97 track_title=track.title,
+16 -18
backend/tests/api/test_hidden_tags_filter.py
··· 148 148 ): 149 149 """test that discovery feed (no artist_did) filters hidden tags.""" 150 150 async with AsyncClient( 151 - transport=ASGITransport(app=test_app), base_url="http://test" 151 + transport=ASGITransport(app=test_app), 152 + base_url="http://test", 153 + cookies={"session_id": "test_session"}, 152 154 ) as client: 153 - response = await client.get( 154 - "/tracks/", 155 - cookies={"session_id": "test_session"}, 156 - ) 155 + response = await client.get("/tracks/") 157 156 158 157 assert response.status_code == 200 159 158 tracks = response.json()["tracks"] ··· 175 174 ): 176 175 """test that artist page (with artist_did) shows all tracks including hidden.""" 177 176 async with AsyncClient( 178 - transport=ASGITransport(app=test_app), base_url="http://test" 177 + transport=ASGITransport(app=test_app), 178 + base_url="http://test", 179 + cookies={"session_id": "test_session"}, 179 180 ) as client: 180 - response = await client.get( 181 - f"/tracks/?artist_did={artist.did}", 182 - cookies={"session_id": "test_session"}, 183 - ) 181 + response = await client.get(f"/tracks/?artist_did={artist.did}") 184 182 185 183 assert response.status_code == 200 186 184 tracks = response.json()["tracks"] ··· 201 199 ): 202 200 """test that filter_hidden_tags=true forces filtering even on artist page.""" 203 201 async with AsyncClient( 204 - transport=ASGITransport(app=test_app), base_url="http://test" 202 + transport=ASGITransport(app=test_app), 203 + base_url="http://test", 204 + cookies={"session_id": "test_session"}, 205 205 ) as client: 206 206 response = await client.get( 207 - f"/tracks/?artist_did={artist.did}&filter_hidden_tags=true", 208 - cookies={"session_id": "test_session"}, 207 + f"/tracks/?artist_did={artist.did}&filter_hidden_tags=true" 209 208 ) 210 209 211 210 assert response.status_code == 200 ··· 226 225 ): 227 226 """test that filter_hidden_tags=false disables filtering on discovery feed.""" 228 227 async with AsyncClient( 229 - transport=ASGITransport(app=test_app), base_url="http://test" 228 + transport=ASGITransport(app=test_app), 229 + base_url="http://test", 230 + cookies={"session_id": "test_session"}, 230 231 ) as client: 231 - response = await client.get( 232 - "/tracks/?filter_hidden_tags=false", 233 - cookies={"session_id": "test_session"}, 234 - ) 232 + response = await client.get("/tracks/?filter_hidden_tags=false") 235 233 236 234 assert response.status_code == 200 237 235 tracks = response.json()["tracks"]
+6 -11
backend/tests/api/test_list_record_sync.py
··· 1 1 """tests for ATProto list record sync on login.""" 2 2 3 - import asyncio 4 3 from collections.abc import Generator 5 4 from unittest.mock import AsyncMock, patch 6 5 ··· 392 391 return_value=True, 393 392 ), 394 393 patch( 395 - "backend.api.auth.sync_atproto_records", 394 + "backend.api.auth.schedule_atproto_sync", 396 395 new_callable=AsyncMock, 397 - ) as mock_sync, 396 + ) as mock_schedule_sync, 398 397 ): 399 398 async with AsyncClient( 400 399 transport=ASGITransport(app=test_app), base_url="http://test" ··· 409 408 follow_redirects=False, 410 409 ) 411 410 412 - # give background tasks time to run 413 - await asyncio.sleep(0.1) 414 - 415 411 assert response.status_code == 303 416 412 assert "exchange_token=test_exchange_token" in response.headers["location"] 417 413 418 - # verify sync was triggered in background 419 - mock_sync.assert_called_once() 420 - call_args = mock_sync.call_args 421 - # first arg is the session, second is the DID 422 - assert call_args[0][1] == "did:plc:testartist123" 414 + # verify sync was scheduled via docket 415 + mock_schedule_sync.assert_called_once_with( 416 + "test_session_id", "did:plc:testartist123" 417 + )
+3 -2
backend/tests/api/test_oembed.py
··· 36 36 37 37 38 38 @pytest.fixture 39 - def test_app() -> FastAPI: 40 - """get test app.""" 39 + def test_app(db_session: AsyncSession) -> FastAPI: 40 + """get test app with db session dependency to ensure correct database URL.""" 41 + _ = db_session # ensures database fixtures run first 41 42 return app 42 43 43 44
+149 -15
backend/tests/conftest.py
··· 3 3 from collections.abc import AsyncGenerator, Generator 4 4 from contextlib import asynccontextmanager 5 5 from datetime import UTC, datetime 6 + from urllib.parse import urlsplit, urlunsplit 6 7 8 + import asyncpg 7 9 import pytest 8 10 import sqlalchemy as sa 9 11 from fastapi import FastAPI ··· 43 45 44 46 # set _storage directly to prevent R2Storage initialization 45 47 backend.storage._storage = MockStorage() # type: ignore[assignment] 48 + 49 + 50 + def _database_from_url(url: str) -> str: 51 + """extract database name from connection URL.""" 52 + _, _, path, _, _ = urlsplit(url) 53 + return path.strip("/") 54 + 55 + 56 + def _postgres_admin_url(database_url: str) -> str: 57 + """convert async database URL to sync postgres URL for admin operations.""" 58 + scheme, netloc, _, query, fragment = urlsplit(database_url) 59 + # asyncpg -> postgres for direct connection 60 + scheme = scheme.replace("+asyncpg", "").replace("postgresql", "postgres") 61 + return urlunsplit((scheme, netloc, "/postgres", query, fragment)) 46 62 47 63 48 64 @asynccontextmanager ··· 134 150 await connection.execute(sa.text(stmt)) 135 151 136 152 153 + async def _setup_template_database(template_url: str) -> None: 154 + """initialize database schema and helper procedure on template database.""" 155 + engine = create_async_engine(template_url, echo=False) 156 + try: 157 + async with engine.begin() as conn: 158 + await conn.run_sync(Base.metadata.create_all) 159 + await _truncate_tables(conn) 160 + await _create_clear_database_procedure(conn) 161 + finally: 162 + await engine.dispose() 163 + 164 + 165 + async def _ensure_template_database(base_url: str) -> str: 166 + """ensure template database exists and is migrated. 167 + 168 + uses advisory lock to coordinate between xdist workers. 169 + returns the template database name. 170 + """ 171 + base_db_name = _database_from_url(base_url) 172 + template_db_name = f"{base_db_name}_template" 173 + postgres_url = _postgres_admin_url(base_url) 174 + 175 + conn = await asyncpg.connect(postgres_url) 176 + try: 177 + # advisory lock prevents race condition between workers 178 + await conn.execute("SELECT pg_advisory_lock(hashtext($1))", template_db_name) 179 + 180 + # check if template exists 181 + exists = await conn.fetchval( 182 + "SELECT 1 FROM pg_database WHERE datname = $1", template_db_name 183 + ) 184 + 185 + if not exists: 186 + # create template database 187 + await conn.execute(f'CREATE DATABASE "{template_db_name}"') 188 + 189 + # build URL for template and set it up 190 + scheme, netloc, _, query, fragment = urlsplit(base_url) 191 + template_url = urlunsplit( 192 + (scheme, netloc, f"/{template_db_name}", query, fragment) 193 + ) 194 + await _setup_template_database(template_url) 195 + 196 + # release lock (other workers waiting will see template exists) 197 + await conn.execute("SELECT pg_advisory_unlock(hashtext($1))", template_db_name) 198 + 199 + return template_db_name 200 + finally: 201 + await conn.close() 202 + 203 + 204 + async def _create_worker_database_from_template( 205 + base_url: str, worker_id: str, template_db_name: str 206 + ) -> str: 207 + """create worker database by cloning the template (instant file copy).""" 208 + base_db_name = _database_from_url(base_url) 209 + worker_db_name = f"{base_db_name}_{worker_id}" 210 + postgres_url = _postgres_admin_url(base_url) 211 + 212 + conn = await asyncpg.connect(postgres_url) 213 + try: 214 + # kill connections to worker db (if it exists from previous run) 215 + await conn.execute( 216 + """ 217 + SELECT pg_terminate_backend(pid) 218 + FROM pg_stat_activity 219 + WHERE datname = $1 AND pid <> pg_backend_pid() 220 + """, 221 + worker_db_name, 222 + ) 223 + 224 + # kill connections to template db (required for cloning) 225 + await conn.execute( 226 + """ 227 + SELECT pg_terminate_backend(pid) 228 + FROM pg_stat_activity 229 + WHERE datname = $1 AND pid <> pg_backend_pid() 230 + """, 231 + template_db_name, 232 + ) 233 + 234 + # drop and recreate from template (instant - just file copy) 235 + await conn.execute(f'DROP DATABASE IF EXISTS "{worker_db_name}"') 236 + await conn.execute( 237 + f'CREATE DATABASE "{worker_db_name}" WITH TEMPLATE "{template_db_name}"' 238 + ) 239 + 240 + return worker_db_name 241 + finally: 242 + await conn.close() 243 + 244 + 137 245 @pytest.fixture(scope="session") 138 246 def test_database_url(worker_id: str) -> str: 139 247 """generate a unique test database URL for each pytest worker. 140 248 141 - reads from settings.database.url and appends worker suffix if needed. 249 + uses template database pattern for fast parallel test execution: 250 + 1. first worker creates template db with migrations (once) 251 + 2. each worker clones from template (instant file copy) 252 + 253 + also patches settings.database.url so all production code uses test db. 142 254 """ 255 + import asyncio 256 + import os 257 + 143 258 base_url = settings.database.url 144 259 145 - # for parallel test execution, append worker id to database name 260 + # single worker - just use base database 146 261 if worker_id == "master": 262 + asyncio.run(_setup_database_direct(base_url)) 147 263 return base_url 148 264 149 - # worker_id will be "gw0", "gw1", etc for xdist workers 150 - return f"{base_url}_{worker_id}" 265 + # xdist workers - use template pattern 266 + template_db_name = asyncio.run(_ensure_template_database(base_url)) 267 + asyncio.run( 268 + _create_worker_database_from_template(base_url, worker_id, template_db_name) 269 + ) 151 270 271 + # build URL for worker database 272 + scheme, netloc, _, query, fragment = urlsplit(base_url) 273 + base_db_name = _database_from_url(base_url) 274 + worker_db_name = f"{base_db_name}_{worker_id}" 275 + worker_url = urlunsplit((scheme, netloc, f"/{worker_db_name}", query, fragment)) 152 276 153 - async def _setup_database(test_database_url: str) -> None: 154 - """initialize database schema and helper procedure.""" 155 - engine = create_async_engine(test_database_url, echo=False) 277 + # patch settings so all production code uses this URL 278 + # this is safe because each xdist worker is a separate process 279 + settings.database.url = worker_url 280 + os.environ["DATABASE_URL"] = worker_url 281 + 282 + return worker_url 283 + 284 + 285 + async def _setup_database_direct(database_url: str) -> None: 286 + """set up database directly (for single worker mode).""" 287 + engine = create_async_engine(database_url, echo=False) 156 288 try: 157 289 async with engine.begin() as conn: 158 290 await conn.run_sync(Base.metadata.create_all) ··· 164 296 165 297 @pytest.fixture(scope="session") 166 298 def _database_setup(test_database_url: str) -> None: 167 - """create tables and stored procedures once per test session.""" 168 - import asyncio 169 - 170 - asyncio.run(_setup_database(test_database_url)) 299 + """marker fixture - database is set up by test_database_url fixture.""" 300 + _ = test_database_url # ensure dependency chain 171 301 172 302 173 303 @pytest.fixture() ··· 175 305 test_database_url: str, _database_setup: None 176 306 ) -> AsyncGenerator[AsyncEngine, None]: 177 307 """create a database engine for each test (to avoid event loop issues).""" 308 + from backend.utilities.database import ENGINES 309 + 310 + # clear any cached engines from previous tests 311 + for cached_engine in list(ENGINES.values()): 312 + await cached_engine.dispose() 313 + ENGINES.clear() 314 + 178 315 engine = create_async_engine( 179 316 test_database_url, 180 317 echo=False, ··· 185 322 yield engine 186 323 finally: 187 324 await engine.dispose() 188 - # also dispose any engines cached by production code (database.py) 189 - # to prevent connection accumulation across tests 190 - from backend.utilities.database import ENGINES 191 - 325 + # clean up cached engines 192 326 for cached_engine in list(ENGINES.values()): 193 327 await cached_engine.dispose() 194 328 ENGINES.clear()
+50 -72
backend/tests/test_background_tasks.py
··· 5 5 import backend._internal.background_tasks as bg_tasks 6 6 7 7 8 - async def test_schedule_export_uses_docket_when_enabled() -> None: 9 - """when docket is enabled, schedule_export should add task to docket.""" 8 + async def test_schedule_export_uses_docket() -> None: 9 + """schedule_export should add task to docket.""" 10 10 calls: list[tuple[str, str]] = [] 11 11 12 12 async def mock_schedule(export_id: str, artist_did: str) -> None: ··· 16 16 mock_docket.add = MagicMock(return_value=mock_schedule) 17 17 18 18 with ( 19 - patch.object(bg_tasks, "is_docket_enabled", return_value=True), 20 19 patch.object(bg_tasks, "get_docket", return_value=mock_docket), 21 20 patch.object(bg_tasks, "process_export", MagicMock()), 22 21 ): ··· 26 25 assert calls == [("export-123", "did:plc:testuser")] 27 26 28 27 29 - async def test_schedule_export_falls_back_to_asyncio_when_disabled() -> None: 30 - """when docket is disabled, schedule_export should use asyncio.create_task.""" 31 - create_task_calls: list[object] = [] 28 + async def test_schedule_copyright_scan_uses_docket() -> None: 29 + """schedule_copyright_scan should add task to docket.""" 30 + calls: list[tuple[int, str]] = [] 32 31 33 - def capture_create_task(coro: object) -> MagicMock: 34 - create_task_calls.append(coro) 35 - return MagicMock() 32 + async def mock_schedule(track_id: int, audio_url: str) -> None: 33 + calls.append((track_id, audio_url)) 36 34 37 - process_export_calls: list[tuple[str, str]] = [] 38 - 39 - def mock_process_export(export_id: str, artist_did: str) -> object: 40 - process_export_calls.append((export_id, artist_did)) 41 - return MagicMock() # return non-coroutine to avoid unawaited warning 42 - 43 - with ( 44 - patch.object(bg_tasks, "is_docket_enabled", return_value=False), 45 - patch.object(bg_tasks, "process_export", mock_process_export), 46 - patch.object(bg_tasks.asyncio, "create_task", capture_create_task), 47 - ): 48 - await bg_tasks.schedule_export("export-456", "did:plc:testuser") 49 - 50 - assert len(create_task_calls) == 1 51 - assert process_export_calls == [("export-456", "did:plc:testuser")] 52 - 53 - 54 - async def test_schedule_export_falls_back_on_docket_error() -> None: 55 - """if docket scheduling fails, should fall back to asyncio.""" 56 35 mock_docket = MagicMock() 57 - mock_docket.add.side_effect = Exception("redis connection failed") 58 - 59 - create_task_calls: list[object] = [] 60 - 61 - def capture_create_task(coro: object) -> MagicMock: 62 - create_task_calls.append(coro) 63 - return MagicMock() 64 - 65 - process_export_calls: list[tuple[str, str]] = [] 66 - 67 - def mock_process_export(export_id: str, artist_did: str) -> object: 68 - process_export_calls.append((export_id, artist_did)) 69 - return MagicMock() 36 + mock_docket.add = MagicMock(return_value=mock_schedule) 70 37 71 38 with ( 72 - patch.object(bg_tasks, "is_docket_enabled", return_value=True), 73 39 patch.object(bg_tasks, "get_docket", return_value=mock_docket), 74 - patch.object(bg_tasks, "process_export", mock_process_export), 75 - patch.object(bg_tasks.asyncio, "create_task", capture_create_task), 40 + patch.object(bg_tasks, "scan_copyright", MagicMock()), 76 41 ): 77 - await bg_tasks.schedule_export("export-789", "did:plc:testuser") 42 + await bg_tasks.schedule_copyright_scan(123, "https://example.com/audio.mp3") 78 43 79 - assert len(create_task_calls) == 1 44 + mock_docket.add.assert_called_once() 45 + assert calls == [(123, "https://example.com/audio.mp3")] 80 46 81 47 82 - async def test_schedule_copyright_scan_uses_docket_when_enabled() -> None: 83 - """when docket is enabled, schedule_copyright_scan should add task to docket.""" 84 - calls: list[tuple[int, str]] = [] 48 + async def test_schedule_atproto_sync_uses_docket() -> None: 49 + """schedule_atproto_sync should add task to docket.""" 50 + calls: list[tuple[str, str]] = [] 85 51 86 - async def mock_schedule(track_id: int, audio_url: str) -> None: 87 - calls.append((track_id, audio_url)) 52 + async def mock_schedule(session_id: str, user_did: str) -> None: 53 + calls.append((session_id, user_did)) 88 54 89 55 mock_docket = MagicMock() 90 56 mock_docket.add = MagicMock(return_value=mock_schedule) 91 57 92 58 with ( 93 - patch.object(bg_tasks, "is_docket_enabled", return_value=True), 94 59 patch.object(bg_tasks, "get_docket", return_value=mock_docket), 95 - patch.object(bg_tasks, "scan_copyright", MagicMock()), 60 + patch.object(bg_tasks, "sync_atproto", MagicMock()), 96 61 ): 97 - await bg_tasks.schedule_copyright_scan(123, "https://example.com/audio.mp3") 62 + await bg_tasks.schedule_atproto_sync("session-abc", "did:plc:testuser") 98 63 99 64 mock_docket.add.assert_called_once() 100 - assert calls == [(123, "https://example.com/audio.mp3")] 65 + assert calls == [("session-abc", "did:plc:testuser")] 101 66 102 67 103 - async def test_schedule_copyright_scan_falls_back_to_asyncio_when_disabled() -> None: 104 - """when docket is disabled, schedule_copyright_scan should use asyncio.""" 105 - create_task_calls: list[object] = [] 68 + async def test_schedule_teal_scrobble_uses_docket() -> None: 69 + """schedule_teal_scrobble should add task to docket.""" 70 + calls: list[tuple] = [] 106 71 107 - def capture_create_task(coro: object) -> MagicMock: 108 - create_task_calls.append(coro) 109 - return MagicMock() 72 + async def mock_schedule( 73 + session_id: str, 74 + track_id: int, 75 + track_title: str, 76 + artist_name: str, 77 + duration: int | None, 78 + album_name: str | None, 79 + ) -> None: 80 + calls.append( 81 + (session_id, track_id, track_title, artist_name, duration, album_name) 82 + ) 110 83 111 - scan_calls: list[tuple[int, str]] = [] 112 - 113 - def mock_scan(track_id: int, audio_url: str) -> object: 114 - scan_calls.append((track_id, audio_url)) 115 - return MagicMock() 84 + mock_docket = MagicMock() 85 + mock_docket.add = MagicMock(return_value=mock_schedule) 116 86 117 87 with ( 118 - patch.object(bg_tasks, "is_docket_enabled", return_value=False), 119 - patch("backend._internal.moderation.scan_track_for_copyright", mock_scan), 120 - patch.object(bg_tasks.asyncio, "create_task", capture_create_task), 88 + patch.object(bg_tasks, "get_docket", return_value=mock_docket), 89 + patch.object(bg_tasks, "scrobble_to_teal", MagicMock()), 121 90 ): 122 - await bg_tasks.schedule_copyright_scan(456, "https://example.com/audio.mp3") 91 + await bg_tasks.schedule_teal_scrobble( 92 + session_id="session-xyz", 93 + track_id=42, 94 + track_title="Test Track", 95 + artist_name="Test Artist", 96 + duration=180, 97 + album_name="Test Album", 98 + ) 123 99 124 - assert len(create_task_calls) == 1 125 - assert scan_calls == [(456, "https://example.com/audio.mp3")] 100 + mock_docket.add.assert_called_once() 101 + assert calls == [ 102 + ("session-xyz", 42, "Test Track", "Test Artist", 180, "Test Album") 103 + ]