at main 12 kB view raw
1#!/usr/bin/env -S uv run --script --quiet 2# /// script 3# requires-python = ">=3.12" 4# dependencies = [ 5# "httpx", 6# "pydantic-settings", 7# "sqlalchemy[asyncio]", 8# "asyncpg", 9# "logfire[sqlalchemy]", 10# ] 11# /// 12"""scan all tracks for copyright using the moderation service. 13 14usage: 15 uv run scripts/scan_tracks_copyright.py --env staging 16 uv run scripts/scan_tracks_copyright.py --env prod --dry-run 17 uv run scripts/scan_tracks_copyright.py --env staging --limit 10 18 uv run scripts/scan_tracks_copyright.py --env prod --max-duration 5 19 20this will: 21- fetch all tracks that haven't been scanned yet 22- call the moderation service for each track 23- store results in copyright_scans table 24 25environment variables (set in .env or export): 26 # database URLs per environment 27 DEV_DATABASE_URL - dev database connection string 28 STAGING_DATABASE_URL - staging database connection string 29 PROD_DATABASE_URL - production database connection string 30 31 # moderation service 32 MODERATION_SERVICE_URL - URL of moderation service (default: https://plyr-moderation.fly.dev) 33 MODERATION_AUTH_TOKEN - auth token for moderation service 34""" 35 36import asyncio 37import os 38import sys 39from datetime import UTC, datetime 40from pathlib import Path 41from typing import Literal 42 43import httpx 44from pydantic import Field 45from pydantic_settings import BaseSettings, SettingsConfigDict 46 47# add src to path 48sys.path.insert(0, str(Path(__file__).parent.parent / "backend" / "src")) 49 50 51Environment = Literal["dev", "staging", "prod"] 52 53 54class ScanSettings(BaseSettings): 55 """settings for copyright scan script.""" 56 57 model_config = SettingsConfigDict( 58 env_file=".env", 59 case_sensitive=False, 60 extra="ignore", 61 ) 62 63 dev_database_url: str = Field(default="", validation_alias="DEV_DATABASE_URL") 64 staging_database_url: str = Field( 65 default="", validation_alias="STAGING_DATABASE_URL" 66 ) 67 prod_database_url: str = Field(default="", validation_alias="PROD_DATABASE_URL") 68 69 moderation_service_url: str = Field( 70 default="https://plyr-moderation.fly.dev", 71 validation_alias="MODERATION_SERVICE_URL", 72 ) 73 moderation_auth_token: str = Field( 74 default="", validation_alias="MODERATION_AUTH_TOKEN" 75 ) 76 77 def get_database_url(self, env: Environment) -> str: 78 """get database URL for environment.""" 79 urls = { 80 "dev": self.dev_database_url, 81 "staging": self.staging_database_url, 82 "prod": self.prod_database_url, 83 } 84 url = urls.get(env, "") 85 if not url: 86 raise ValueError(f"no database URL configured for {env}") 87 return url 88 89 90def setup_env(settings: ScanSettings, env: Environment) -> None: 91 """setup environment variables for backend imports.""" 92 os.environ["DATABASE_URL"] = settings.get_database_url(env) 93 94 95async def get_file_size(client: httpx.AsyncClient, url: str) -> int | None: 96 """get file size from HTTP HEAD request.""" 97 try: 98 response = await client.head(url, timeout=10.0) 99 content_length = response.headers.get("content-length") 100 if content_length: 101 return int(content_length) 102 except Exception: 103 pass 104 return None 105 106 107def estimate_duration_minutes(file_size_bytes: int, file_type: str) -> float: 108 """estimate audio duration from file size. 109 110 uses high bitrate estimates to avoid OVERestimating duration: 111 - mp3: ~320 kbps (2.4 MB ≈ 1 minute) 112 - m4a/aac: ~256 kbps (1.9 MB ≈ 1 minute) 113 - wav: ~1411 kbps for 16-bit 44.1kHz stereo (10 MB ≈ 1 minute) 114 - flac: ~1000 kbps high quality (7.5 MB ≈ 1 minute) 115 """ 116 mb = file_size_bytes / (1024 * 1024) 117 118 if file_type == "mp3": 119 return mb / 2.4 # ~2.4 MB per minute at 320kbps 120 elif file_type in ("m4a", "aac"): 121 return mb / 1.9 # ~1.9 MB per minute at 256kbps 122 elif file_type == "wav": 123 return mb / 10 # ~10 MB per minute for CD quality 124 elif file_type == "flac": 125 return mb / 7.5 # ~7.5 MB per minute high quality 126 else: 127 return mb / 2.4 # default to mp3-like estimate 128 129 130async def scan_track( 131 client: httpx.AsyncClient, 132 settings: ScanSettings, 133 audio_url: str, 134) -> dict: 135 """call moderation service to scan a track.""" 136 response = await client.post( 137 f"{settings.moderation_service_url}/scan", 138 json={"audio_url": audio_url}, 139 headers={"X-Moderation-Key": settings.moderation_auth_token}, 140 timeout=120.0, # scans can take a while 141 ) 142 response.raise_for_status() 143 return response.json() 144 145 146async def run_scan( 147 env: Environment, 148 dry_run: bool = False, 149 limit: int | None = None, 150 max_duration: float | None = None, 151) -> None: 152 """scan all tracks for copyright.""" 153 # load settings 154 settings = ScanSettings() 155 156 # validate settings 157 try: 158 db_url = settings.get_database_url(env) 159 print( 160 f"✓ database: {db_url.split('@')[1].split('/')[0] if '@' in db_url else 'configured'}" 161 ) 162 except ValueError as e: 163 print(f"{e}") 164 print(f"\nset {env.upper()}_DATABASE_URL in .env") 165 sys.exit(1) 166 167 if not settings.moderation_auth_token: 168 print("❌ MODERATION_AUTH_TOKEN not set") 169 sys.exit(1) 170 171 print(f"✓ moderation service: {settings.moderation_service_url}") 172 173 # setup env before backend imports 174 setup_env(settings, env) 175 176 # import backend after env setup 177 from sqlalchemy import select 178 from sqlalchemy.orm import joinedload 179 180 from backend.models import CopyrightScan, Track 181 from backend.utilities.database import db_session 182 183 async with db_session() as db: 184 # find tracks without scans 185 scanned_subq = select(CopyrightScan.track_id) 186 stmt = ( 187 select(Track) 188 .options(joinedload(Track.artist)) 189 .where(Track.id.notin_(scanned_subq)) 190 .where(Track.r2_url.isnot(None)) 191 .order_by(Track.created_at.desc()) 192 ) 193 194 if limit: 195 stmt = stmt.limit(limit) 196 197 result = await db.execute(stmt) 198 tracks = result.scalars().unique().all() 199 200 if not tracks: 201 print("\n✅ all tracks have been scanned") 202 return 203 204 print(f"\n📋 found {len(tracks)} tracks to scan") 205 if max_duration: 206 print(f"⏱️ skipping tracks > {max_duration} minutes") 207 208 if dry_run: 209 print("\n[DRY RUN] checking tracks...") 210 async with httpx.AsyncClient() as client: 211 would_scan = [] 212 would_skip = [] 213 for track in tracks: 214 if max_duration and track.r2_url: 215 file_size = await get_file_size(client, track.r2_url) 216 if file_size: 217 est_duration = estimate_duration_minutes( 218 file_size, track.file_type 219 ) 220 if est_duration > max_duration: 221 would_skip.append((track, file_size, est_duration)) 222 continue 223 would_scan.append(track) 224 225 print(f"\nwould scan ({len(would_scan)}):") 226 for track in would_scan: 227 print(f" - {track.id}: {track.title} by @{track.artist.handle}") 228 229 if would_skip: 230 print(f"\nwould skip ({len(would_skip)}):") 231 for track, size, duration in would_skip: 232 print( 233 f" - {track.id}: {track.title} " 234 f"({size / (1024 * 1024):.1f} MB, ~{duration:.1f} min)" 235 ) 236 return 237 238 # scan tracks 239 async with httpx.AsyncClient() as client: 240 scanned = 0 241 skipped = 0 242 failed = 0 243 flagged = 0 244 245 for i, track in enumerate(tracks, 1): 246 print(f"\n[{i}/{len(tracks)}] scanning: {track.title}") 247 print(f" artist: @{track.artist.handle}") 248 print(f" url: {track.r2_url}") 249 250 # check duration if max_duration is set 251 if max_duration and track.r2_url: 252 file_size = await get_file_size(client, track.r2_url) 253 if file_size: 254 est_duration = estimate_duration_minutes( 255 file_size, track.file_type 256 ) 257 print( 258 f" size: {file_size / (1024 * 1024):.1f} MB, " 259 f"est. duration: {est_duration:.1f} min" 260 ) 261 if est_duration > max_duration: 262 print(f" ⏭️ skipped (>{max_duration} min)") 263 skipped += 1 264 continue 265 266 try: 267 result = await scan_track(client, settings, track.r2_url) 268 269 # create scan record 270 scan = CopyrightScan( 271 track_id=track.id, 272 scanned_at=datetime.now(UTC), 273 is_flagged=result["is_flagged"], 274 highest_score=result["highest_score"], 275 matches=result["matches"], 276 raw_response=result["raw_response"], 277 ) 278 db.add(scan) 279 await db.commit() 280 281 scanned += 1 282 if result["is_flagged"]: 283 flagged += 1 284 print(f" ⚠️ FLAGGED (score: {result['highest_score']})") 285 for match in result["matches"][:3]: 286 print( 287 f" - {match['artist']} - {match['title']} ({match['score']})" 288 ) 289 else: 290 print(f" ✓ clear (score: {result['highest_score']})") 291 292 except httpx.HTTPStatusError as e: 293 failed += 1 294 print(f" ❌ HTTP error: {e.response.status_code}") 295 try: 296 print(f" {e.response.json()}") 297 except Exception: 298 print(f" {e.response.text[:200]}") 299 except Exception as e: 300 failed += 1 301 print(f" ❌ error: {e}") 302 303 print(f"\n{'=' * 50}") 304 print("✅ scan complete") 305 print(f" scanned: {scanned}") 306 print(f" flagged: {flagged}") 307 print(f" skipped: {skipped}") 308 print(f" failed: {failed}") 309 310 311def main() -> None: 312 """main entry point.""" 313 import argparse 314 315 parser = argparse.ArgumentParser(description="scan tracks for copyright") 316 parser.add_argument( 317 "--env", 318 type=str, 319 required=True, 320 choices=["dev", "staging", "prod"], 321 help="environment to scan", 322 ) 323 parser.add_argument( 324 "--dry-run", 325 action="store_true", 326 help="show what would be scanned without making changes", 327 ) 328 parser.add_argument( 329 "--limit", 330 type=int, 331 default=None, 332 help="limit number of tracks to scan", 333 ) 334 parser.add_argument( 335 "--max-duration", 336 type=float, 337 default=None, 338 help="skip tracks longer than this many minutes (estimated from file size)", 339 ) 340 341 args = parser.parse_args() 342 343 print(f"🔍 copyright scan - {args.env}") 344 print("=" * 50) 345 346 asyncio.run(run_scan(args.env, args.dry_run, args.limit, args.max_duration)) 347 348 349if __name__ == "__main__": 350 main()