at main 5.1 kB view raw
1//! ATProto XRPC endpoints for the labeler protocol. 2 3use std::sync::Arc; 4 5use axum::{ 6 extract::{ 7 ws::{Message, WebSocket, WebSocketUpgrade}, 8 Query, State, 9 }, 10 response::Response, 11 Json, 12}; 13use futures::StreamExt; 14use serde::{Deserialize, Serialize}; 15use tokio::sync::broadcast; 16use tokio_stream::wrappers::BroadcastStream; 17use tracing::error; 18 19use crate::db::LabelDb; 20use crate::labels::Label; 21use crate::state::{AppError, AppState}; 22 23// --- types --- 24 25#[derive(Debug, Deserialize)] 26#[serde(rename_all = "camelCase")] 27pub struct QueryLabelsParams { 28 pub uri_patterns: String, // comma-separated 29 pub sources: Option<String>, 30 pub cursor: Option<String>, 31 pub limit: Option<i64>, 32} 33 34#[derive(Debug, Serialize)] 35pub struct QueryLabelsResponse { 36 pub cursor: Option<String>, 37 pub labels: Vec<Label>, 38} 39 40#[derive(Debug, Deserialize)] 41pub struct SubscribeLabelsParams { 42 pub cursor: Option<i64>, 43} 44 45#[derive(Serialize)] 46struct SubscribeLabelsMessage { 47 seq: i64, 48 labels: Vec<Label>, 49} 50 51// --- handlers --- 52 53/// Query labels by URI pattern. 54pub async fn query_labels( 55 State(state): State<AppState>, 56 Query(params): Query<QueryLabelsParams>, 57) -> Result<Json<QueryLabelsResponse>, AppError> { 58 let db = state.db.as_ref().ok_or(AppError::LabelerNotConfigured)?; 59 60 let uri_patterns: Vec<String> = params 61 .uri_patterns 62 .split(',') 63 .map(|s| s.trim().to_string()) 64 .collect(); 65 let sources: Option<Vec<String>> = params 66 .sources 67 .map(|s| s.split(',').map(|s| s.trim().to_string()).collect()); 68 let limit = params.limit.unwrap_or(50).clamp(1, 250); 69 70 let (rows, cursor) = db 71 .query_labels( 72 &uri_patterns, 73 sources.as_deref(), 74 params.cursor.as_deref(), 75 limit, 76 ) 77 .await?; 78 79 let labels: Vec<Label> = rows.iter().map(|r| r.to_label()).collect(); 80 81 Ok(Json(QueryLabelsResponse { cursor, labels })) 82} 83 84/// WebSocket subscription for real-time label updates. 85pub async fn subscribe_labels( 86 State(state): State<AppState>, 87 Query(params): Query<SubscribeLabelsParams>, 88 ws: WebSocketUpgrade, 89) -> Result<Response, AppError> { 90 let db = state.db.clone().ok_or(AppError::LabelerNotConfigured)?; 91 let label_tx = state 92 .label_tx 93 .clone() 94 .ok_or(AppError::LabelerNotConfigured)?; 95 96 Ok(ws.on_upgrade(move |socket| handle_subscribe(socket, db, label_tx, params.cursor))) 97} 98 99async fn handle_subscribe( 100 mut socket: WebSocket, 101 db: Arc<LabelDb>, 102 label_tx: broadcast::Sender<(i64, Label)>, 103 cursor: Option<i64>, 104) { 105 // If cursor provided, backfill from that point 106 let start_seq = if let Some(c) = cursor { 107 // Send historical labels first 108 match db.get_labels_since(c, 1000).await { 109 Ok(rows) => { 110 for row in &rows { 111 let msg = SubscribeLabelsMessage { 112 seq: row.seq, 113 labels: vec![row.to_label()], 114 }; 115 if let Ok(json) = serde_json::to_string(&msg) { 116 if socket.send(Message::Text(json)).await.is_err() { 117 return; 118 } 119 } 120 } 121 rows.last().map(|r| r.seq).unwrap_or(c) 122 } 123 Err(e) => { 124 error!(error = %e, "failed to backfill labels"); 125 return; 126 } 127 } 128 } else { 129 // Start from current position 130 db.get_latest_seq().await.unwrap_or(0) 131 }; 132 133 // Subscribe to live updates 134 let rx = label_tx.subscribe(); 135 let mut stream = BroadcastStream::new(rx); 136 137 let mut last_seq = start_seq; 138 139 loop { 140 tokio::select! { 141 // Receive from broadcast 142 Some(result) = stream.next() => { 143 match result { 144 Ok((seq, label)) => { 145 if seq > last_seq { 146 let msg = SubscribeLabelsMessage { 147 seq, 148 labels: vec![label], 149 }; 150 if let Ok(json) = serde_json::to_string(&msg) { 151 if socket.send(Message::Text(json)).await.is_err() { 152 break; 153 } 154 } 155 last_seq = seq; 156 } 157 } 158 Err(_) => continue, // Lagged, skip 159 } 160 } 161 // Check for client disconnect 162 msg = socket.recv() => { 163 match msg { 164 Some(Ok(Message::Close(_))) | None => break, 165 Some(Ok(Message::Ping(data))) => { 166 if socket.send(Message::Pong(data)).await.is_err() { 167 break; 168 } 169 } 170 _ => {} 171 } 172 } 173 } 174 } 175}