music on atproto
plyr.fm
1"""tests for platform stats api endpoint."""
2
3import pytest
4from fastapi.testclient import TestClient
5from sqlalchemy.ext.asyncio import AsyncSession
6
7from backend.models import Artist, Track
8
9
10@pytest.fixture
11async def artist(db_session: AsyncSession) -> Artist:
12 """create a test artist."""
13 artist = Artist(
14 did="did:plc:stats_test_artist",
15 handle="stats-test.bsky.social",
16 display_name="Stats Test Artist",
17 )
18 db_session.add(artist)
19 await db_session.commit()
20 return artist
21
22
23@pytest.fixture
24async def tracks_with_duration(db_session: AsyncSession, artist: Artist) -> list[Track]:
25 """create multiple test tracks with duration metadata."""
26 tracks = [
27 Track(
28 title="Short Track",
29 artist_did=artist.did,
30 file_id="track1",
31 file_type="mp3",
32 extra={"duration": 180}, # 3 minutes
33 play_count=10,
34 ),
35 Track(
36 title="Long Track",
37 artist_did=artist.did,
38 file_id="track2",
39 file_type="mp3",
40 extra={"duration": 3600}, # 1 hour
41 play_count=5,
42 ),
43 Track(
44 title="Medium Track",
45 artist_did=artist.did,
46 file_id="track3",
47 file_type="mp3",
48 extra={"duration": 300}, # 5 minutes
49 play_count=20,
50 ),
51 ]
52 for track in tracks:
53 db_session.add(track)
54 await db_session.commit()
55 return tracks
56
57
58@pytest.fixture
59async def track_without_duration(db_session: AsyncSession, artist: Artist) -> Track:
60 """create a test track without duration (legacy upload)."""
61 track = Track(
62 title="No Duration Track",
63 artist_did=artist.did,
64 file_id="track_noduration",
65 file_type="mp3",
66 extra={}, # no duration
67 play_count=3,
68 )
69 db_session.add(track)
70 await db_session.commit()
71 return track
72
73
74async def test_get_stats_returns_total_duration(
75 client: TestClient,
76 tracks_with_duration: list[Track],
77) -> None:
78 """stats endpoint returns total duration in seconds."""
79 response = client.get("/stats")
80 assert response.status_code == 200
81
82 data = response.json()
83 assert "total_duration_seconds" in data
84
85 # 180 + 3600 + 300 = 4080 seconds
86 assert data["total_duration_seconds"] == 4080
87
88
89async def test_get_stats_duration_ignores_null(
90 client: TestClient,
91 tracks_with_duration: list[Track],
92 track_without_duration: Track,
93) -> None:
94 """stats endpoint handles tracks without duration gracefully."""
95 response = client.get("/stats")
96 assert response.status_code == 200
97
98 data = response.json()
99 # should still be 4080 (the track without duration doesn't add to total)
100 assert data["total_duration_seconds"] == 4080
101 # but track count should include all 4
102 assert data["total_tracks"] == 4
103
104
105async def test_get_stats_empty_database(client: TestClient) -> None:
106 """stats endpoint returns zeros for empty database."""
107 response = client.get("/stats")
108 assert response.status_code == 200
109
110 data = response.json()
111 assert data["total_plays"] == 0
112 assert data["total_tracks"] == 0
113 assert data["total_artists"] == 0
114 assert data["total_duration_seconds"] == 0
115
116
117async def test_get_stats_aggregates_play_counts(
118 client: TestClient,
119 tracks_with_duration: list[Track],
120) -> None:
121 """stats endpoint correctly aggregates play counts."""
122 response = client.get("/stats")
123 assert response.status_code == 200
124
125 data = response.json()
126 # 10 + 5 + 20 = 35 total plays
127 assert data["total_plays"] == 35
128
129
130async def test_get_stats_counts_distinct_artists(
131 client: TestClient,
132 db_session: AsyncSession,
133) -> None:
134 """stats endpoint counts distinct artists correctly."""
135 # create two artists
136 artist1 = Artist(
137 did="did:plc:artist1",
138 handle="artist1.bsky.social",
139 display_name="Artist 1",
140 )
141 artist2 = Artist(
142 did="did:plc:artist2",
143 handle="artist2.bsky.social",
144 display_name="Artist 2",
145 )
146 db_session.add_all([artist1, artist2])
147 await db_session.flush()
148
149 # create tracks from both artists
150 tracks = [
151 Track(
152 title="Track A1",
153 artist_did=artist1.did,
154 file_id="a1",
155 file_type="mp3",
156 extra={"duration": 100},
157 ),
158 Track(
159 title="Track A2",
160 artist_did=artist1.did,
161 file_id="a2",
162 file_type="mp3",
163 extra={"duration": 100},
164 ),
165 Track(
166 title="Track B1",
167 artist_did=artist2.did,
168 file_id="b1",
169 file_type="mp3",
170 extra={"duration": 100},
171 ),
172 ]
173 for track in tracks:
174 db_session.add(track)
175 await db_session.commit()
176
177 response = client.get("/stats")
178 assert response.status_code == 200
179
180 data = response.json()
181 assert data["total_tracks"] == 3
182 assert data["total_artists"] == 2
183 assert data["total_duration_seconds"] == 300