at main 6.1 kB view raw
1"""tests for aggregation utilities.""" 2 3import pytest 4from sqlalchemy.ext.asyncio import AsyncSession 5 6from backend.models import Artist, CopyrightScan, Track, TrackLike 7from backend.utilities.aggregations import get_copyright_info, get_like_counts 8 9 10@pytest.fixture 11async def test_tracks(db_session: AsyncSession) -> list[Track]: 12 """create test tracks with varying like counts.""" 13 # create artist 14 artist = Artist( 15 did="did:plc:artist123", 16 handle="artist.bsky.social", 17 display_name="Test Artist", 18 ) 19 db_session.add(artist) 20 await db_session.flush() 21 22 # create tracks 23 tracks = [] 24 for i in range(3): 25 track = Track( 26 title=f"Track {i}", 27 artist_did=artist.did, 28 file_id=f"file_{i}", 29 file_type="mp3", 30 atproto_record_uri=f"at://did:plc:artist123/fm.plyr.track/{i}", 31 atproto_record_cid=f"cid_{i}", 32 ) 33 db_session.add(track) 34 tracks.append(track) 35 36 await db_session.commit() 37 38 # refresh to get IDs 39 for track in tracks: 40 await db_session.refresh(track) 41 42 # create likes: 43 # track 0: 2 likes 44 # track 1: 1 like 45 # track 2: 0 likes 46 likes = [ 47 TrackLike( 48 track_id=tracks[0].id, 49 user_did="did:test:user1", 50 atproto_like_uri="at://did:test:user1/fm.plyr.like/1", 51 ), 52 TrackLike( 53 track_id=tracks[0].id, 54 user_did="did:test:user2", 55 atproto_like_uri="at://did:test:user2/fm.plyr.like/1", 56 ), 57 TrackLike( 58 track_id=tracks[1].id, 59 user_did="did:test:user1", 60 atproto_like_uri="at://did:test:user1/fm.plyr.like/2", 61 ), 62 ] 63 64 for like in likes: 65 db_session.add(like) 66 67 await db_session.commit() 68 69 return tracks 70 71 72async def test_get_like_counts_multiple_tracks( 73 db_session: AsyncSession, test_tracks: list[Track] 74): 75 """test getting like counts for multiple tracks.""" 76 track_ids = [track.id for track in test_tracks] 77 counts = await get_like_counts(db_session, track_ids) 78 79 assert counts[test_tracks[0].id] == 2 80 assert counts[test_tracks[1].id] == 1 81 # track 2 has no likes, so it won't be in the dict 82 assert test_tracks[2].id not in counts 83 84 85async def test_get_like_counts_empty_list(db_session: AsyncSession): 86 """test that empty track list returns empty dict.""" 87 counts = await get_like_counts(db_session, []) 88 assert counts == {} 89 90 91async def test_get_like_counts_no_likes( 92 db_session: AsyncSession, test_tracks: list[Track] 93): 94 """test tracks with no likes return empty dict.""" 95 # only query track 2 which has no likes 96 counts = await get_like_counts(db_session, [test_tracks[2].id]) 97 assert counts == {} 98 99 100async def test_get_like_counts_single_track( 101 db_session: AsyncSession, test_tracks: list[Track] 102): 103 """test getting like count for a single track.""" 104 counts = await get_like_counts(db_session, [test_tracks[0].id]) 105 assert counts[test_tracks[0].id] == 2 106 107 108# tests for get_copyright_info 109 110 111@pytest.fixture 112async def flagged_track(db_session: AsyncSession) -> Track: 113 """create a track with a copyright flag.""" 114 artist = Artist( 115 did="did:plc:flagged", 116 handle="flagged.bsky.social", 117 display_name="Flagged Artist", 118 ) 119 db_session.add(artist) 120 await db_session.flush() 121 122 track = Track( 123 title="Flagged Track", 124 artist_did=artist.did, 125 file_id="flagged_file", 126 file_type="mp3", 127 atproto_record_uri="at://did:plc:flagged/fm.plyr.track/abc123", 128 ) 129 db_session.add(track) 130 await db_session.commit() 131 await db_session.refresh(track) 132 133 # add copyright scan with flag 134 scan = CopyrightScan( 135 track_id=track.id, 136 is_flagged=True, 137 highest_score=90, 138 matches=[{"title": "Copyrighted Song", "artist": "Famous Artist", "score": 90}], 139 ) 140 db_session.add(scan) 141 await db_session.commit() 142 143 return track 144 145 146async def test_get_copyright_info_flagged( 147 db_session: AsyncSession, flagged_track: Track 148) -> None: 149 """test that flagged scans are returned as flagged. 150 151 get_copyright_info is now a pure read - it reads the is_flagged state 152 directly from the database. the sync_copyright_resolutions background 153 task is responsible for updating is_flagged based on labeler state. 154 """ 155 result = await get_copyright_info(db_session, [flagged_track.id]) 156 157 assert flagged_track.id in result 158 assert result[flagged_track.id].is_flagged is True 159 assert result[flagged_track.id].primary_match == "Copyrighted Song by Famous Artist" 160 161 162async def test_get_copyright_info_not_flagged( 163 db_session: AsyncSession, flagged_track: Track 164) -> None: 165 """test that resolved scans (is_flagged=False) are returned as not flagged.""" 166 from sqlalchemy import select 167 168 # update scan to be not flagged (simulates sync_copyright_resolutions running) 169 scan = await db_session.scalar( 170 select(CopyrightScan).where(CopyrightScan.track_id == flagged_track.id) 171 ) 172 assert scan is not None 173 scan.is_flagged = False 174 await db_session.commit() 175 176 result = await get_copyright_info(db_session, [flagged_track.id]) 177 178 assert flagged_track.id in result 179 assert result[flagged_track.id].is_flagged is False 180 assert result[flagged_track.id].primary_match is None 181 182 183async def test_get_copyright_info_empty_list(db_session: AsyncSession) -> None: 184 """test that empty track list returns empty dict.""" 185 result = await get_copyright_info(db_session, []) 186 assert result == {} 187 188 189async def test_get_copyright_info_no_scan( 190 db_session: AsyncSession, test_tracks: list[Track] 191) -> None: 192 """test that tracks without copyright scans are not included.""" 193 # test_tracks fixture doesn't create copyright scans 194 track_ids = [track.id for track in test_tracks] 195 196 result = await get_copyright_info(db_session, track_ids) 197 198 # no tracks should be in result since none have scans 199 assert result == {}