music on atproto
plyr.fm
1"""tests for aggregation utilities."""
2
3import pytest
4from sqlalchemy.ext.asyncio import AsyncSession
5
6from backend.models import Artist, CopyrightScan, Track, TrackLike
7from backend.utilities.aggregations import get_copyright_info, get_like_counts
8
9
10@pytest.fixture
11async def test_tracks(db_session: AsyncSession) -> list[Track]:
12 """create test tracks with varying like counts."""
13 # create artist
14 artist = Artist(
15 did="did:plc:artist123",
16 handle="artist.bsky.social",
17 display_name="Test Artist",
18 )
19 db_session.add(artist)
20 await db_session.flush()
21
22 # create tracks
23 tracks = []
24 for i in range(3):
25 track = Track(
26 title=f"Track {i}",
27 artist_did=artist.did,
28 file_id=f"file_{i}",
29 file_type="mp3",
30 atproto_record_uri=f"at://did:plc:artist123/fm.plyr.track/{i}",
31 atproto_record_cid=f"cid_{i}",
32 )
33 db_session.add(track)
34 tracks.append(track)
35
36 await db_session.commit()
37
38 # refresh to get IDs
39 for track in tracks:
40 await db_session.refresh(track)
41
42 # create likes:
43 # track 0: 2 likes
44 # track 1: 1 like
45 # track 2: 0 likes
46 likes = [
47 TrackLike(
48 track_id=tracks[0].id,
49 user_did="did:test:user1",
50 atproto_like_uri="at://did:test:user1/fm.plyr.like/1",
51 ),
52 TrackLike(
53 track_id=tracks[0].id,
54 user_did="did:test:user2",
55 atproto_like_uri="at://did:test:user2/fm.plyr.like/1",
56 ),
57 TrackLike(
58 track_id=tracks[1].id,
59 user_did="did:test:user1",
60 atproto_like_uri="at://did:test:user1/fm.plyr.like/2",
61 ),
62 ]
63
64 for like in likes:
65 db_session.add(like)
66
67 await db_session.commit()
68
69 return tracks
70
71
72async def test_get_like_counts_multiple_tracks(
73 db_session: AsyncSession, test_tracks: list[Track]
74):
75 """test getting like counts for multiple tracks."""
76 track_ids = [track.id for track in test_tracks]
77 counts = await get_like_counts(db_session, track_ids)
78
79 assert counts[test_tracks[0].id] == 2
80 assert counts[test_tracks[1].id] == 1
81 # track 2 has no likes, so it won't be in the dict
82 assert test_tracks[2].id not in counts
83
84
85async def test_get_like_counts_empty_list(db_session: AsyncSession):
86 """test that empty track list returns empty dict."""
87 counts = await get_like_counts(db_session, [])
88 assert counts == {}
89
90
91async def test_get_like_counts_no_likes(
92 db_session: AsyncSession, test_tracks: list[Track]
93):
94 """test tracks with no likes return empty dict."""
95 # only query track 2 which has no likes
96 counts = await get_like_counts(db_session, [test_tracks[2].id])
97 assert counts == {}
98
99
100async def test_get_like_counts_single_track(
101 db_session: AsyncSession, test_tracks: list[Track]
102):
103 """test getting like count for a single track."""
104 counts = await get_like_counts(db_session, [test_tracks[0].id])
105 assert counts[test_tracks[0].id] == 2
106
107
108# tests for get_copyright_info
109
110
111@pytest.fixture
112async def flagged_track(db_session: AsyncSession) -> Track:
113 """create a track with a copyright flag."""
114 artist = Artist(
115 did="did:plc:flagged",
116 handle="flagged.bsky.social",
117 display_name="Flagged Artist",
118 )
119 db_session.add(artist)
120 await db_session.flush()
121
122 track = Track(
123 title="Flagged Track",
124 artist_did=artist.did,
125 file_id="flagged_file",
126 file_type="mp3",
127 atproto_record_uri="at://did:plc:flagged/fm.plyr.track/abc123",
128 )
129 db_session.add(track)
130 await db_session.commit()
131 await db_session.refresh(track)
132
133 # add copyright scan with flag
134 scan = CopyrightScan(
135 track_id=track.id,
136 is_flagged=True,
137 highest_score=90,
138 matches=[{"title": "Copyrighted Song", "artist": "Famous Artist", "score": 90}],
139 )
140 db_session.add(scan)
141 await db_session.commit()
142
143 return track
144
145
146async def test_get_copyright_info_flagged(
147 db_session: AsyncSession, flagged_track: Track
148) -> None:
149 """test that flagged scans are returned as flagged.
150
151 get_copyright_info is now a pure read - it reads the is_flagged state
152 directly from the database. the sync_copyright_resolutions background
153 task is responsible for updating is_flagged based on labeler state.
154 """
155 result = await get_copyright_info(db_session, [flagged_track.id])
156
157 assert flagged_track.id in result
158 assert result[flagged_track.id].is_flagged is True
159 assert result[flagged_track.id].primary_match == "Copyrighted Song by Famous Artist"
160
161
162async def test_get_copyright_info_not_flagged(
163 db_session: AsyncSession, flagged_track: Track
164) -> None:
165 """test that resolved scans (is_flagged=False) are returned as not flagged."""
166 from sqlalchemy import select
167
168 # update scan to be not flagged (simulates sync_copyright_resolutions running)
169 scan = await db_session.scalar(
170 select(CopyrightScan).where(CopyrightScan.track_id == flagged_track.id)
171 )
172 assert scan is not None
173 scan.is_flagged = False
174 await db_session.commit()
175
176 result = await get_copyright_info(db_session, [flagged_track.id])
177
178 assert flagged_track.id in result
179 assert result[flagged_track.id].is_flagged is False
180 assert result[flagged_track.id].primary_match is None
181
182
183async def test_get_copyright_info_empty_list(db_session: AsyncSession) -> None:
184 """test that empty track list returns empty dict."""
185 result = await get_copyright_info(db_session, [])
186 assert result == {}
187
188
189async def test_get_copyright_info_no_scan(
190 db_session: AsyncSession, test_tracks: list[Track]
191) -> None:
192 """test that tracks without copyright scans are not included."""
193 # test_tracks fixture doesn't create copyright scans
194 track_ids = [track.id for track in test_tracks]
195
196 result = await get_copyright_info(db_session, track_ids)
197
198 # no tracks should be in result since none have scans
199 assert result == {}