at main 8.9 kB view raw
1#!/usr/bin/env -S uv run --script --quiet --with-editable=backend 2# /// script 3# requires-python = ">=3.12" 4# dependencies = [ 5# "httpx", 6# "pydantic-settings", 7# ] 8# /// 9"""backfill label context from copyright_scans to moderation service. 10 11this script reads flagged tracks from the backend database and populates 12the label_context table in the moderation service database. it does NOT 13emit new labels - it only adds context to existing labels. 14 15usage: 16 uv run scripts/backfill_label_context.py --env prod --dry-run 17 uv run scripts/backfill_label_context.py --env prod 18 19environment variables (set in .env or export): 20 PROD_DATABASE_URL - production database connection string 21 STAGING_DATABASE_URL - staging database connection string 22 MODERATION_SERVICE_URL - URL of moderation service (default: https://moderation.plyr.fm) 23 MODERATION_AUTH_TOKEN - auth token for moderation service 24""" 25 26import asyncio 27import os 28import sys 29from typing import Any, Literal 30 31import httpx 32from pydantic import Field 33from pydantic_settings import BaseSettings, SettingsConfigDict 34 35 36Environment = Literal["dev", "staging", "prod"] 37 38 39class BackfillSettings(BaseSettings): 40 """settings for backfill script.""" 41 42 model_config = SettingsConfigDict( 43 env_file=".env", 44 case_sensitive=False, 45 extra="ignore", 46 ) 47 48 dev_database_url: str = Field(default="", validation_alias="DEV_DATABASE_URL") 49 staging_database_url: str = Field( 50 default="", validation_alias="STAGING_DATABASE_URL" 51 ) 52 prod_database_url: str = Field(default="", validation_alias="PROD_DATABASE_URL") 53 54 moderation_service_url: str = Field( 55 default="https://moderation.plyr.fm", 56 validation_alias="MODERATION_SERVICE_URL", 57 ) 58 moderation_auth_token: str = Field( 59 default="", validation_alias="MODERATION_AUTH_TOKEN" 60 ) 61 62 def get_database_url(self, env: Environment) -> str: 63 """get database URL for environment.""" 64 urls = { 65 "dev": self.dev_database_url, 66 "staging": self.staging_database_url, 67 "prod": self.prod_database_url, 68 } 69 url = urls.get(env, "") 70 if not url: 71 raise ValueError(f"no database URL configured for {env}") 72 return url 73 74 75def setup_env(settings: BackfillSettings, env: Environment) -> None: 76 """setup environment variables for backend imports.""" 77 db_url = settings.get_database_url(env) 78 # ensure asyncpg driver is used 79 if db_url.startswith("postgresql://"): 80 db_url = db_url.replace("postgresql://", "postgresql+asyncpg://", 1) 81 # asyncpg uses 'ssl' not 'sslmode' - convert the parameter 82 db_url = db_url.replace("sslmode=require", "ssl=require") 83 os.environ["DATABASE_URL"] = db_url 84 85 86async def store_context( 87 client: httpx.AsyncClient, 88 settings: BackfillSettings, 89 uri: str, 90 context: dict[str, Any], 91) -> bool: 92 """store context directly via emit-label endpoint. 93 94 we send a "dummy" emit that just stores context for an existing label. 95 the moderation service will upsert the context without creating a new label 96 if we use neg=false and the label already exists (it just updates context). 97 98 actually, we need a dedicated endpoint for this. let's use a workaround: 99 call emit-label with the context - it will store the context even though 100 the label already exists (store_context uses ON CONFLICT DO UPDATE). 101 """ 102 try: 103 # we need to call emit-label to trigger context storage 104 # but we don't want to create duplicate labels 105 # the backend will reject duplicate labels, so we just send context 106 # via a new endpoint we need to add... or we can use a hack: 107 # just POST to emit-label with context - it will store label + context 108 # but since label already exists, we'll get an error... hmm 109 110 # actually, looking at the code, store_label will create a new label row 111 # each time (no unique constraint on uri+val). that's intentional for 112 # labeler protocol. so we can't use emit-label for backfill. 113 114 # we need a dedicated endpoint. let's add /admin/context for this. 115 response = await client.post( 116 f"{settings.moderation_service_url}/admin/context", 117 json={ 118 "uri": uri, 119 "context": context, 120 }, 121 headers={"X-Moderation-Key": settings.moderation_auth_token}, 122 timeout=30.0, 123 ) 124 response.raise_for_status() 125 return True 126 except httpx.HTTPStatusError as e: 127 print(f" ❌ HTTP error: {e.response.status_code}") 128 try: 129 print(f" {e.response.json()}") 130 except Exception: 131 print(f" {e.response.text[:200]}") 132 return False 133 except Exception as e: 134 print(f" ❌ error: {e}") 135 return False 136 137 138async def run_backfill(env: Environment, dry_run: bool = False) -> None: 139 """backfill label context from copyright_scans.""" 140 settings = BackfillSettings() 141 142 # validate settings 143 try: 144 db_url = settings.get_database_url(env) 145 print( 146 f"✓ database: {db_url.split('@')[1].split('/')[0] if '@' in db_url else 'configured'}" 147 ) 148 except ValueError as e: 149 print(f"{e}") 150 print(f"\nset {env.upper()}_DATABASE_URL in .env") 151 sys.exit(1) 152 153 if not settings.moderation_auth_token: 154 print("❌ MODERATION_AUTH_TOKEN not set") 155 sys.exit(1) 156 157 print(f"✓ moderation service: {settings.moderation_service_url}") 158 159 # setup env before backend imports 160 setup_env(settings, env) 161 162 # import backend after env setup 163 from sqlalchemy import select 164 from sqlalchemy.orm import joinedload 165 166 from backend.models import CopyrightScan, Track 167 from backend.utilities.database import db_session 168 169 async with db_session() as db: 170 # find flagged tracks with atproto URIs and their scan results 171 stmt = ( 172 select(Track, CopyrightScan) 173 .options(joinedload(Track.artist)) 174 .join(CopyrightScan, CopyrightScan.track_id == Track.id) 175 .where(CopyrightScan.is_flagged.is_(True)) 176 .where(Track.atproto_record_uri.isnot(None)) 177 .order_by(Track.created_at.desc()) 178 ) 179 180 result = await db.execute(stmt) 181 rows = result.unique().all() 182 183 if not rows: 184 print("\n✅ no flagged tracks to backfill context for") 185 return 186 187 print(f"\n📋 found {len(rows)} flagged tracks with context to backfill") 188 189 if dry_run: 190 print("\n[DRY RUN] would store context for:") 191 for track, scan in rows: 192 print(f" - {track.id}: {track.title} by @{track.artist.handle}") 193 print(f" uri: {track.atproto_record_uri}") 194 print( 195 f" score: {scan.highest_score}, matches: {len(scan.matches or [])}" 196 ) 197 return 198 199 # store context for each track 200 async with httpx.AsyncClient() as client: 201 stored = 0 202 failed = 0 203 204 for i, (track, scan) in enumerate(rows, 1): 205 print(f"\n[{i}/{len(rows)}] storing context for: {track.title}") 206 print(f" artist: @{track.artist.handle}") 207 print(f" uri: {track.atproto_record_uri}") 208 209 context = { 210 "track_title": track.title, 211 "artist_handle": track.artist.handle if track.artist else None, 212 "artist_did": track.artist_did, 213 "highest_score": scan.highest_score, 214 "matches": scan.matches, 215 } 216 217 success = await store_context( 218 client, 219 settings, 220 track.atproto_record_uri, 221 context, 222 ) 223 224 if success: 225 stored += 1 226 print(" ✓ context stored") 227 else: 228 failed += 1 229 230 print(f"\n{'=' * 50}") 231 print("✅ backfill complete") 232 print(f" stored: {stored}") 233 print(f" failed: {failed}") 234 235 236def main() -> None: 237 """main entry point.""" 238 import argparse 239 240 parser = argparse.ArgumentParser( 241 description="backfill label context from copyright_scans" 242 ) 243 parser.add_argument( 244 "--env", 245 type=str, 246 required=True, 247 choices=["dev", "staging", "prod"], 248 help="environment to backfill", 249 ) 250 parser.add_argument( 251 "--dry-run", 252 action="store_true", 253 help="show what would be stored without making changes", 254 ) 255 256 args = parser.parse_args() 257 258 print(f"🏷️ label context backfill - {args.env}") 259 print("=" * 50) 260 261 asyncio.run(run_backfill(args.env, args.dry_run)) 262 263 264if __name__ == "__main__": 265 main()