at main 16 kB view raw
1"""tests for copyright moderation integration.""" 2 3from unittest.mock import AsyncMock, Mock, patch 4 5import httpx 6import pytest 7from fastapi.testclient import TestClient 8from sqlalchemy import select 9from sqlalchemy.ext.asyncio import AsyncSession 10 11from backend._internal.moderation import ( 12 get_active_copyright_labels, 13 scan_track_for_copyright, 14) 15from backend._internal.moderation_client import ( 16 ModerationClient, 17 ScanResult, 18 SensitiveImagesResult, 19) 20from backend.models import Artist, CopyrightScan, Track 21 22 23@pytest.fixture 24def mock_scan_result() -> ScanResult: 25 """typical scan result from moderation client.""" 26 return ScanResult( 27 is_flagged=True, 28 highest_score=85, 29 matches=[ 30 { 31 "artist": "Test Artist", 32 "title": "Test Song", 33 "score": 85, 34 "isrc": "USRC12345678", 35 } 36 ], 37 raw_response={"status": "success", "result": []}, 38 ) 39 40 41@pytest.fixture 42def mock_clear_result() -> ScanResult: 43 """scan result when no copyright matches found.""" 44 return ScanResult( 45 is_flagged=False, 46 highest_score=0, 47 matches=[], 48 raw_response={"status": "success", "result": None}, 49 ) 50 51 52async def test_moderation_client_scan_success() -> None: 53 """test ModerationClient.scan() with successful response.""" 54 mock_response = Mock() 55 mock_response.json.return_value = { 56 "is_flagged": True, 57 "highest_score": 85, 58 "matches": [{"artist": "Test", "title": "Song", "score": 85}], 59 "raw_response": {"status": "success"}, 60 } 61 mock_response.raise_for_status.return_value = None 62 63 client = ModerationClient( 64 service_url="https://test.example.com", 65 labeler_url="https://labeler.example.com", 66 auth_token="test-token", 67 timeout_seconds=30, 68 label_cache_prefix="test:label:", 69 label_cache_ttl_seconds=300, 70 ) 71 72 with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: 73 mock_post.return_value = mock_response 74 75 result = await client.scan("https://example.com/audio.mp3") 76 77 assert result.is_flagged is True 78 assert result.highest_score == 85 79 assert len(result.matches) == 1 80 mock_post.assert_called_once() 81 82 83async def test_moderation_client_scan_timeout() -> None: 84 """test ModerationClient.scan() timeout handling.""" 85 client = ModerationClient( 86 service_url="https://test.example.com", 87 labeler_url="https://labeler.example.com", 88 auth_token="test-token", 89 timeout_seconds=30, 90 label_cache_prefix="test:label:", 91 label_cache_ttl_seconds=300, 92 ) 93 94 with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post: 95 mock_post.side_effect = httpx.TimeoutException("timeout") 96 97 with pytest.raises(httpx.TimeoutException): 98 await client.scan("https://example.com/audio.mp3") 99 100 101async def test_scan_track_stores_flagged_result( 102 db_session: AsyncSession, 103 mock_scan_result: ScanResult, 104) -> None: 105 """test storing a flagged scan result.""" 106 artist = Artist( 107 did="did:plc:test123", 108 handle="test.bsky.social", 109 display_name="Test User", 110 ) 111 db_session.add(artist) 112 await db_session.commit() 113 114 track = Track( 115 title="Test Track", 116 file_id="test_file_123", 117 file_type="mp3", 118 artist_did=artist.did, 119 r2_url="https://example.com/audio.mp3", 120 ) 121 db_session.add(track) 122 await db_session.commit() 123 124 with patch("backend._internal.moderation.settings") as mock_settings: 125 mock_settings.moderation.enabled = True 126 mock_settings.moderation.auth_token = "test-token" 127 128 with patch( 129 "backend._internal.moderation.get_moderation_client" 130 ) as mock_get_client: 131 mock_client = AsyncMock() 132 mock_client.scan.return_value = mock_scan_result 133 mock_get_client.return_value = mock_client 134 135 assert track.r2_url is not None 136 await scan_track_for_copyright(track.id, track.r2_url) 137 138 result = await db_session.execute( 139 select(CopyrightScan).where(CopyrightScan.track_id == track.id) 140 ) 141 scan = result.scalar_one() 142 143 assert scan.is_flagged is True 144 assert scan.highest_score == 85 145 assert len(scan.matches) == 1 146 assert scan.matches[0]["artist"] == "Test Artist" 147 148 149async def test_scan_track_stores_clear_result( 150 db_session: AsyncSession, 151 mock_clear_result: ScanResult, 152) -> None: 153 """test storing a clear (no matches) scan result.""" 154 artist = Artist( 155 did="did:plc:test456", 156 handle="clear.bsky.social", 157 display_name="Clear User", 158 ) 159 db_session.add(artist) 160 await db_session.commit() 161 162 track = Track( 163 title="Original Track", 164 file_id="original_file_456", 165 file_type="wav", 166 artist_did=artist.did, 167 r2_url="https://example.com/original.wav", 168 ) 169 db_session.add(track) 170 await db_session.commit() 171 172 with patch("backend._internal.moderation.settings") as mock_settings: 173 mock_settings.moderation.enabled = True 174 mock_settings.moderation.auth_token = "test-token" 175 176 with patch( 177 "backend._internal.moderation.get_moderation_client" 178 ) as mock_get_client: 179 mock_client = AsyncMock() 180 mock_client.scan.return_value = mock_clear_result 181 mock_get_client.return_value = mock_client 182 183 assert track.r2_url is not None 184 await scan_track_for_copyright(track.id, track.r2_url) 185 186 result = await db_session.execute( 187 select(CopyrightScan).where(CopyrightScan.track_id == track.id) 188 ) 189 scan = result.scalar_one() 190 191 assert scan.is_flagged is False 192 assert scan.highest_score == 0 193 assert scan.matches == [] 194 195 196async def test_scan_track_disabled() -> None: 197 """test that scanning is skipped when disabled.""" 198 with patch("backend._internal.moderation.settings") as mock_settings: 199 mock_settings.moderation.enabled = False 200 201 with patch( 202 "backend._internal.moderation.get_moderation_client" 203 ) as mock_get_client: 204 await scan_track_for_copyright(1, "https://example.com/audio.mp3") 205 206 # should not even get the client when disabled 207 mock_get_client.assert_not_called() 208 209 210async def test_scan_track_no_auth_token() -> None: 211 """test that scanning is skipped when auth token not configured.""" 212 with patch("backend._internal.moderation.settings") as mock_settings: 213 mock_settings.moderation.enabled = True 214 mock_settings.moderation.auth_token = "" 215 216 with patch( 217 "backend._internal.moderation.get_moderation_client" 218 ) as mock_get_client: 219 await scan_track_for_copyright(1, "https://example.com/audio.mp3") 220 221 # should not even get the client without auth token 222 mock_get_client.assert_not_called() 223 224 225async def test_scan_track_service_error_stores_as_clear( 226 db_session: AsyncSession, 227) -> None: 228 """test that service errors are stored as clear results.""" 229 artist = Artist( 230 did="did:plc:errortest", 231 handle="errortest.bsky.social", 232 display_name="Error Test User", 233 ) 234 db_session.add(artist) 235 await db_session.commit() 236 237 track = Track( 238 title="Error Test Track", 239 file_id="error_test_file", 240 file_type="mp3", 241 artist_did=artist.did, 242 r2_url="https://example.com/short.mp3", 243 ) 244 db_session.add(track) 245 await db_session.commit() 246 247 with patch("backend._internal.moderation.settings") as mock_settings: 248 mock_settings.moderation.enabled = True 249 mock_settings.moderation.auth_token = "test-token" 250 251 with patch( 252 "backend._internal.moderation.get_moderation_client" 253 ) as mock_get_client: 254 mock_client = AsyncMock() 255 mock_client.scan.side_effect = httpx.HTTPStatusError( 256 "502 error", 257 request=AsyncMock(), 258 response=AsyncMock(status_code=502), 259 ) 260 mock_get_client.return_value = mock_client 261 262 # should not raise - stores error as clear 263 await scan_track_for_copyright(track.id, "https://example.com/short.mp3") 264 265 result = await db_session.execute( 266 select(CopyrightScan).where(CopyrightScan.track_id == track.id) 267 ) 268 scan = result.scalar_one() 269 270 assert scan.is_flagged is False 271 assert scan.highest_score == 0 272 assert scan.matches == [] 273 assert "error" in scan.raw_response 274 assert scan.raw_response["status"] == "scan_failed" 275 276 277# tests for get_active_copyright_labels 278 279 280async def test_get_active_copyright_labels_empty_list() -> None: 281 """test that empty URI list returns empty set.""" 282 result = await get_active_copyright_labels([]) 283 assert result == set() 284 285 286async def test_get_active_copyright_labels_disabled() -> None: 287 """test that disabled moderation returns all URIs as active (fail closed).""" 288 uris = ["at://did:plc:test/fm.plyr.track/1", "at://did:plc:test/fm.plyr.track/2"] 289 290 with patch("backend._internal.moderation.settings") as mock_settings: 291 mock_settings.moderation.enabled = False 292 293 result = await get_active_copyright_labels(uris) 294 295 assert result == set(uris) 296 297 298async def test_get_active_copyright_labels_no_auth_token() -> None: 299 """test that missing auth token returns all URIs as active (fail closed).""" 300 uris = ["at://did:plc:test/fm.plyr.track/1"] 301 302 with patch("backend._internal.moderation.settings") as mock_settings: 303 mock_settings.moderation.enabled = True 304 mock_settings.moderation.auth_token = "" 305 306 result = await get_active_copyright_labels(uris) 307 308 assert result == set(uris) 309 310 311async def test_get_active_copyright_labels_success() -> None: 312 """test successful call to labeler returns active URIs.""" 313 uris = [ 314 "at://did:plc:success/fm.plyr.track/1", 315 "at://did:plc:success/fm.plyr.track/2", 316 "at://did:plc:success/fm.plyr.track/3", 317 ] 318 319 with patch("backend._internal.moderation.settings") as mock_settings: 320 mock_settings.moderation.enabled = True 321 mock_settings.moderation.auth_token = "test-token" 322 323 with patch( 324 "backend._internal.moderation.get_moderation_client" 325 ) as mock_get_client: 326 mock_client = AsyncMock() 327 mock_client.get_active_labels.return_value = {uris[0]} # only first active 328 mock_get_client.return_value = mock_client 329 330 result = await get_active_copyright_labels(uris) 331 332 assert result == {uris[0]} 333 334 335async def test_get_active_copyright_labels_service_error() -> None: 336 """test that service errors return all URIs as active (fail closed).""" 337 uris = [ 338 "at://did:plc:error/fm.plyr.track/1", 339 "at://did:plc:error/fm.plyr.track/2", 340 ] 341 342 with patch("backend._internal.moderation.settings") as mock_settings: 343 mock_settings.moderation.enabled = True 344 mock_settings.moderation.auth_token = "test-token" 345 346 with patch( 347 "backend._internal.moderation.get_moderation_client" 348 ) as mock_get_client: 349 mock_client = AsyncMock() 350 # client's get_active_labels fails closed internally 351 mock_client.get_active_labels.return_value = set(uris) 352 mock_get_client.return_value = mock_client 353 354 result = await get_active_copyright_labels(uris) 355 356 assert result == set(uris) 357 358 359# tests for background task 360 361 362async def test_sync_copyright_resolutions(db_session: AsyncSession) -> None: 363 """test that sync_copyright_resolutions updates flagged scans.""" 364 from backend._internal.background_tasks import sync_copyright_resolutions 365 366 # create test artist and tracks 367 artist = Artist( 368 did="did:plc:synctest", 369 handle="synctest.bsky.social", 370 display_name="Sync Test User", 371 ) 372 db_session.add(artist) 373 await db_session.commit() 374 375 # track 1: flagged, will be resolved 376 track1 = Track( 377 title="Flagged Track 1", 378 file_id="flagged_1", 379 file_type="mp3", 380 artist_did=artist.did, 381 r2_url="https://example.com/flagged1.mp3", 382 atproto_record_uri="at://did:plc:synctest/fm.plyr.track/1", 383 ) 384 db_session.add(track1) 385 386 # track 2: flagged, will stay flagged 387 track2 = Track( 388 title="Flagged Track 2", 389 file_id="flagged_2", 390 file_type="mp3", 391 artist_did=artist.did, 392 r2_url="https://example.com/flagged2.mp3", 393 atproto_record_uri="at://did:plc:synctest/fm.plyr.track/2", 394 ) 395 db_session.add(track2) 396 await db_session.commit() 397 398 # create flagged scans 399 scan1 = CopyrightScan( 400 track_id=track1.id, 401 is_flagged=True, 402 highest_score=85, 403 matches=[{"artist": "Test", "title": "Song"}], 404 raw_response={}, 405 ) 406 scan2 = CopyrightScan( 407 track_id=track2.id, 408 is_flagged=True, 409 highest_score=90, 410 matches=[{"artist": "Test", "title": "Song2"}], 411 raw_response={}, 412 ) 413 db_session.add_all([scan1, scan2]) 414 await db_session.commit() 415 416 with patch( 417 "backend._internal.moderation_client.get_moderation_client" 418 ) as mock_get_client: 419 mock_client = AsyncMock() 420 # only track2's URI is still active 421 mock_client.get_active_labels.return_value = { 422 "at://did:plc:synctest/fm.plyr.track/2" 423 } 424 mock_get_client.return_value = mock_client 425 426 await sync_copyright_resolutions() 427 428 # refresh from db 429 await db_session.refresh(scan1) 430 await db_session.refresh(scan2) 431 432 # scan1 should no longer be flagged (label was negated) 433 assert scan1.is_flagged is False 434 435 # scan2 should still be flagged 436 assert scan2.is_flagged is True 437 438 439# tests for sensitive images 440 441 442async def test_moderation_client_get_sensitive_images() -> None: 443 """test ModerationClient.get_sensitive_images() with successful response.""" 444 mock_response = Mock() 445 mock_response.json.return_value = { 446 "image_ids": ["abc123", "def456"], 447 "urls": ["https://example.com/image.jpg"], 448 } 449 mock_response.raise_for_status.return_value = None 450 451 client = ModerationClient( 452 service_url="https://test.example.com", 453 labeler_url="https://labeler.example.com", 454 auth_token="test-token", 455 timeout_seconds=30, 456 label_cache_prefix="test:label:", 457 label_cache_ttl_seconds=300, 458 ) 459 460 with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get: 461 mock_get.return_value = mock_response 462 463 result = await client.get_sensitive_images() 464 465 assert result.image_ids == ["abc123", "def456"] 466 assert result.urls == ["https://example.com/image.jpg"] 467 mock_get.assert_called_once() 468 469 470async def test_moderation_client_get_sensitive_images_empty() -> None: 471 """test ModerationClient.get_sensitive_images() with empty response.""" 472 mock_response = Mock() 473 mock_response.json.return_value = {"image_ids": [], "urls": []} 474 mock_response.raise_for_status.return_value = None 475 476 client = ModerationClient( 477 service_url="https://test.example.com", 478 labeler_url="https://labeler.example.com", 479 auth_token="test-token", 480 timeout_seconds=30, 481 label_cache_prefix="test:label:", 482 label_cache_ttl_seconds=300, 483 ) 484 485 with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get: 486 mock_get.return_value = mock_response 487 488 result = await client.get_sensitive_images() 489 490 assert result.image_ids == [] 491 assert result.urls == [] 492 493 494async def test_get_sensitive_images_endpoint( 495 client: TestClient, 496) -> None: 497 """test GET /moderation/sensitive-images endpoint proxies to moderation service.""" 498 mock_result = SensitiveImagesResult( 499 image_ids=["image1", "image2"], 500 urls=["https://example.com/avatar.jpg"], 501 ) 502 503 with patch("backend.api.moderation.get_moderation_client") as mock_get_client: 504 mock_client = AsyncMock() 505 mock_client.get_sensitive_images.return_value = mock_result 506 mock_get_client.return_value = mock_client 507 508 response = client.get("/moderation/sensitive-images") 509 510 assert response.status_code == 200 511 data = response.json() 512 assert data["image_ids"] == ["image1", "image2"] 513 assert data["urls"] == ["https://example.com/avatar.jpg"]