music on atproto
plyr.fm
1//! ATProto XRPC endpoints for the labeler protocol.
2
3use std::sync::Arc;
4
5use axum::{
6 extract::{
7 ws::{Message, WebSocket, WebSocketUpgrade},
8 Query, State,
9 },
10 response::Response,
11 Json,
12};
13use futures::StreamExt;
14use serde::{Deserialize, Serialize};
15use tokio::sync::broadcast;
16use tokio_stream::wrappers::BroadcastStream;
17use tracing::error;
18
19use crate::db::LabelDb;
20use crate::labels::Label;
21use crate::state::{AppError, AppState};
22
23// --- types ---
24
25#[derive(Debug, Deserialize)]
26#[serde(rename_all = "camelCase")]
27pub struct QueryLabelsParams {
28 pub uri_patterns: String, // comma-separated
29 pub sources: Option<String>,
30 pub cursor: Option<String>,
31 pub limit: Option<i64>,
32}
33
34#[derive(Debug, Serialize)]
35pub struct QueryLabelsResponse {
36 pub cursor: Option<String>,
37 pub labels: Vec<Label>,
38}
39
40#[derive(Debug, Deserialize)]
41pub struct SubscribeLabelsParams {
42 pub cursor: Option<i64>,
43}
44
45#[derive(Serialize)]
46struct SubscribeLabelsMessage {
47 seq: i64,
48 labels: Vec<Label>,
49}
50
51// --- handlers ---
52
53/// Query labels by URI pattern.
54pub async fn query_labels(
55 State(state): State<AppState>,
56 Query(params): Query<QueryLabelsParams>,
57) -> Result<Json<QueryLabelsResponse>, AppError> {
58 let db = state.db.as_ref().ok_or(AppError::LabelerNotConfigured)?;
59
60 let uri_patterns: Vec<String> = params
61 .uri_patterns
62 .split(',')
63 .map(|s| s.trim().to_string())
64 .collect();
65 let sources: Option<Vec<String>> = params
66 .sources
67 .map(|s| s.split(',').map(|s| s.trim().to_string()).collect());
68 let limit = params.limit.unwrap_or(50).clamp(1, 250);
69
70 let (rows, cursor) = db
71 .query_labels(
72 &uri_patterns,
73 sources.as_deref(),
74 params.cursor.as_deref(),
75 limit,
76 )
77 .await?;
78
79 let labels: Vec<Label> = rows.iter().map(|r| r.to_label()).collect();
80
81 Ok(Json(QueryLabelsResponse { cursor, labels }))
82}
83
84/// WebSocket subscription for real-time label updates.
85pub async fn subscribe_labels(
86 State(state): State<AppState>,
87 Query(params): Query<SubscribeLabelsParams>,
88 ws: WebSocketUpgrade,
89) -> Result<Response, AppError> {
90 let db = state.db.clone().ok_or(AppError::LabelerNotConfigured)?;
91 let label_tx = state
92 .label_tx
93 .clone()
94 .ok_or(AppError::LabelerNotConfigured)?;
95
96 Ok(ws.on_upgrade(move |socket| handle_subscribe(socket, db, label_tx, params.cursor)))
97}
98
99async fn handle_subscribe(
100 mut socket: WebSocket,
101 db: Arc<LabelDb>,
102 label_tx: broadcast::Sender<(i64, Label)>,
103 cursor: Option<i64>,
104) {
105 // If cursor provided, backfill from that point
106 let start_seq = if let Some(c) = cursor {
107 // Send historical labels first
108 match db.get_labels_since(c, 1000).await {
109 Ok(rows) => {
110 for row in &rows {
111 let msg = SubscribeLabelsMessage {
112 seq: row.seq,
113 labels: vec![row.to_label()],
114 };
115 if let Ok(json) = serde_json::to_string(&msg) {
116 if socket.send(Message::Text(json)).await.is_err() {
117 return;
118 }
119 }
120 }
121 rows.last().map(|r| r.seq).unwrap_or(c)
122 }
123 Err(e) => {
124 error!(error = %e, "failed to backfill labels");
125 return;
126 }
127 }
128 } else {
129 // Start from current position
130 db.get_latest_seq().await.unwrap_or(0)
131 };
132
133 // Subscribe to live updates
134 let rx = label_tx.subscribe();
135 let mut stream = BroadcastStream::new(rx);
136
137 let mut last_seq = start_seq;
138
139 loop {
140 tokio::select! {
141 // Receive from broadcast
142 Some(result) = stream.next() => {
143 match result {
144 Ok((seq, label)) => {
145 if seq > last_seq {
146 let msg = SubscribeLabelsMessage {
147 seq,
148 labels: vec![label],
149 };
150 if let Ok(json) = serde_json::to_string(&msg) {
151 if socket.send(Message::Text(json)).await.is_err() {
152 break;
153 }
154 }
155 last_seq = seq;
156 }
157 }
158 Err(_) => continue, // Lagged, skip
159 }
160 }
161 // Check for client disconnect
162 msg = socket.recv() => {
163 match msg {
164 Some(Ok(Message::Close(_))) | None => break,
165 Some(Ok(Message::Ping(data))) => {
166 if socket.send(Message::Pong(data)).await.is_err() {
167 break;
168 }
169 }
170 _ => {}
171 }
172 }
173 }
174 }
175}