at main 11 kB view raw
1//! HTTP request handlers for core endpoints. 2 3use axum::{ 4 extract::{Multipart, State}, 5 response::Html, 6 Json, 7}; 8use serde::{Deserialize, Serialize}; 9use tracing::info; 10 11use crate::db::{CopyrightMatch, LabelContext}; 12use crate::labels::Label; 13use crate::state::{AppError, AppState}; 14 15// --- types --- 16 17#[derive(Debug, Serialize)] 18pub struct HealthResponse { 19 pub status: &'static str, 20 pub labeler_enabled: bool, 21} 22 23/// Context info for display in admin UI. 24#[derive(Debug, Deserialize)] 25pub struct EmitLabelContext { 26 pub track_id: Option<i64>, 27 pub track_title: Option<String>, 28 pub artist_handle: Option<String>, 29 pub artist_did: Option<String>, 30 pub highest_score: Option<f64>, 31 pub matches: Option<Vec<CopyrightMatch>>, 32} 33 34#[derive(Debug, Deserialize)] 35pub struct EmitLabelRequest { 36 /// AT URI of the resource to label (e.g., at://did:plc:xxx/fm.plyr.track/abc123) 37 pub uri: String, 38 /// Label value (e.g., "copyright-violation") 39 #[serde(default = "default_label_val")] 40 pub val: String, 41 /// Optional CID of specific version 42 pub cid: Option<String>, 43 /// If true, negate an existing label 44 #[serde(default)] 45 pub neg: bool, 46 /// Optional context for admin UI display 47 pub context: Option<EmitLabelContext>, 48} 49 50fn default_label_val() -> String { 51 "copyright-violation".to_string() 52} 53 54/// Normalize a score from integer (0-100) to float (0.0-1.0) range. 55/// AuDD returns scores as integers like 85 meaning 85%. 56fn normalize_score(score: f64) -> f64 { 57 if score > 1.0 { 58 score / 100.0 59 } else { 60 score 61 } 62} 63 64#[derive(Debug, Serialize)] 65pub struct EmitLabelResponse { 66 pub seq: i64, 67 pub label: Label, 68} 69 70/// Response for sensitive images endpoint. 71#[derive(Debug, Serialize)] 72pub struct SensitiveImagesResponse { 73 /// R2 image IDs (for track/album artwork) 74 pub image_ids: Vec<String>, 75 /// Full URLs (for external images like avatars) 76 pub urls: Vec<String>, 77} 78 79// --- handlers --- 80 81/// Health check endpoint. 82pub async fn health(State(state): State<AppState>) -> Json<HealthResponse> { 83 Json(HealthResponse { 84 status: "ok", 85 labeler_enabled: state.db.is_some(), 86 }) 87} 88 89/// Landing page with service info. 90pub async fn landing(State(state): State<AppState>) -> Html<String> { 91 let labeler_did = state 92 .signer 93 .as_ref() 94 .map(|s| s.did().to_string()) 95 .unwrap_or_else(|| "not configured".to_string()); 96 97 Html(format!( 98 r#"<!DOCTYPE html> 99<html> 100<head> 101 <meta charset="utf-8"> 102 <meta name="viewport" content="width=device-width, initial-scale=1"> 103 <title>plyr.fm moderation</title> 104 <style> 105 body {{ 106 font-family: system-ui, -apple-system, sans-serif; 107 background: #0a0a0a; 108 color: #e5e5e5; 109 max-width: 600px; 110 margin: 80px auto; 111 padding: 20px; 112 line-height: 1.6; 113 }} 114 h1 {{ color: #fff; margin-bottom: 8px; }} 115 .subtitle {{ color: #888; margin-bottom: 32px; }} 116 a {{ color: #3b82f6; }} 117 code {{ 118 background: #1a1a1a; 119 padding: 2px 6px; 120 border-radius: 4px; 121 font-size: 0.9em; 122 }} 123 .endpoint {{ 124 background: #111; 125 border: 1px solid #222; 126 border-radius: 8px; 127 padding: 16px; 128 margin: 12px 0; 129 }} 130 .endpoint-name {{ color: #10b981; font-family: monospace; }} 131 </style> 132</head> 133<body> 134 <h1>plyr.fm moderation</h1> 135 <p class="subtitle">ATProto labeler for audio content moderation</p> 136 137 <p>This service provides content labels for <a href="https://plyr.fm">plyr.fm</a>, 138 the music streaming platform on ATProto.</p> 139 140 <p><strong>Labeler DID:</strong> <code>{}</code></p> 141 142 <h2>Endpoints</h2> 143 144 <div class="endpoint"> 145 <div class="endpoint-name">GET /xrpc/com.atproto.label.queryLabels</div> 146 <p>Query labels by URI pattern</p> 147 </div> 148 149 <div class="endpoint"> 150 <div class="endpoint-name">GET /xrpc/com.atproto.label.subscribeLabels</div> 151 <p>WebSocket subscription for real-time label updates</p> 152 </div> 153 154 <p style="margin-top: 32px; color: #666;"> 155 <a href="https://bsky.app/profile/moderation.plyr.fm">@moderation.plyr.fm</a> 156 </p> 157</body> 158</html>"#, 159 labeler_did 160 )) 161} 162 163/// Emit a new label (internal API). 164pub async fn emit_label( 165 State(state): State<AppState>, 166 Json(request): Json<EmitLabelRequest>, 167) -> Result<Json<EmitLabelResponse>, AppError> { 168 let db = state.db.as_ref().ok_or(AppError::LabelerNotConfigured)?; 169 let signer = state 170 .signer 171 .as_ref() 172 .ok_or(AppError::LabelerNotConfigured)?; 173 174 info!(uri = %request.uri, val = %request.val, neg = request.neg, "emitting label"); 175 176 // Create and sign the label 177 let mut label = Label::new(signer.did(), &request.uri, &request.val); 178 if let Some(cid) = request.cid { 179 label = label.with_cid(cid); 180 } 181 if request.neg { 182 label = label.negated(); 183 } 184 let label = signer.sign_label(label)?; 185 186 // Store in database 187 let seq = db.store_label(&label).await?; 188 info!(seq, uri = %request.uri, "label stored"); 189 190 // Store context if provided (for admin UI) 191 if let Some(ctx) = request.context { 192 let label_ctx = LabelContext { 193 track_id: ctx.track_id, 194 track_title: ctx.track_title, 195 artist_handle: ctx.artist_handle, 196 artist_did: ctx.artist_did, 197 highest_score: ctx.highest_score.map(normalize_score), 198 matches: ctx.matches.map(|matches| { 199 matches 200 .into_iter() 201 .map(|mut m| { 202 m.score = normalize_score(m.score); 203 m 204 }) 205 .collect() 206 }), 207 resolution_reason: None, 208 resolution_notes: None, 209 }; 210 if let Err(e) = db.store_context(&request.uri, &label_ctx).await { 211 // Log but don't fail - context is supplementary 212 tracing::warn!(uri = %request.uri, error = %e, "failed to store label context"); 213 } 214 } 215 216 // Broadcast to subscribers 217 if let Some(tx) = &state.label_tx { 218 let _ = tx.send((seq, label.clone())); 219 } 220 221 Ok(Json(EmitLabelResponse { seq, label })) 222} 223 224/// Get all sensitive images (public endpoint). 225/// 226/// Returns image_ids (R2 storage IDs) and urls (full URLs) for all flagged images. 227/// Clients should check both lists when determining if an image is sensitive. 228pub async fn get_sensitive_images( 229 State(state): State<AppState>, 230) -> Result<Json<SensitiveImagesResponse>, AppError> { 231 let db = state.db.as_ref().ok_or(AppError::LabelerNotConfigured)?; 232 233 let images = db.get_sensitive_images().await?; 234 235 let image_ids: Vec<String> = images.iter().filter_map(|i| i.image_id.clone()).collect(); 236 let urls: Vec<String> = images.iter().filter_map(|i| i.url.clone()).collect(); 237 238 Ok(Json(SensitiveImagesResponse { image_ids, urls })) 239} 240 241// --- image moderation --- 242 243/// Response from image scanning endpoint. 244#[derive(Debug, Serialize)] 245pub struct ScanImageResponse { 246 pub is_safe: bool, 247 pub reason: Option<String>, 248 pub severity: String, 249 pub violated_categories: Vec<String>, 250} 251 252/// Scan an image for policy violations using Claude vision. 253/// 254/// Accepts multipart form with: 255/// - `image`: the image file to scan 256/// - `image_id`: identifier for tracking (e.g., R2 file ID) 257/// 258/// Returns moderation result. If image is not safe, it's automatically 259/// added to the sensitive_images table. 260pub async fn scan_image( 261 State(state): State<AppState>, 262 mut multipart: Multipart, 263) -> Result<Json<ScanImageResponse>, AppError> { 264 let claude = state 265 .claude 266 .as_ref() 267 .ok_or(AppError::ImageModerationNotConfigured)?; 268 let db = state 269 .db 270 .as_ref() 271 .ok_or(AppError::ImageModerationNotConfigured)?; 272 273 let mut image_bytes: Option<Vec<u8>> = None; 274 let mut image_id: Option<String> = None; 275 let mut media_type = "image/png".to_string(); 276 277 // Parse multipart form 278 while let Some(field) = multipart 279 .next_field() 280 .await 281 .map_err(|e| AppError::BadRequest(format!("multipart error: {e}")))? 282 { 283 let name = field.name().unwrap_or_default().to_string(); 284 285 match name.as_str() { 286 "image" => { 287 // Get content type from field 288 if let Some(ct) = field.content_type() { 289 media_type = ct.to_string(); 290 } 291 image_bytes = Some( 292 field 293 .bytes() 294 .await 295 .map_err(|e| AppError::BadRequest(format!("failed to read image: {e}")))? 296 .to_vec(), 297 ); 298 } 299 "image_id" => { 300 image_id = Some( 301 field 302 .text() 303 .await 304 .map_err(|e| AppError::BadRequest(format!("failed to read image_id: {e}")))?, 305 ); 306 } 307 _ => {} 308 } 309 } 310 311 let image_bytes = 312 image_bytes.ok_or_else(|| AppError::BadRequest("missing 'image' field".to_string()))?; 313 let image_id = 314 image_id.ok_or_else(|| AppError::BadRequest("missing 'image_id' field".to_string()))?; 315 316 info!(image_id = %image_id, size = image_bytes.len(), "scanning image"); 317 318 // Call Claude for analysis 319 let result = claude 320 .analyze_image(&image_bytes, &media_type) 321 .await 322 .map_err(|e| AppError::Claude(e.to_string()))?; 323 324 // Store scan result for cost tracking 325 db.store_image_scan( 326 &image_id, 327 result.is_safe, 328 &result.violated_categories, 329 &result.severity, 330 &result.explanation, 331 "claude-sonnet-4-5-20250929", // TODO: get from client 332 ) 333 .await?; 334 335 // If not safe, add to sensitive images 336 if !result.is_safe { 337 info!(image_id = %image_id, severity = %result.severity, "flagging sensitive image"); 338 db.add_sensitive_image( 339 Some(&image_id), 340 None, 341 Some(&result.explanation), 342 Some("claude-auto"), 343 ) 344 .await?; 345 } 346 347 Ok(Json(ScanImageResponse { 348 is_safe: result.is_safe, 349 reason: if result.is_safe { 350 None 351 } else { 352 Some(result.explanation) 353 }, 354 severity: result.severity, 355 violated_categories: result.violated_categories, 356 })) 357} 358 359#[cfg(test)] 360mod tests { 361 use super::*; 362 363 #[test] 364 fn test_normalize_score() { 365 // Integer scores (0-100) should be converted to 0.0-1.0 366 assert!((normalize_score(85.0) - 0.85).abs() < 0.001); 367 assert!((normalize_score(100.0) - 1.0).abs() < 0.001); 368 assert!((normalize_score(50.0) - 0.5).abs() < 0.001); 369 370 // Scores already in 0.0-1.0 range should stay unchanged 371 assert!((normalize_score(0.85) - 0.85).abs() < 0.001); 372 assert!((normalize_score(1.0) - 1.0).abs() < 0.001); 373 assert!((normalize_score(0.5) - 0.5).abs() < 0.001); 374 assert!((normalize_score(0.0) - 0.0).abs() < 0.001); 375 } 376}