this repo has no description
1import logging
2import random
3from io import BytesIO
4from pathlib import Path
5from typing import Any, Dict, List, Optional, Set, Tuple, Union
6
7import click
8import matplotlib.pyplot as plt
9import numpy as np
10import requests
11import seaborn as sns
12import torch.nn as nn
13import torch.optim as optim
14from PIL import Image
15from sklearn.metrics import classification_report, confusion_matrix
16from torch import torch
17from torch.utils.data import DataLoader, Dataset, random_split
18from torchvision import transforms
19from tqdm import tqdm
20from transformers import ViTForImageClassification, ViTImageProcessor
21
22from config import CONFIG, Config
23
24logging.basicConfig(
25 level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
26)
27logger = logging.getLogger()
28
29
30class NSFWDataset(Dataset[torch.Tensor]):
31 def __init__(
32 self,
33 image_paths: List[str],
34 labels: List[int],
35 processor: ViTImageProcessor,
36 augment: bool = False,
37 ):
38 self.image_paths = image_paths
39 self.labels = labels
40 self.processor = processor
41
42 if augment:
43 self.transform = transforms.Compose(
44 [
45 transforms.RandomHorizontalFlip(p=0.5),
46 transforms.RandomRotation(15),
47 transforms.ColorJitter(
48 brightness=0.2, contrast=0.2, saturation=0.2
49 ),
50 transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
51 ]
52 )
53 else:
54 self.transform = transforms.Compose(
55 [
56 transforms.Resize((224, 224)),
57 ]
58 )
59
60 def __len__(self) -> int:
61 return len(self.image_paths)
62
63 def __getitem__(self, idx: int) -> Optional[Tuple[torch.Tensor, int]]:
64 img_path = self.image_paths[idx]
65
66 try:
67 image: Image.Image = Image.open(img_path).convert("RGB")
68 image = self.transform(image)
69
70 processed = self.processor(images=image, return_tensors="pt")
71
72 pixel_values = processed["pixel_values"].squeeze(0)
73
74 label: int = self.labels[idx]
75
76 return pixel_values, label
77 except Exception as e:
78 logger.error(f"Error loading image {img_path}: {e}")
79
80
81def load_image_paths_and_labels(
82 nsfw_dir: str,
83 sfw_human_dir: str,
84 sfw_anime_dir: str,
85) -> Tuple[List[str], List[int]]:
86 nsfw_paths: List[str] = []
87 sfw_paths: List[str] = []
88
89 # false positives for nsfw
90 override_paths: List[str] = []
91
92 valid_extensions: Set[str] = {".jpg", ".jpeg", ".png", ".webp"}
93
94 logger.info(f"Loading NSFW images from {nsfw_dir}")
95 nsfw_path = Path(nsfw_dir)
96 if nsfw_path.exists():
97 for f in nsfw_path.iterdir():
98 if f.suffix.lower() in valid_extensions:
99 nsfw_paths.append(str(f))
100 else:
101 logger.warning(f"NSFW directory not found: {nsfw_dir}")
102
103 logger.info(f"Loading SFW human images from {sfw_human_dir}")
104 sfw_humans_path = Path(sfw_human_dir)
105 if sfw_humans_path.exists():
106 for f in sfw_humans_path.iterdir():
107 if f.suffix.lower() in valid_extensions:
108 sfw_paths.append(str(f))
109 else:
110 logger.warning(f"SFW humans directory not found: {sfw_human_dir}")
111
112 logger.info(f"Loading SFW anime images from {sfw_anime_dir}")
113 sfw_anime_path = Path(sfw_anime_dir)
114 if sfw_anime_path.exists():
115 for f in sfw_anime_path.iterdir():
116 if f.suffix.lower() in valid_extensions:
117 sfw_paths.append(str(f))
118 else:
119 logger.warning(f"SFW anime directory not found: {sfw_anime_dir}")
120
121 nsfw_count = len(nsfw_paths)
122 sfw_count = len(sfw_paths)
123
124 logger.info("Loading overrides (NSFW false positives)")
125 override_path = Path("./dataset_clean/overrides")
126 if override_path.exists():
127 for f in override_path.iterdir():
128 if f.suffix.lower() in valid_extensions:
129 override_paths.append(str(f))
130 else:
131 logger.warning("No override path found")
132
133 logger.info(f"Dataset loaded before balancing: {nsfw_count} NSFW, {sfw_count} SFW")
134
135 if nsfw_count > 0 and sfw_count > 0:
136 current_ratio = nsfw_count / sfw_count
137 logger.info(f"Current NSFW/SFW ratio: {current_ratio:.2f}")
138
139 needs_balancing = current_ratio < 0.5 or current_ratio > 2.0
140
141 target_ratio = 1.2
142
143 if needs_balancing:
144 if current_ratio > target_ratio:
145 target_nsfw = int(sfw_count * target_ratio)
146 target_sfw = sfw_count
147 logger.info(f"Downsampling NSFW: {nsfw_count} → {target_nsfw}")
148
149 random.seed(42)
150 nsfw_paths = random.sample(nsfw_paths, target_nsfw)
151 else:
152 target_nsfw = nsfw_count
153 target_sfw = int(nsfw_count / target_ratio)
154 logger.info(f"Downsampling SFW: {sfw_count} → {target_sfw}")
155
156 random.seed(42)
157 sfw_paths = random.sample(sfw_paths, target_sfw)
158
159 nsfw_count = len(nsfw_paths)
160 sfw_count = len(sfw_paths)
161 new_ratio = nsfw_count / sfw_count
162
163 logger.info(f"Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW")
164 logger.info(f"New NSFW/SFW ratio: {new_ratio:.2f}")
165
166 image_paths: List[str] = []
167 labels: List[int] = []
168
169 for path in nsfw_paths:
170 image_paths.append(path)
171 labels.append(1)
172
173 for path in sfw_paths:
174 image_paths.append(path)
175 labels.append(0)
176
177 # Add in any overrides so they will show up even after balancing
178 for path in override_paths:
179 image_paths.append(path)
180 labels.append(0)
181
182 combined = list(zip(image_paths, labels))
183 random.shuffle(combined)
184 image_paths, labels = zip(*combined)
185 image_paths = list(image_paths)
186 labels = list(labels)
187
188 logger.info(f"Final dataset: {len(image_paths)} total images")
189
190 return image_paths, labels
191
192
193def create_dataloaders(
194 image_paths: List[str],
195 labels: List[int],
196 processor: ViTImageProcessor,
197 config: Config,
198 augment: bool,
199) -> Tuple[Dataset[torch.Tensor], Dataset[torch.Tensor], Dataset[torch.Tensor]]:
200 total_size = len(image_paths)
201 train_size = int(config.train_ratio * total_size)
202 val_size = int(config.val_ratio * total_size)
203 test_size = total_size - train_size - val_size
204
205 # create our splits without actually augmenting anything
206 full_dataset = NSFWDataset(image_paths, labels, processor, augment=False)
207
208 train_dataset_temp, val_dataset_temp, test_dataset_temp = random_split(
209 full_dataset,
210 [train_size, val_size, test_size],
211 generator=torch.Generator().manual_seed(42),
212 )
213
214 train_indicies = train_dataset_temp.indices
215 val_indicies = val_dataset_temp.indices
216 test_indicies = test_dataset_temp.indices
217
218 train_paths = [image_paths[i] for i in train_indicies]
219 train_labels = [labels[i] for i in train_indicies]
220 train_dataset = NSFWDataset(train_paths, train_labels, processor, augment)
221
222 val_paths = [image_paths[i] for i in val_indicies]
223 val_labels = [labels[i] for i in val_indicies]
224 val_dataset = NSFWDataset(val_paths, val_labels, processor, augment=False)
225
226 test_paths = [image_paths[i] for i in test_indicies]
227 test_labels = [labels[i] for i in test_indicies]
228 test_dataset = NSFWDataset(test_paths, test_labels, processor, augment=False)
229
230 pin_memory = config.device.type == "cuda"
231
232 train_loader = DataLoader[torch.Tensor](
233 train_dataset,
234 batch_size=config.batch_size,
235 shuffle=True,
236 num_workers=config.num_workers,
237 pin_memory=pin_memory,
238 )
239
240 val_loader = DataLoader[torch.Tensor](
241 val_dataset,
242 batch_size=config.batch_size,
243 shuffle=False,
244 num_workers=config.num_workers,
245 pin_memory=pin_memory,
246 )
247
248 test_loader = DataLoader[torch.Tensor](
249 test_dataset,
250 batch_size=config.batch_size,
251 shuffle=False,
252 num_workers=config.num_workers,
253 pin_memory=pin_memory,
254 )
255
256 return train_loader, val_loader, test_loader
257
258
259def create_model(config: Config) -> ViTForImageClassification:
260 logger.info(f"Loading pretrained ViT model: {config.model_name}")
261
262 model = ViTForImageClassification.from_pretrained(
263 config.model_name,
264 num_labels=config.num_classes,
265 ignore_mismatched_sizes=True,
266 )
267
268 model = model.to(config.device)
269
270 return model
271
272
273def train_epoch(
274 model: ViTForImageClassification,
275 train_loader: DataLoader[torch.Tensor],
276 criterion: nn.Module,
277 optimizer: optim.Optimizer,
278 device: torch.device,
279):
280 model.train()
281 running_loss: float = 0.0
282 correct: int = 0
283 total: int = 0
284
285 progress_bar = tqdm(train_loader, desc="Training")
286
287 for batch_idx, (images, labels) in enumerate(progress_bar):
288 images: torch.Tensor = images.to(device)
289 labels: torch.Tensor = labels.to(device)
290
291 optimizer.zero_grad()
292
293 outputs = model(pixel_values=images)
294 logits: torch.Tensor = outputs.logits
295
296 loss: torch.Tensor = criterion(logits, labels)
297
298 loss.backward()
299 optimizer.step()
300
301 running_loss += loss.item()
302 _, predicted = torch.max(logits, 1)
303 total += labels.size(0)
304 correct += int((predicted == labels).sum().item())
305
306 progress_bar.set_postfix(
307 {"loss": running_loss / (batch_idx + 1), "acc": 100.0 * correct / total}
308 )
309
310 epoch_loss: float = running_loss / len(train_loader)
311 epoch_acc: float = 100.0 * correct / total
312
313 return epoch_loss, epoch_acc
314
315
316def validate(
317 model: ViTForImageClassification,
318 val_loader: DataLoader[torch.Tensor],
319 criterion: nn.Module,
320 device: torch.device,
321):
322 model.eval()
323 running_loss: float = 0.0
324 correct: int = 0
325 total: int = 0
326
327 all_preds: List[int] = []
328 all_labels: List[int] = []
329
330 with torch.no_grad():
331 progress_bar = tqdm(val_loader, desc="Validation")
332
333 for images, labels in progress_bar:
334 images: torch.Tensor = images.to(device)
335 labels: torch.Tensor = labels.to(device)
336
337 outputs = model(pixel_values=images)
338 logits: torch.Tensor = outputs.logits
339
340 loss: torch.Tensor = criterion(logits, labels)
341
342 running_loss += loss.item()
343 _, predicted = torch.max(logits, 1)
344 total += labels.size(0)
345 correct += int((predicted == labels).sum().item())
346
347 all_preds.extend(predicted.cpu().numpy().tolist())
348 all_labels.extend(labels.cpu().numpy().tolist())
349
350 progress_bar.set_postfix(
351 {"loss": running_loss / len(val_loader), "acc": 100.0 * correct / total}
352 )
353
354 epoch_loss: float = running_loss / len(val_loader)
355 epoch_acc: float = 100.0 * correct / total
356
357 return epoch_loss, epoch_acc, all_preds, all_labels
358
359
360def train_model(
361 model: ViTForImageClassification,
362 train_loader: DataLoader[torch.Tensor],
363 val_loader: DataLoader[torch.Tensor],
364 config: Config,
365):
366 criterion = nn.CrossEntropyLoss()
367 optimizer = optim.AdamW(
368 model.parameters(),
369 lr=config.learning_rate,
370 weight_decay=config.weight_decay,
371 )
372 scheduler = optim.lr_scheduler.ReduceLROnPlateau(
373 optimizer,
374 mode="min",
375 factor=0.5,
376 patience=2,
377 )
378
379 best_val_acc = 0.0
380 best_model_state: Optional[Dict[str, Any]] = None
381
382 history: Dict[str, List[float]] = {
383 "train_loss": [],
384 "train_acc": [],
385 "val_loss": [],
386 "val_acc": [],
387 }
388
389 for epoch in range(config.num_epochs):
390 logger.info(f"Epoch {epoch + 1}/{config.num_epochs}")
391 print("-" * 60)
392
393 train_loss, train_acc = train_epoch(
394 model,
395 train_loader,
396 criterion,
397 optimizer,
398 config.device,
399 )
400
401 val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device)
402
403 scheduler.step(val_loss)
404
405 history["train_loss"].append(train_loss)
406 history["train_acc"].append(train_acc)
407 history["val_loss"].append(val_loss)
408 history["val_acc"].append(val_acc)
409
410 print(f"\nEpoch {epoch + 1} Summary:")
411 print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}")
412 print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}")
413
414 if val_acc > best_val_acc:
415 best_val_acc = val_acc
416 best_model_state = model.state_dict().copy()
417 print(f"New best model! (Val Acc: {val_acc:2f}%)")
418
419 if best_model_state is not None:
420 model.load_state_dict(best_model_state)
421
422 logger.info(f"Training complete. Best validation accuracy: {best_val_acc:.2f}%")
423
424 return model, history
425
426
427def save_model(
428 model: ViTForImageClassification,
429 processor: ViTImageProcessor,
430 config: Config,
431):
432 Path(config.model_save_path).mkdir(parents=True, exist_ok=True)
433
434 model.save_pretrained(config.model_save_path)
435 processor.save_pretrained(config.model_save_path)
436
437 print(f"Model saved to {config.model_save_path}")
438
439
440def load_trained_model(
441 model_path: str, device: torch.device
442) -> Tuple[ViTForImageClassification, ViTImageProcessor]:
443 model = ViTForImageClassification.from_pretrained(model_path)
444 processor = ViTImageProcessor.from_pretrained(model_path)
445 model = model.to(device)
446 model.eval()
447 return model, processor
448
449
450def predict_image(
451 image_path: str,
452 model: ViTForImageClassification,
453 processor: ViTImageProcessor,
454 device: torch.device,
455 config: Config,
456) -> Tuple[str, float]:
457 model.eval()
458
459 image: Image.Image = Image.open(image_path).convert("RGB")
460 inputs: Dict[str, torch.Tensor] = processor(images=image, return_tensors="pt")
461 pixel_values: torch.Tensor = inputs["pixel_values"].to(device)
462
463 with torch.no_grad():
464 outputs = model(pixel_values=pixel_values)
465 logits: torch.Tensor = outputs.logits
466 probs: torch.Tensor = torch.softmax(logits, dim=1)
467 predicted_class: int = torch.argmax(probs, dim=1).item()
468 confidence: float = probs[0, predicted_class].item()
469
470 prediction: str = config.class_names[predicted_class]
471
472 return prediction, confidence
473
474
475def evaluate_model(
476 model: ViTForImageClassification,
477 test_loader: DataLoader[torch.Tensor],
478 config: Config,
479) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
480 model.eval()
481 all_preds: List[int] = []
482 all_labels: List[int] = []
483 all_probs: List[List[float]] = []
484
485 print("\n" + "=" * 60)
486 print("Evaluating on Test Set")
487 print("=" * 60)
488
489 with torch.no_grad():
490 for images, labels in tqdm(test_loader, desc="Testing"):
491 images: torch.Tensor = images.to(config.device)
492
493 outputs = model(pixel_values=images)
494 logits: torch.Tensor = outputs.logits
495
496 probs: torch.Tensor = torch.softmax(logits, dim=1)
497 _, predicted = torch.max(logits, 1)
498
499 all_preds.extend(predicted.cpu().numpy().tolist())
500 all_labels.extend(labels.numpy().tolist())
501 all_probs.extend(probs.cpu().numpy().tolist())
502
503 all_preds_array: np.ndarray = np.array(all_preds)
504 all_labels_array: np.ndarray = np.array(all_labels)
505 all_probs_array: np.ndarray = np.array(all_probs)
506
507 print("\nClassification Report:")
508 print(
509 classification_report(
510 all_labels_array, all_preds_array, target_names=config.class_names, digits=4
511 )
512 )
513
514 cm: np.ndarray = confusion_matrix(all_labels_array, all_preds_array)
515
516 plt.figure(figsize=(8, 6))
517 sns.heatmap(
518 cm,
519 annot=True,
520 fmt="d",
521 cmap="Blues",
522 xticklabels=config.class_names,
523 yticklabels=config.class_names,
524 )
525 plt.title("Confusion Matrix")
526 plt.ylabel("True Label")
527 plt.xlabel("Predicted Label")
528 plt.tight_layout()
529 plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight")
530 print("\n✓ Confusion matrix saved as 'confusion_matrix.png'")
531
532 return all_preds_array, all_labels_array, all_probs_array
533
534
535def plot_training_history(history: Dict[str, List[float]]) -> None:
536 fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
537
538 ax1.plot(history["train_loss"], label="Train Loss", marker="o")
539 ax1.plot(history["val_loss"], label="Val Loss", marker="s")
540 ax1.set_xlabel("Epoch")
541 ax1.set_ylabel("Loss")
542 ax1.set_title("Training and Validation Loss")
543 ax1.legend()
544 ax1.grid(True)
545
546 ax2.plot(history["train_acc"], label="Train Acc", marker="o")
547 ax2.plot(history["val_acc"], label="Val Acc", marker="s")
548 ax2.set_xlabel("Epoch")
549 ax2.set_ylabel("Accuracy (%)")
550 ax2.set_title("Training and Validation Accuracy")
551 ax2.legend()
552 ax2.grid(True)
553
554 plt.tight_layout()
555 plt.savefig("training_history.png", dpi=300, bbox_inches="tight")
556 print("✓ Training history saved as 'training_history.png'")
557
558
559class NSFWClassifier:
560 def __init__(self, model_path: str, device: Optional[torch.device] = None) -> None:
561 if device is None:
562 self.device: torch.device = torch.device(
563 "cuda" if torch.cuda.is_available() else "cpu"
564 )
565 else:
566 self.device = device
567
568 print(f"Loading model from {model_path}...")
569 self.model: ViTForImageClassification = (
570 ViTForImageClassification.from_pretrained(model_path)
571 )
572 self.processor: ViTImageProcessor = ViTImageProcessor.from_pretrained(
573 model_path
574 )
575 self.model = self.model.to(self.device)
576 self.model.eval()
577
578 self.class_names: List[str] = ["SFW", "NSFW"]
579 print(f"✓ Model loaded on {self.device}")
580
581 def predict_single(self, image_path: str) -> Dict[str, Union[str, float]]:
582 image: Image.Image = Image.open(image_path).convert("RGB")
583
584 inputs: Dict[str, torch.Tensor] = self.processor(
585 images=image, return_tensors="pt"
586 )
587 pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device)
588
589 with torch.no_grad():
590 outputs = self.model(pixel_values=pixel_values)
591 logits: torch.Tensor = outputs.logits
592 probs: torch.Tensor = torch.softmax(logits, dim=1)[0]
593 predicted_class: int = torch.argmax(probs).item()
594
595 return {
596 "prediction": self.class_names[predicted_class],
597 "confidence": probs[predicted_class].item(),
598 "sfw_prob": probs[0].item(),
599 "nsfw_prob": probs[1].item(),
600 }
601
602 def predict_single_bytes(self, image_bytes: bytes) -> Dict[str, Union[str, float]]:
603 image = Image.open(BytesIO(image_bytes)).convert("RGB")
604 inputs: Dict[str, torch.Tensor] = self.processor(
605 images=image, return_tensors="pt"
606 )
607 pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device)
608
609 with torch.no_grad():
610 outputs = self.model(pixel_values=pixel_values)
611 logits: torch.Tensor = outputs.logits
612 probs: torch.Tensor = torch.softmax(logits, dim=1)[0]
613 predicted_class: int = torch.argmax(probs).item()
614
615 return {
616 "prediction": self.class_names[predicted_class],
617 "confidence": probs[predicted_class].item(),
618 "sfw_prob": probs[0].item(),
619 "nsfw_prob": probs[1].item(),
620 }
621
622
623@click.group()
624def cli():
625 pass
626
627
628@cli.command()
629@click.option(
630 "--augment",
631 type=bool,
632 default=False,
633 required=False,
634)
635def train(
636 augment: Optional[bool],
637):
638 config = CONFIG
639
640 print("=" * 60)
641 print("NSFW Image Classifier - ViT Finetuner")
642 print("=" * 60)
643 print(f"Device: {config.device}")
644 print(f"Model: {config.model_name}")
645 print(f"Batch size: {config.batch_size}")
646 print(f"Learning rate: {config.learning_rate}")
647 print(f"Epochs: {config.num_epochs}")
648
649 logger.info("Loading image processor")
650 processor = ViTImageProcessor.from_pretrained(config.model_name)
651
652 logger.info("Creating dataset")
653 image_paths, labels = load_image_paths_and_labels(
654 config.nsfw_dir,
655 config.sfw_humans_dir,
656 config.sfw_anime_dir,
657 )
658
659 logger.info("Creating dataloaders")
660 train_loader, val_loader, test_loader = create_dataloaders(
661 image_paths, labels, processor, config, augment=augment or False
662 )
663
664 model = create_model(config)
665
666 model, history = train_model(model, train_loader, val_loader, config)
667
668 logger.info("Evaluating model")
669 evaluate_model(model, test_loader, config)
670
671 logger.info("Generating plots")
672 plot_training_history(history)
673
674 logger.info("Saving model")
675 save_model(model, processor, config)
676
677
678@cli.command()
679@click.argument("image", type=str, required=True)
680def predict(
681 image: str,
682):
683 classifer = NSFWClassifier(
684 CONFIG.model_save_path,
685 CONFIG.device,
686 )
687
688 print(f"{classifer.predict_single(image)}")
689
690
691@cli.command()
692@click.argument("url", type=str, required=True)
693def predict_url(
694 url: str,
695):
696 classifier = NSFWClassifier(
697 CONFIG.model_save_path,
698 CONFIG.device,
699 )
700
701 resp = requests.get(url)
702 resp.raise_for_status()
703
704 # we can use predict_bytes_single, but we wont so that we can do something with this image like
705 # move it to a directory for retraining
706 with open("./testimg.jpeg", "wb") as f:
707 f.write(resp.content)
708
709 print(f"{classifier.predict_single('./testimg.jpeg')}")
710
711
712if __name__ == "__main__":
713 cli()
714