music on atproto
plyr.fm
1//! Database operations for the labeler.
2
3use chrono::{DateTime, Utc};
4use serde::{Deserialize, Serialize};
5use sqlx::{postgres::PgPoolOptions, FromRow, PgPool};
6
7use crate::admin::FlaggedTrack;
8use crate::labels::Label;
9
10/// Sensitive image record from the database.
11#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
12pub struct SensitiveImageRow {
13 pub id: i64,
14 /// R2 storage ID (for track/album artwork)
15 pub image_id: Option<String>,
16 /// Full URL (for external images like avatars)
17 pub url: Option<String>,
18 /// Why this image was flagged
19 pub reason: Option<String>,
20 /// When the image was flagged
21 pub flagged_at: DateTime<Utc>,
22 /// Admin who flagged it
23 pub flagged_by: Option<String>,
24}
25
26/// Review batch for mobile-friendly flag review.
27#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
28pub struct ReviewBatch {
29 pub id: String,
30 pub created_at: DateTime<Utc>,
31 pub expires_at: Option<DateTime<Utc>>,
32 /// Status: pending, completed.
33 pub status: String,
34 /// Who created this batch.
35 pub created_by: Option<String>,
36}
37
38/// A flag within a review batch.
39#[derive(Debug, Clone, FromRow, Serialize, Deserialize)]
40pub struct BatchFlag {
41 pub id: i64,
42 pub batch_id: String,
43 pub uri: String,
44 pub reviewed: bool,
45 pub reviewed_at: Option<DateTime<Utc>>,
46 /// Decision: approved, rejected, or null.
47 pub decision: Option<String>,
48}
49
50/// Type alias for context row from database query.
51type ContextRow = (
52 Option<i64>, // track_id
53 Option<String>, // track_title
54 Option<String>, // artist_handle
55 Option<String>, // artist_did
56 Option<f64>, // highest_score
57 Option<serde_json::Value>, // matches
58 Option<String>, // resolution_reason
59 Option<String>, // resolution_notes
60);
61
62/// Type alias for flagged track row from database query.
63type FlaggedRow = (
64 i64, // seq
65 String, // uri
66 String, // val
67 DateTime<Utc>, // cts
68 Option<i64>, // track_id
69 Option<String>, // track_title
70 Option<String>, // artist_handle
71 Option<String>, // artist_did
72 Option<f64>, // highest_score
73 Option<serde_json::Value>, // matches
74 Option<String>, // resolution_reason
75 Option<String>, // resolution_notes
76);
77
78/// Copyright match info stored alongside labels.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct CopyrightMatch {
81 pub title: String,
82 pub artist: String,
83 pub score: f64,
84}
85
86/// Reason for resolving a false positive.
87#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
88#[serde(rename_all = "snake_case")]
89pub enum ResolutionReason {
90 /// Artist uploaded their own distributed music
91 OriginalArtist,
92 /// Artist has licensing/permission for the content
93 Licensed,
94 /// Fingerprint matcher produced a false match
95 FingerprintNoise,
96 /// Legal cover version or remix
97 CoverVersion,
98 /// Content was deleted from plyr.fm
99 ContentDeleted,
100 /// Other reason (see resolution_notes)
101 Other,
102}
103
104impl ResolutionReason {
105 /// Human-readable label for the reason.
106 pub fn label(&self) -> &'static str {
107 match self {
108 Self::OriginalArtist => "original artist",
109 Self::Licensed => "licensed",
110 Self::FingerprintNoise => "fingerprint noise",
111 Self::CoverVersion => "cover/remix",
112 Self::ContentDeleted => "content deleted",
113 Self::Other => "other",
114 }
115 }
116
117 /// Parse from string.
118 pub fn from_str(s: &str) -> Option<Self> {
119 match s {
120 "original_artist" => Some(Self::OriginalArtist),
121 "licensed" => Some(Self::Licensed),
122 "fingerprint_noise" => Some(Self::FingerprintNoise),
123 "cover_version" => Some(Self::CoverVersion),
124 "content_deleted" => Some(Self::ContentDeleted),
125 "other" => Some(Self::Other),
126 _ => None,
127 }
128 }
129}
130
131/// Context stored alongside a label for display in admin UI.
132#[derive(Debug, Clone, Serialize, Deserialize, Default)]
133pub struct LabelContext {
134 pub track_id: Option<i64>,
135 pub track_title: Option<String>,
136 pub artist_handle: Option<String>,
137 pub artist_did: Option<String>,
138 pub highest_score: Option<f64>,
139 pub matches: Option<Vec<CopyrightMatch>>,
140 /// Why the flag was resolved as false positive (set on resolution).
141 #[serde(skip_serializing_if = "Option::is_none")]
142 pub resolution_reason: Option<ResolutionReason>,
143 /// Additional notes about the resolution.
144 #[serde(skip_serializing_if = "Option::is_none")]
145 pub resolution_notes: Option<String>,
146}
147
148/// Database connection pool and operations.
149#[derive(Clone)]
150pub struct LabelDb {
151 pool: PgPool,
152}
153
154/// Stored label row from the database.
155#[derive(Debug, Clone, sqlx::FromRow)]
156pub struct LabelRow {
157 pub seq: i64,
158 pub src: String,
159 pub uri: String,
160 pub cid: Option<String>,
161 pub val: String,
162 pub neg: bool,
163 pub cts: DateTime<Utc>,
164 pub exp: Option<DateTime<Utc>>,
165 pub sig: Vec<u8>,
166}
167
168impl LabelDb {
169 /// Connect to the database.
170 pub async fn connect(database_url: &str) -> Result<Self, sqlx::Error> {
171 let pool = PgPoolOptions::new()
172 .max_connections(5)
173 .connect(database_url)
174 .await?;
175 Ok(Self { pool })
176 }
177
178 /// Run database migrations.
179 pub async fn migrate(&self) -> Result<(), sqlx::Error> {
180 sqlx::query(
181 r#"
182 CREATE TABLE IF NOT EXISTS labels (
183 id BIGSERIAL PRIMARY KEY,
184 seq BIGSERIAL UNIQUE NOT NULL,
185 src TEXT NOT NULL,
186 uri TEXT NOT NULL,
187 cid TEXT,
188 val TEXT NOT NULL,
189 neg BOOLEAN NOT NULL DEFAULT FALSE,
190 cts TIMESTAMPTZ NOT NULL,
191 exp TIMESTAMPTZ,
192 sig BYTEA NOT NULL,
193 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
194 )
195 "#,
196 )
197 .execute(&self.pool)
198 .await?;
199
200 sqlx::query("CREATE INDEX IF NOT EXISTS idx_labels_uri ON labels(uri)")
201 .execute(&self.pool)
202 .await?;
203 sqlx::query("CREATE INDEX IF NOT EXISTS idx_labels_src ON labels(src)")
204 .execute(&self.pool)
205 .await?;
206 sqlx::query("CREATE INDEX IF NOT EXISTS idx_labels_seq ON labels(seq)")
207 .execute(&self.pool)
208 .await?;
209 sqlx::query("CREATE INDEX IF NOT EXISTS idx_labels_val ON labels(val)")
210 .execute(&self.pool)
211 .await?;
212
213 // Label context table for admin UI display
214 sqlx::query(
215 r#"
216 CREATE TABLE IF NOT EXISTS label_context (
217 id BIGSERIAL PRIMARY KEY,
218 uri TEXT NOT NULL UNIQUE,
219 track_title TEXT,
220 artist_handle TEXT,
221 artist_did TEXT,
222 highest_score DOUBLE PRECISION,
223 matches JSONB,
224 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW()
225 )
226 "#,
227 )
228 .execute(&self.pool)
229 .await?;
230
231 sqlx::query("CREATE INDEX IF NOT EXISTS idx_label_context_uri ON label_context(uri)")
232 .execute(&self.pool)
233 .await?;
234
235 // Add resolution columns (migration-safe: only adds if missing)
236 sqlx::query("ALTER TABLE label_context ADD COLUMN IF NOT EXISTS resolution_reason TEXT")
237 .execute(&self.pool)
238 .await?;
239 sqlx::query("ALTER TABLE label_context ADD COLUMN IF NOT EXISTS resolution_notes TEXT")
240 .execute(&self.pool)
241 .await?;
242
243 // Sensitive images table for content moderation
244 sqlx::query(
245 r#"
246 CREATE TABLE IF NOT EXISTS sensitive_images (
247 id BIGSERIAL PRIMARY KEY,
248 image_id TEXT,
249 url TEXT,
250 reason TEXT,
251 flagged_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
252 flagged_by TEXT
253 )
254 "#,
255 )
256 .execute(&self.pool)
257 .await?;
258
259 sqlx::query("CREATE INDEX IF NOT EXISTS idx_sensitive_images_image_id ON sensitive_images(image_id)")
260 .execute(&self.pool)
261 .await?;
262 sqlx::query("CREATE INDEX IF NOT EXISTS idx_sensitive_images_url ON sensitive_images(url)")
263 .execute(&self.pool)
264 .await?;
265
266 // Image scans table for tracking automated moderation
267 sqlx::query(
268 r#"
269 CREATE TABLE IF NOT EXISTS image_scans (
270 id BIGSERIAL PRIMARY KEY,
271 image_id TEXT NOT NULL,
272 is_safe BOOLEAN NOT NULL,
273 violated_categories JSONB,
274 severity TEXT,
275 explanation TEXT,
276 scanned_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
277 model TEXT
278 )
279 "#,
280 )
281 .execute(&self.pool)
282 .await?;
283
284 sqlx::query("CREATE INDEX IF NOT EXISTS idx_image_scans_image_id ON image_scans(image_id)")
285 .execute(&self.pool)
286 .await?;
287 sqlx::query("CREATE INDEX IF NOT EXISTS idx_image_scans_is_safe ON image_scans(is_safe)")
288 .execute(&self.pool)
289 .await?;
290
291 // Review batches for mobile-friendly flag review
292 sqlx::query(
293 r#"
294 CREATE TABLE IF NOT EXISTS review_batches (
295 id TEXT PRIMARY KEY,
296 created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
297 expires_at TIMESTAMPTZ,
298 status TEXT NOT NULL DEFAULT 'pending',
299 created_by TEXT
300 )
301 "#,
302 )
303 .execute(&self.pool)
304 .await?;
305
306 // Flags within review batches
307 sqlx::query(
308 r#"
309 CREATE TABLE IF NOT EXISTS batch_flags (
310 id BIGSERIAL PRIMARY KEY,
311 batch_id TEXT NOT NULL REFERENCES review_batches(id) ON DELETE CASCADE,
312 uri TEXT NOT NULL,
313 reviewed BOOLEAN NOT NULL DEFAULT FALSE,
314 reviewed_at TIMESTAMPTZ,
315 decision TEXT,
316 UNIQUE(batch_id, uri)
317 )
318 "#,
319 )
320 .execute(&self.pool)
321 .await?;
322
323 sqlx::query("CREATE INDEX IF NOT EXISTS idx_batch_flags_batch_id ON batch_flags(batch_id)")
324 .execute(&self.pool)
325 .await?;
326
327 Ok(())
328 }
329
330 /// Store or update label context for a URI.
331 pub async fn store_context(
332 &self,
333 uri: &str,
334 context: &LabelContext,
335 ) -> Result<(), sqlx::Error> {
336 let matches_json = context
337 .matches
338 .as_ref()
339 .map(|m| serde_json::to_value(m).unwrap_or_default());
340 let reason_str = context
341 .resolution_reason
342 .map(|r| format!("{:?}", r).to_lowercase());
343
344 sqlx::query(
345 r#"
346 INSERT INTO label_context (uri, track_id, track_title, artist_handle, artist_did, highest_score, matches, resolution_reason, resolution_notes)
347 VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)
348 ON CONFLICT (uri) DO UPDATE SET
349 track_id = COALESCE(EXCLUDED.track_id, label_context.track_id),
350 track_title = COALESCE(EXCLUDED.track_title, label_context.track_title),
351 artist_handle = COALESCE(EXCLUDED.artist_handle, label_context.artist_handle),
352 artist_did = COALESCE(EXCLUDED.artist_did, label_context.artist_did),
353 highest_score = COALESCE(EXCLUDED.highest_score, label_context.highest_score),
354 matches = COALESCE(EXCLUDED.matches, label_context.matches),
355 resolution_reason = COALESCE(EXCLUDED.resolution_reason, label_context.resolution_reason),
356 resolution_notes = COALESCE(EXCLUDED.resolution_notes, label_context.resolution_notes)
357 "#,
358 )
359 .bind(uri)
360 .bind(context.track_id)
361 .bind(&context.track_title)
362 .bind(&context.artist_handle)
363 .bind(&context.artist_did)
364 .bind(context.highest_score)
365 .bind(matches_json)
366 .bind(reason_str)
367 .bind(&context.resolution_notes)
368 .execute(&self.pool)
369 .await?;
370
371 Ok(())
372 }
373
374 /// Store resolution reason for a URI (without overwriting other context).
375 pub async fn store_resolution(
376 &self,
377 uri: &str,
378 reason: ResolutionReason,
379 notes: Option<&str>,
380 ) -> Result<(), sqlx::Error> {
381 let reason_str = format!("{:?}", reason).to_lowercase();
382 sqlx::query(
383 r#"
384 INSERT INTO label_context (uri, resolution_reason, resolution_notes)
385 VALUES ($1, $2, $3)
386 ON CONFLICT (uri) DO UPDATE SET
387 resolution_reason = EXCLUDED.resolution_reason,
388 resolution_notes = EXCLUDED.resolution_notes
389 "#,
390 )
391 .bind(uri)
392 .bind(reason_str)
393 .bind(notes)
394 .execute(&self.pool)
395 .await?;
396
397 Ok(())
398 }
399
400 /// Get label context for a URI.
401 pub async fn get_context(&self, uri: &str) -> Result<Option<LabelContext>, sqlx::Error> {
402 let row: Option<ContextRow> = sqlx::query_as(
403 r#"
404 SELECT track_id, track_title, artist_handle, artist_did, highest_score, matches, resolution_reason, resolution_notes
405 FROM label_context
406 WHERE uri = $1
407 "#,
408 )
409 .bind(uri)
410 .fetch_optional(&self.pool)
411 .await?;
412
413 Ok(row.map(
414 |(
415 track_id,
416 track_title,
417 artist_handle,
418 artist_did,
419 highest_score,
420 matches,
421 resolution_reason,
422 resolution_notes,
423 )| {
424 LabelContext {
425 track_id,
426 track_title,
427 artist_handle,
428 artist_did,
429 highest_score,
430 matches: matches.and_then(|v| serde_json::from_value(v).ok()),
431 resolution_reason: resolution_reason
432 .and_then(|s| ResolutionReason::from_str(&s)),
433 resolution_notes,
434 }
435 },
436 ))
437 }
438
439 /// Store a signed label and return its sequence number.
440 pub async fn store_label(&self, label: &Label) -> Result<i64, sqlx::Error> {
441 let sig = label.sig.as_ref().map(|b| b.to_vec()).unwrap_or_default();
442 let cts: DateTime<Utc> = label.cts.parse().unwrap_or_else(|_| Utc::now());
443 let exp: Option<DateTime<Utc>> = label.exp.as_ref().and_then(|e| e.parse().ok());
444
445 let row = sqlx::query_scalar::<_, i64>(
446 r#"
447 INSERT INTO labels (src, uri, cid, val, neg, cts, exp, sig)
448 VALUES ($1, $2, $3, $4, $5, $6, $7, $8)
449 RETURNING seq
450 "#,
451 )
452 .bind(&label.src)
453 .bind(&label.uri)
454 .bind(&label.cid)
455 .bind(&label.val)
456 .bind(label.neg.unwrap_or(false))
457 .bind(cts)
458 .bind(exp)
459 .bind(sig)
460 .fetch_one(&self.pool)
461 .await?;
462
463 Ok(row)
464 }
465
466 /// Query labels matching URI patterns.
467 ///
468 /// Patterns can contain `*` as a wildcard (e.g., `at://did:plc:*`).
469 pub async fn query_labels(
470 &self,
471 uri_patterns: &[String],
472 sources: Option<&[String]>,
473 cursor: Option<&str>,
474 limit: i64,
475 ) -> Result<(Vec<LabelRow>, Option<String>), sqlx::Error> {
476 // Build dynamic query
477 let mut conditions = Vec::new();
478 let mut param_idx = 1;
479
480 // URI pattern matching
481 let uri_conditions: Vec<String> = uri_patterns
482 .iter()
483 .map(|p| {
484 let idx = param_idx;
485 param_idx += 1;
486 if p.contains('*') {
487 format!("uri LIKE ${}", idx)
488 } else {
489 format!("uri = ${}", idx)
490 }
491 })
492 .collect();
493
494 if !uri_conditions.is_empty() {
495 conditions.push(format!("({})", uri_conditions.join(" OR ")));
496 }
497
498 // Source filtering
499 if let Some(srcs) = sources {
500 if !srcs.is_empty() {
501 let placeholders: Vec<String> = srcs
502 .iter()
503 .map(|_| {
504 let idx = param_idx;
505 param_idx += 1;
506 format!("${}", idx)
507 })
508 .collect();
509 conditions.push(format!("src IN ({})", placeholders.join(", ")));
510 }
511 }
512
513 // Cursor for pagination
514 if cursor.is_some() {
515 conditions.push(format!("seq > ${}", param_idx));
516 }
517
518 let where_clause = if conditions.is_empty() {
519 String::new()
520 } else {
521 format!("WHERE {}", conditions.join(" AND "))
522 };
523
524 let query = format!(
525 r#"
526 SELECT seq, src, uri, cid, val, neg, cts, exp, sig
527 FROM labels
528 {}
529 ORDER BY seq ASC
530 LIMIT {}
531 "#,
532 where_clause,
533 limit + 1 // Fetch one extra to determine if there's more
534 );
535
536 // Build query with parameters
537 let mut q = sqlx::query_as::<_, LabelRow>(&query);
538
539 // Bind URI patterns (converting * to %)
540 for pattern in uri_patterns {
541 let sql_pattern = pattern.replace('*', "%");
542 q = q.bind(sql_pattern);
543 }
544
545 // Bind sources
546 if let Some(srcs) = sources {
547 for src in srcs {
548 q = q.bind(src);
549 }
550 }
551
552 // Bind cursor
553 if let Some(c) = cursor {
554 let cursor_seq: i64 = c.parse().unwrap_or(0);
555 q = q.bind(cursor_seq);
556 }
557
558 let mut rows: Vec<LabelRow> = q.fetch_all(&self.pool).await?;
559
560 // Determine next cursor
561 let next_cursor = if rows.len() > limit as usize {
562 rows.pop(); // Remove the extra row
563 rows.last().map(|r| r.seq.to_string())
564 } else {
565 None
566 };
567
568 Ok((rows, next_cursor))
569 }
570
571 /// Get labels since a sequence number (for subscribeLabels).
572 pub async fn get_labels_since(
573 &self,
574 cursor: i64,
575 limit: i64,
576 ) -> Result<Vec<LabelRow>, sqlx::Error> {
577 sqlx::query_as::<_, LabelRow>(
578 r#"
579 SELECT seq, src, uri, cid, val, neg, cts, exp, sig
580 FROM labels
581 WHERE seq > $1
582 ORDER BY seq ASC
583 LIMIT $2
584 "#,
585 )
586 .bind(cursor)
587 .bind(limit)
588 .fetch_all(&self.pool)
589 .await
590 }
591
592 /// Get the latest sequence number.
593 pub async fn get_latest_seq(&self) -> Result<i64, sqlx::Error> {
594 sqlx::query_scalar::<_, Option<i64>>("SELECT MAX(seq) FROM labels")
595 .fetch_one(&self.pool)
596 .await
597 .map(|s| s.unwrap_or(0))
598 }
599
600 /// Get URIs that have active (non-negated) copyright-violation labels.
601 ///
602 /// For each URI, checks if there's a negation label. Returns only those
603 /// that are still actively flagged.
604 pub async fn get_active_labels(&self, uris: &[String]) -> Result<Vec<String>, sqlx::Error> {
605 if uris.is_empty() {
606 return Ok(Vec::new());
607 }
608
609 // Get all negated URIs from our input set
610 let negated_uris: std::collections::HashSet<String> = sqlx::query_scalar::<_, String>(
611 r#"
612 SELECT DISTINCT uri
613 FROM labels
614 WHERE val = 'copyright-violation' AND neg = true AND uri = ANY($1)
615 "#,
616 )
617 .bind(uris)
618 .fetch_all(&self.pool)
619 .await?
620 .into_iter()
621 .collect();
622
623 // Get URIs that have a positive label and are not negated
624 let active_uris: Vec<String> = sqlx::query_scalar::<_, String>(
625 r#"
626 SELECT DISTINCT uri
627 FROM labels
628 WHERE val = 'copyright-violation' AND neg = false AND uri = ANY($1)
629 "#,
630 )
631 .bind(uris)
632 .fetch_all(&self.pool)
633 .await?
634 .into_iter()
635 .filter(|uri| !negated_uris.contains(uri))
636 .collect();
637
638 Ok(active_uris)
639 }
640
641 /// Get all copyright-violation labels with their resolution status and context.
642 ///
643 /// A label is resolved if there's a negation label for the same uri+val.
644 pub async fn get_pending_flags(&self) -> Result<Vec<FlaggedTrack>, sqlx::Error> {
645 // Get all copyright-violation labels with context via LEFT JOIN
646 let rows: Vec<FlaggedRow> = sqlx::query_as(
647 r#"
648 SELECT l.seq, l.uri, l.val, l.cts,
649 c.track_id, c.track_title, c.artist_handle, c.artist_did, c.highest_score, c.matches,
650 c.resolution_reason, c.resolution_notes
651 FROM labels l
652 LEFT JOIN label_context c ON l.uri = c.uri
653 WHERE l.val = 'copyright-violation' AND l.neg = false
654 ORDER BY l.seq DESC
655 "#,
656 )
657 .fetch_all(&self.pool)
658 .await?;
659
660 // Get all negation labels
661 let negated_uris: std::collections::HashSet<String> = sqlx::query_scalar::<_, String>(
662 r#"
663 SELECT DISTINCT uri
664 FROM labels
665 WHERE val = 'copyright-violation' AND neg = true
666 "#,
667 )
668 .fetch_all(&self.pool)
669 .await?
670 .into_iter()
671 .collect();
672
673 let tracks = rows
674 .into_iter()
675 .map(
676 |(
677 seq,
678 uri,
679 val,
680 cts,
681 track_id,
682 track_title,
683 artist_handle,
684 artist_did,
685 highest_score,
686 matches,
687 resolution_reason,
688 resolution_notes,
689 )| {
690 let context = if track_id.is_some()
691 || track_title.is_some()
692 || artist_handle.is_some()
693 || resolution_reason.is_some()
694 {
695 Some(LabelContext {
696 track_id,
697 track_title,
698 artist_handle,
699 artist_did,
700 highest_score,
701 matches: matches.and_then(|v| serde_json::from_value(v).ok()),
702 resolution_reason: resolution_reason
703 .and_then(|s| ResolutionReason::from_str(&s)),
704 resolution_notes,
705 })
706 } else {
707 None
708 };
709
710 FlaggedTrack {
711 seq,
712 uri: uri.clone(),
713 val,
714 created_at: cts.format("%Y-%m-%d %H:%M:%S").to_string(),
715 resolved: negated_uris.contains(&uri),
716 context,
717 }
718 },
719 )
720 .collect();
721
722 Ok(tracks)
723 }
724
725 // -------------------------------------------------------------------------
726 // Review batches
727 // -------------------------------------------------------------------------
728
729 /// Create a review batch with the given flags.
730 pub async fn create_batch(
731 &self,
732 id: &str,
733 uris: &[String],
734 created_by: Option<&str>,
735 ) -> Result<ReviewBatch, sqlx::Error> {
736 let batch = sqlx::query_as::<_, ReviewBatch>(
737 r#"
738 INSERT INTO review_batches (id, created_by)
739 VALUES ($1, $2)
740 RETURNING id, created_at, expires_at, status, created_by
741 "#,
742 )
743 .bind(id)
744 .bind(created_by)
745 .fetch_one(&self.pool)
746 .await?;
747
748 for uri in uris {
749 sqlx::query(
750 r#"
751 INSERT INTO batch_flags (batch_id, uri)
752 VALUES ($1, $2)
753 ON CONFLICT (batch_id, uri) DO NOTHING
754 "#,
755 )
756 .bind(id)
757 .bind(uri)
758 .execute(&self.pool)
759 .await?;
760 }
761
762 Ok(batch)
763 }
764
765 /// Get a batch by ID.
766 pub async fn get_batch(&self, id: &str) -> Result<Option<ReviewBatch>, sqlx::Error> {
767 sqlx::query_as::<_, ReviewBatch>(
768 r#"
769 SELECT id, created_at, expires_at, status, created_by
770 FROM review_batches
771 WHERE id = $1
772 "#,
773 )
774 .bind(id)
775 .fetch_optional(&self.pool)
776 .await
777 }
778
779 /// Get all flags in a batch with their context.
780 pub async fn get_batch_flags(&self, batch_id: &str) -> Result<Vec<FlaggedTrack>, sqlx::Error> {
781 let rows: Vec<FlaggedRow> = sqlx::query_as(
782 r#"
783 SELECT l.seq, l.uri, l.val, l.cts,
784 c.track_id, c.track_title, c.artist_handle, c.artist_did, c.highest_score, c.matches,
785 c.resolution_reason, c.resolution_notes
786 FROM batch_flags bf
787 JOIN labels l ON l.uri = bf.uri AND l.val = 'copyright-violation' AND l.neg = false
788 LEFT JOIN label_context c ON l.uri = c.uri
789 WHERE bf.batch_id = $1
790 ORDER BY l.seq DESC
791 "#,
792 )
793 .bind(batch_id)
794 .fetch_all(&self.pool)
795 .await?;
796
797 let batch_uris: Vec<String> = rows.iter().map(|r| r.1.clone()).collect();
798 let negated_uris: std::collections::HashSet<String> = if !batch_uris.is_empty() {
799 sqlx::query_scalar::<_, String>(
800 r#"
801 SELECT DISTINCT uri
802 FROM labels
803 WHERE val = 'copyright-violation' AND neg = true AND uri = ANY($1)
804 "#,
805 )
806 .bind(&batch_uris)
807 .fetch_all(&self.pool)
808 .await?
809 .into_iter()
810 .collect()
811 } else {
812 std::collections::HashSet::new()
813 };
814
815 let tracks = rows
816 .into_iter()
817 .map(
818 |(
819 seq,
820 uri,
821 val,
822 cts,
823 track_id,
824 track_title,
825 artist_handle,
826 artist_did,
827 highest_score,
828 matches,
829 resolution_reason,
830 resolution_notes,
831 )| {
832 let context = if track_id.is_some()
833 || track_title.is_some()
834 || artist_handle.is_some()
835 || resolution_reason.is_some()
836 {
837 Some(LabelContext {
838 track_id,
839 track_title,
840 artist_handle,
841 artist_did,
842 highest_score,
843 matches: matches.and_then(|v| serde_json::from_value(v).ok()),
844 resolution_reason: resolution_reason
845 .and_then(|s| ResolutionReason::from_str(&s)),
846 resolution_notes,
847 })
848 } else {
849 None
850 };
851
852 FlaggedTrack {
853 seq,
854 uri: uri.clone(),
855 val,
856 created_at: cts.format("%Y-%m-%d %H:%M:%S").to_string(),
857 resolved: negated_uris.contains(&uri),
858 context,
859 }
860 },
861 )
862 .collect();
863
864 Ok(tracks)
865 }
866
867 /// Update batch status.
868 pub async fn update_batch_status(&self, id: &str, status: &str) -> Result<bool, sqlx::Error> {
869 let result = sqlx::query("UPDATE review_batches SET status = $1 WHERE id = $2")
870 .bind(status)
871 .bind(id)
872 .execute(&self.pool)
873 .await?;
874 Ok(result.rows_affected() > 0)
875 }
876
877 /// Mark a flag in a batch as reviewed.
878 pub async fn mark_flag_reviewed(
879 &self,
880 batch_id: &str,
881 uri: &str,
882 decision: &str,
883 ) -> Result<bool, sqlx::Error> {
884 let result = sqlx::query(
885 r#"
886 UPDATE batch_flags
887 SET reviewed = true, reviewed_at = NOW(), decision = $1
888 WHERE batch_id = $2 AND uri = $3
889 "#,
890 )
891 .bind(decision)
892 .bind(batch_id)
893 .bind(uri)
894 .execute(&self.pool)
895 .await?;
896 Ok(result.rows_affected() > 0)
897 }
898
899 /// Get pending (non-reviewed) flags from a batch.
900 pub async fn get_batch_pending_uris(&self, batch_id: &str) -> Result<Vec<String>, sqlx::Error> {
901 sqlx::query_scalar::<_, String>(
902 r#"
903 SELECT uri FROM batch_flags
904 WHERE batch_id = $1 AND reviewed = false
905 "#,
906 )
907 .bind(batch_id)
908 .fetch_all(&self.pool)
909 .await
910 }
911
912 // -------------------------------------------------------------------------
913 // Sensitive images
914 // -------------------------------------------------------------------------
915
916 /// Get all sensitive images.
917 pub async fn get_sensitive_images(&self) -> Result<Vec<SensitiveImageRow>, sqlx::Error> {
918 sqlx::query_as::<_, SensitiveImageRow>(
919 "SELECT id, image_id, url, reason, flagged_at, flagged_by FROM sensitive_images ORDER BY flagged_at DESC",
920 )
921 .fetch_all(&self.pool)
922 .await
923 }
924
925 /// Add a sensitive image entry.
926 pub async fn add_sensitive_image(
927 &self,
928 image_id: Option<&str>,
929 url: Option<&str>,
930 reason: Option<&str>,
931 flagged_by: Option<&str>,
932 ) -> Result<i64, sqlx::Error> {
933 sqlx::query_scalar::<_, i64>(
934 r#"
935 INSERT INTO sensitive_images (image_id, url, reason, flagged_by)
936 VALUES ($1, $2, $3, $4)
937 RETURNING id
938 "#,
939 )
940 .bind(image_id)
941 .bind(url)
942 .bind(reason)
943 .bind(flagged_by)
944 .fetch_one(&self.pool)
945 .await
946 }
947
948 /// Remove a sensitive image entry by ID.
949 pub async fn remove_sensitive_image(&self, id: i64) -> Result<bool, sqlx::Error> {
950 let result = sqlx::query("DELETE FROM sensitive_images WHERE id = $1")
951 .bind(id)
952 .execute(&self.pool)
953 .await?;
954 Ok(result.rows_affected() > 0)
955 }
956
957 // -------------------------------------------------------------------------
958 // Image scans
959 // -------------------------------------------------------------------------
960
961 /// Store an image scan result.
962 pub async fn store_image_scan(
963 &self,
964 image_id: &str,
965 is_safe: bool,
966 violated_categories: &[String],
967 severity: &str,
968 explanation: &str,
969 model: &str,
970 ) -> Result<i64, sqlx::Error> {
971 let categories_json = serde_json::to_value(violated_categories).unwrap_or_default();
972 sqlx::query_scalar::<_, i64>(
973 r#"
974 INSERT INTO image_scans (image_id, is_safe, violated_categories, severity, explanation, model)
975 VALUES ($1, $2, $3, $4, $5, $6)
976 RETURNING id
977 "#,
978 )
979 .bind(image_id)
980 .bind(is_safe)
981 .bind(categories_json)
982 .bind(severity)
983 .bind(explanation)
984 .bind(model)
985 .fetch_one(&self.pool)
986 .await
987 }
988
989 /// Get image scan stats for cost tracking.
990 pub async fn get_image_scan_stats(&self) -> Result<ImageScanStats, sqlx::Error> {
991 let row: (i64, i64, i64) = sqlx::query_as(
992 r#"
993 SELECT
994 COUNT(*) as total,
995 COUNT(*) FILTER (WHERE is_safe = true) as safe,
996 COUNT(*) FILTER (WHERE is_safe = false) as flagged
997 FROM image_scans
998 "#,
999 )
1000 .fetch_one(&self.pool)
1001 .await?;
1002
1003 Ok(ImageScanStats {
1004 total: row.0,
1005 safe: row.1,
1006 flagged: row.2,
1007 })
1008 }
1009}
1010
1011/// Statistics for image scans.
1012#[derive(Debug, Clone, Serialize, Deserialize)]
1013pub struct ImageScanStats {
1014 pub total: i64,
1015 pub safe: i64,
1016 pub flagged: i64,
1017}
1018
1019impl LabelRow {
1020 /// Convert database row to Label struct.
1021 pub fn to_label(&self) -> Label {
1022 Label {
1023 ver: Some(1),
1024 src: self.src.clone(),
1025 uri: self.uri.clone(),
1026 cid: self.cid.clone(),
1027 val: self.val.clone(),
1028 neg: if self.neg { Some(true) } else { None },
1029 cts: self.cts.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string(),
1030 exp: self
1031 .exp
1032 .map(|e| e.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string()),
1033 sig: Some(bytes::Bytes::from(self.sig.clone())),
1034 }
1035 }
1036}
1037
1038#[cfg(test)]
1039mod tests {
1040 use super::*;
1041
1042 #[test]
1043 fn test_resolution_reason_from_str() {
1044 assert_eq!(
1045 ResolutionReason::from_str("original_artist"),
1046 Some(ResolutionReason::OriginalArtist)
1047 );
1048 assert_eq!(
1049 ResolutionReason::from_str("licensed"),
1050 Some(ResolutionReason::Licensed)
1051 );
1052 assert_eq!(
1053 ResolutionReason::from_str("fingerprint_noise"),
1054 Some(ResolutionReason::FingerprintNoise)
1055 );
1056 assert_eq!(
1057 ResolutionReason::from_str("cover_version"),
1058 Some(ResolutionReason::CoverVersion)
1059 );
1060 assert_eq!(
1061 ResolutionReason::from_str("other"),
1062 Some(ResolutionReason::Other)
1063 );
1064 assert_eq!(ResolutionReason::from_str("invalid"), None);
1065 }
1066
1067 #[test]
1068 fn test_resolution_reason_labels() {
1069 assert_eq!(ResolutionReason::OriginalArtist.label(), "original artist");
1070 assert_eq!(ResolutionReason::Licensed.label(), "licensed");
1071 assert_eq!(
1072 ResolutionReason::FingerprintNoise.label(),
1073 "fingerprint noise"
1074 );
1075 assert_eq!(ResolutionReason::CoverVersion.label(), "cover/remix");
1076 assert_eq!(ResolutionReason::Other.label(), "other");
1077 }
1078
1079 #[test]
1080 fn test_label_context_default() {
1081 let ctx = LabelContext::default();
1082 assert!(ctx.track_title.is_none());
1083 assert!(ctx.resolution_reason.is_none());
1084 assert!(ctx.resolution_notes.is_none());
1085 }
1086}