music on atproto
plyr.fm
1"""tests for copyright moderation integration."""
2
3from unittest.mock import AsyncMock, Mock, patch
4
5import httpx
6import pytest
7from fastapi.testclient import TestClient
8from sqlalchemy import select
9from sqlalchemy.ext.asyncio import AsyncSession
10
11from backend._internal.moderation import (
12 get_active_copyright_labels,
13 scan_track_for_copyright,
14)
15from backend._internal.moderation_client import (
16 ModerationClient,
17 ScanResult,
18 SensitiveImagesResult,
19)
20from backend.models import Artist, CopyrightScan, Track
21
22
23@pytest.fixture
24def mock_scan_result() -> ScanResult:
25 """typical scan result from moderation client."""
26 return ScanResult(
27 is_flagged=True,
28 highest_score=85,
29 matches=[
30 {
31 "artist": "Test Artist",
32 "title": "Test Song",
33 "score": 85,
34 "isrc": "USRC12345678",
35 }
36 ],
37 raw_response={"status": "success", "result": []},
38 )
39
40
41@pytest.fixture
42def mock_clear_result() -> ScanResult:
43 """scan result when no copyright matches found."""
44 return ScanResult(
45 is_flagged=False,
46 highest_score=0,
47 matches=[],
48 raw_response={"status": "success", "result": None},
49 )
50
51
52async def test_moderation_client_scan_success() -> None:
53 """test ModerationClient.scan() with successful response."""
54 mock_response = Mock()
55 mock_response.json.return_value = {
56 "is_flagged": True,
57 "highest_score": 85,
58 "matches": [{"artist": "Test", "title": "Song", "score": 85}],
59 "raw_response": {"status": "success"},
60 }
61 mock_response.raise_for_status.return_value = None
62
63 client = ModerationClient(
64 service_url="https://test.example.com",
65 labeler_url="https://labeler.example.com",
66 auth_token="test-token",
67 timeout_seconds=30,
68 label_cache_prefix="test:label:",
69 label_cache_ttl_seconds=300,
70 )
71
72 with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
73 mock_post.return_value = mock_response
74
75 result = await client.scan("https://example.com/audio.mp3")
76
77 assert result.is_flagged is True
78 assert result.highest_score == 85
79 assert len(result.matches) == 1
80 mock_post.assert_called_once()
81
82
83async def test_moderation_client_scan_timeout() -> None:
84 """test ModerationClient.scan() timeout handling."""
85 client = ModerationClient(
86 service_url="https://test.example.com",
87 labeler_url="https://labeler.example.com",
88 auth_token="test-token",
89 timeout_seconds=30,
90 label_cache_prefix="test:label:",
91 label_cache_ttl_seconds=300,
92 )
93
94 with patch("httpx.AsyncClient.post", new_callable=AsyncMock) as mock_post:
95 mock_post.side_effect = httpx.TimeoutException("timeout")
96
97 with pytest.raises(httpx.TimeoutException):
98 await client.scan("https://example.com/audio.mp3")
99
100
101async def test_scan_track_stores_flagged_result(
102 db_session: AsyncSession,
103 mock_scan_result: ScanResult,
104) -> None:
105 """test storing a flagged scan result."""
106 artist = Artist(
107 did="did:plc:test123",
108 handle="test.bsky.social",
109 display_name="Test User",
110 )
111 db_session.add(artist)
112 await db_session.commit()
113
114 track = Track(
115 title="Test Track",
116 file_id="test_file_123",
117 file_type="mp3",
118 artist_did=artist.did,
119 r2_url="https://example.com/audio.mp3",
120 )
121 db_session.add(track)
122 await db_session.commit()
123
124 with patch("backend._internal.moderation.settings") as mock_settings:
125 mock_settings.moderation.enabled = True
126 mock_settings.moderation.auth_token = "test-token"
127
128 with patch(
129 "backend._internal.moderation.get_moderation_client"
130 ) as mock_get_client:
131 mock_client = AsyncMock()
132 mock_client.scan.return_value = mock_scan_result
133 mock_get_client.return_value = mock_client
134
135 assert track.r2_url is not None
136 await scan_track_for_copyright(track.id, track.r2_url)
137
138 result = await db_session.execute(
139 select(CopyrightScan).where(CopyrightScan.track_id == track.id)
140 )
141 scan = result.scalar_one()
142
143 assert scan.is_flagged is True
144 assert scan.highest_score == 85
145 assert len(scan.matches) == 1
146 assert scan.matches[0]["artist"] == "Test Artist"
147
148
149async def test_scan_track_stores_clear_result(
150 db_session: AsyncSession,
151 mock_clear_result: ScanResult,
152) -> None:
153 """test storing a clear (no matches) scan result."""
154 artist = Artist(
155 did="did:plc:test456",
156 handle="clear.bsky.social",
157 display_name="Clear User",
158 )
159 db_session.add(artist)
160 await db_session.commit()
161
162 track = Track(
163 title="Original Track",
164 file_id="original_file_456",
165 file_type="wav",
166 artist_did=artist.did,
167 r2_url="https://example.com/original.wav",
168 )
169 db_session.add(track)
170 await db_session.commit()
171
172 with patch("backend._internal.moderation.settings") as mock_settings:
173 mock_settings.moderation.enabled = True
174 mock_settings.moderation.auth_token = "test-token"
175
176 with patch(
177 "backend._internal.moderation.get_moderation_client"
178 ) as mock_get_client:
179 mock_client = AsyncMock()
180 mock_client.scan.return_value = mock_clear_result
181 mock_get_client.return_value = mock_client
182
183 assert track.r2_url is not None
184 await scan_track_for_copyright(track.id, track.r2_url)
185
186 result = await db_session.execute(
187 select(CopyrightScan).where(CopyrightScan.track_id == track.id)
188 )
189 scan = result.scalar_one()
190
191 assert scan.is_flagged is False
192 assert scan.highest_score == 0
193 assert scan.matches == []
194
195
196async def test_scan_track_disabled() -> None:
197 """test that scanning is skipped when disabled."""
198 with patch("backend._internal.moderation.settings") as mock_settings:
199 mock_settings.moderation.enabled = False
200
201 with patch(
202 "backend._internal.moderation.get_moderation_client"
203 ) as mock_get_client:
204 await scan_track_for_copyright(1, "https://example.com/audio.mp3")
205
206 # should not even get the client when disabled
207 mock_get_client.assert_not_called()
208
209
210async def test_scan_track_no_auth_token() -> None:
211 """test that scanning is skipped when auth token not configured."""
212 with patch("backend._internal.moderation.settings") as mock_settings:
213 mock_settings.moderation.enabled = True
214 mock_settings.moderation.auth_token = ""
215
216 with patch(
217 "backend._internal.moderation.get_moderation_client"
218 ) as mock_get_client:
219 await scan_track_for_copyright(1, "https://example.com/audio.mp3")
220
221 # should not even get the client without auth token
222 mock_get_client.assert_not_called()
223
224
225async def test_scan_track_service_error_stores_as_clear(
226 db_session: AsyncSession,
227) -> None:
228 """test that service errors are stored as clear results."""
229 artist = Artist(
230 did="did:plc:errortest",
231 handle="errortest.bsky.social",
232 display_name="Error Test User",
233 )
234 db_session.add(artist)
235 await db_session.commit()
236
237 track = Track(
238 title="Error Test Track",
239 file_id="error_test_file",
240 file_type="mp3",
241 artist_did=artist.did,
242 r2_url="https://example.com/short.mp3",
243 )
244 db_session.add(track)
245 await db_session.commit()
246
247 with patch("backend._internal.moderation.settings") as mock_settings:
248 mock_settings.moderation.enabled = True
249 mock_settings.moderation.auth_token = "test-token"
250
251 with patch(
252 "backend._internal.moderation.get_moderation_client"
253 ) as mock_get_client:
254 mock_client = AsyncMock()
255 mock_client.scan.side_effect = httpx.HTTPStatusError(
256 "502 error",
257 request=AsyncMock(),
258 response=AsyncMock(status_code=502),
259 )
260 mock_get_client.return_value = mock_client
261
262 # should not raise - stores error as clear
263 await scan_track_for_copyright(track.id, "https://example.com/short.mp3")
264
265 result = await db_session.execute(
266 select(CopyrightScan).where(CopyrightScan.track_id == track.id)
267 )
268 scan = result.scalar_one()
269
270 assert scan.is_flagged is False
271 assert scan.highest_score == 0
272 assert scan.matches == []
273 assert "error" in scan.raw_response
274 assert scan.raw_response["status"] == "scan_failed"
275
276
277# tests for get_active_copyright_labels
278
279
280async def test_get_active_copyright_labels_empty_list() -> None:
281 """test that empty URI list returns empty set."""
282 result = await get_active_copyright_labels([])
283 assert result == set()
284
285
286async def test_get_active_copyright_labels_disabled() -> None:
287 """test that disabled moderation returns all URIs as active (fail closed)."""
288 uris = ["at://did:plc:test/fm.plyr.track/1", "at://did:plc:test/fm.plyr.track/2"]
289
290 with patch("backend._internal.moderation.settings") as mock_settings:
291 mock_settings.moderation.enabled = False
292
293 result = await get_active_copyright_labels(uris)
294
295 assert result == set(uris)
296
297
298async def test_get_active_copyright_labels_no_auth_token() -> None:
299 """test that missing auth token returns all URIs as active (fail closed)."""
300 uris = ["at://did:plc:test/fm.plyr.track/1"]
301
302 with patch("backend._internal.moderation.settings") as mock_settings:
303 mock_settings.moderation.enabled = True
304 mock_settings.moderation.auth_token = ""
305
306 result = await get_active_copyright_labels(uris)
307
308 assert result == set(uris)
309
310
311async def test_get_active_copyright_labels_success() -> None:
312 """test successful call to labeler returns active URIs."""
313 uris = [
314 "at://did:plc:success/fm.plyr.track/1",
315 "at://did:plc:success/fm.plyr.track/2",
316 "at://did:plc:success/fm.plyr.track/3",
317 ]
318
319 with patch("backend._internal.moderation.settings") as mock_settings:
320 mock_settings.moderation.enabled = True
321 mock_settings.moderation.auth_token = "test-token"
322
323 with patch(
324 "backend._internal.moderation.get_moderation_client"
325 ) as mock_get_client:
326 mock_client = AsyncMock()
327 mock_client.get_active_labels.return_value = {uris[0]} # only first active
328 mock_get_client.return_value = mock_client
329
330 result = await get_active_copyright_labels(uris)
331
332 assert result == {uris[0]}
333
334
335async def test_get_active_copyright_labels_service_error() -> None:
336 """test that service errors return all URIs as active (fail closed)."""
337 uris = [
338 "at://did:plc:error/fm.plyr.track/1",
339 "at://did:plc:error/fm.plyr.track/2",
340 ]
341
342 with patch("backend._internal.moderation.settings") as mock_settings:
343 mock_settings.moderation.enabled = True
344 mock_settings.moderation.auth_token = "test-token"
345
346 with patch(
347 "backend._internal.moderation.get_moderation_client"
348 ) as mock_get_client:
349 mock_client = AsyncMock()
350 # client's get_active_labels fails closed internally
351 mock_client.get_active_labels.return_value = set(uris)
352 mock_get_client.return_value = mock_client
353
354 result = await get_active_copyright_labels(uris)
355
356 assert result == set(uris)
357
358
359# tests for background task
360
361
362async def test_sync_copyright_resolutions(db_session: AsyncSession) -> None:
363 """test that sync_copyright_resolutions updates flagged scans."""
364 from backend._internal.background_tasks import sync_copyright_resolutions
365
366 # create test artist and tracks
367 artist = Artist(
368 did="did:plc:synctest",
369 handle="synctest.bsky.social",
370 display_name="Sync Test User",
371 )
372 db_session.add(artist)
373 await db_session.commit()
374
375 # track 1: flagged, will be resolved
376 track1 = Track(
377 title="Flagged Track 1",
378 file_id="flagged_1",
379 file_type="mp3",
380 artist_did=artist.did,
381 r2_url="https://example.com/flagged1.mp3",
382 atproto_record_uri="at://did:plc:synctest/fm.plyr.track/1",
383 )
384 db_session.add(track1)
385
386 # track 2: flagged, will stay flagged
387 track2 = Track(
388 title="Flagged Track 2",
389 file_id="flagged_2",
390 file_type="mp3",
391 artist_did=artist.did,
392 r2_url="https://example.com/flagged2.mp3",
393 atproto_record_uri="at://did:plc:synctest/fm.plyr.track/2",
394 )
395 db_session.add(track2)
396 await db_session.commit()
397
398 # create flagged scans
399 scan1 = CopyrightScan(
400 track_id=track1.id,
401 is_flagged=True,
402 highest_score=85,
403 matches=[{"artist": "Test", "title": "Song"}],
404 raw_response={},
405 )
406 scan2 = CopyrightScan(
407 track_id=track2.id,
408 is_flagged=True,
409 highest_score=90,
410 matches=[{"artist": "Test", "title": "Song2"}],
411 raw_response={},
412 )
413 db_session.add_all([scan1, scan2])
414 await db_session.commit()
415
416 with patch(
417 "backend._internal.moderation_client.get_moderation_client"
418 ) as mock_get_client:
419 mock_client = AsyncMock()
420 # only track2's URI is still active
421 mock_client.get_active_labels.return_value = {
422 "at://did:plc:synctest/fm.plyr.track/2"
423 }
424 mock_get_client.return_value = mock_client
425
426 await sync_copyright_resolutions()
427
428 # refresh from db
429 await db_session.refresh(scan1)
430 await db_session.refresh(scan2)
431
432 # scan1 should no longer be flagged (label was negated)
433 assert scan1.is_flagged is False
434
435 # scan2 should still be flagged
436 assert scan2.is_flagged is True
437
438
439# tests for sensitive images
440
441
442async def test_moderation_client_get_sensitive_images() -> None:
443 """test ModerationClient.get_sensitive_images() with successful response."""
444 mock_response = Mock()
445 mock_response.json.return_value = {
446 "image_ids": ["abc123", "def456"],
447 "urls": ["https://example.com/image.jpg"],
448 }
449 mock_response.raise_for_status.return_value = None
450
451 client = ModerationClient(
452 service_url="https://test.example.com",
453 labeler_url="https://labeler.example.com",
454 auth_token="test-token",
455 timeout_seconds=30,
456 label_cache_prefix="test:label:",
457 label_cache_ttl_seconds=300,
458 )
459
460 with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
461 mock_get.return_value = mock_response
462
463 result = await client.get_sensitive_images()
464
465 assert result.image_ids == ["abc123", "def456"]
466 assert result.urls == ["https://example.com/image.jpg"]
467 mock_get.assert_called_once()
468
469
470async def test_moderation_client_get_sensitive_images_empty() -> None:
471 """test ModerationClient.get_sensitive_images() with empty response."""
472 mock_response = Mock()
473 mock_response.json.return_value = {"image_ids": [], "urls": []}
474 mock_response.raise_for_status.return_value = None
475
476 client = ModerationClient(
477 service_url="https://test.example.com",
478 labeler_url="https://labeler.example.com",
479 auth_token="test-token",
480 timeout_seconds=30,
481 label_cache_prefix="test:label:",
482 label_cache_ttl_seconds=300,
483 )
484
485 with patch("httpx.AsyncClient.get", new_callable=AsyncMock) as mock_get:
486 mock_get.return_value = mock_response
487
488 result = await client.get_sensitive_images()
489
490 assert result.image_ids == []
491 assert result.urls == []
492
493
494async def test_get_sensitive_images_endpoint(
495 client: TestClient,
496) -> None:
497 """test GET /moderation/sensitive-images endpoint proxies to moderation service."""
498 mock_result = SensitiveImagesResult(
499 image_ids=["image1", "image2"],
500 urls=["https://example.com/avatar.jpg"],
501 )
502
503 with patch("backend.api.moderation.get_moderation_client") as mock_get_client:
504 mock_client = AsyncMock()
505 mock_client.get_sensitive_images.return_value = mock_result
506 mock_get_client.return_value = mock_client
507
508 response = client.get("/moderation/sensitive-images")
509
510 assert response.status_code == 200
511 data = response.json()
512 assert data["image_ids"] == ["image1", "image2"]
513 assert data["urls"] == ["https://example.com/avatar.jpg"]