import logging import random from io import BytesIO from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple, Union import click import matplotlib.pyplot as plt import numpy as np import requests import seaborn as sns import torch.nn as nn import torch.optim as optim from PIL import Image from sklearn.metrics import classification_report, confusion_matrix from torch import torch from torch.utils.data import DataLoader, Dataset, random_split from torchvision import transforms from tqdm import tqdm from transformers import ViTForImageClassification, ViTImageProcessor from config import CONFIG, Config logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger() class NSFWDataset(Dataset[torch.Tensor]): def __init__( self, image_paths: List[str], labels: List[int], processor: ViTImageProcessor, augment: bool = False, ): self.image_paths = image_paths self.labels = labels self.processor = processor if augment: self.transform = transforms.Compose( [ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(15), transforms.ColorJitter( brightness=0.2, contrast=0.2, saturation=0.2 ), transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), ] ) else: self.transform = transforms.Compose( [ transforms.Resize((224, 224)), ] ) def __len__(self) -> int: return len(self.image_paths) def __getitem__(self, idx: int) -> Optional[Tuple[torch.Tensor, int]]: img_path = self.image_paths[idx] try: image: Image.Image = Image.open(img_path).convert("RGB") image = self.transform(image) processed = self.processor(images=image, return_tensors="pt") pixel_values = processed["pixel_values"].squeeze(0) label: int = self.labels[idx] return pixel_values, label except Exception as e: logger.error(f"Error loading image {img_path}: {e}") def load_image_paths_and_labels( nsfw_dir: str, sfw_human_dir: str, sfw_anime_dir: str, ) -> Tuple[List[str], List[int]]: nsfw_paths: List[str] = [] sfw_paths: List[str] = [] # false positives for nsfw override_paths: List[str] = [] valid_extensions: Set[str] = {".jpg", ".jpeg", ".png", ".webp"} logger.info(f"Loading NSFW images from {nsfw_dir}") nsfw_path = Path(nsfw_dir) if nsfw_path.exists(): for f in nsfw_path.iterdir(): if f.suffix.lower() in valid_extensions: nsfw_paths.append(str(f)) else: logger.warning(f"NSFW directory not found: {nsfw_dir}") logger.info(f"Loading SFW human images from {sfw_human_dir}") sfw_humans_path = Path(sfw_human_dir) if sfw_humans_path.exists(): for f in sfw_humans_path.iterdir(): if f.suffix.lower() in valid_extensions: sfw_paths.append(str(f)) else: logger.warning(f"SFW humans directory not found: {sfw_human_dir}") logger.info(f"Loading SFW anime images from {sfw_anime_dir}") sfw_anime_path = Path(sfw_anime_dir) if sfw_anime_path.exists(): for f in sfw_anime_path.iterdir(): if f.suffix.lower() in valid_extensions: sfw_paths.append(str(f)) else: logger.warning(f"SFW anime directory not found: {sfw_anime_dir}") nsfw_count = len(nsfw_paths) sfw_count = len(sfw_paths) logger.info("Loading overrides (NSFW false positives)") override_path = Path("./dataset_clean/overrides") if override_path.exists(): for f in override_path.iterdir(): if f.suffix.lower() in valid_extensions: override_paths.append(str(f)) else: logger.warning("No override path found") logger.info(f"Dataset loaded before balancing: {nsfw_count} NSFW, {sfw_count} SFW") if nsfw_count > 0 and sfw_count > 0: current_ratio = nsfw_count / sfw_count logger.info(f"Current NSFW/SFW ratio: {current_ratio:.2f}") needs_balancing = current_ratio < 0.5 or current_ratio > 2.0 target_ratio = 1.2 if needs_balancing: if current_ratio > target_ratio: target_nsfw = int(sfw_count * target_ratio) target_sfw = sfw_count logger.info(f"Downsampling NSFW: {nsfw_count} → {target_nsfw}") random.seed(42) nsfw_paths = random.sample(nsfw_paths, target_nsfw) else: target_nsfw = nsfw_count target_sfw = int(nsfw_count / target_ratio) logger.info(f"Downsampling SFW: {sfw_count} → {target_sfw}") random.seed(42) sfw_paths = random.sample(sfw_paths, target_sfw) nsfw_count = len(nsfw_paths) sfw_count = len(sfw_paths) new_ratio = nsfw_count / sfw_count logger.info(f"Dataset after balancing: {nsfw_count} NSFW, {sfw_count} SFW") logger.info(f"New NSFW/SFW ratio: {new_ratio:.2f}") image_paths: List[str] = [] labels: List[int] = [] for path in nsfw_paths: image_paths.append(path) labels.append(1) for path in sfw_paths: image_paths.append(path) labels.append(0) # Add in any overrides so they will show up even after balancing for path in override_paths: image_paths.append(path) labels.append(0) combined = list(zip(image_paths, labels)) random.shuffle(combined) image_paths, labels = zip(*combined) image_paths = list(image_paths) labels = list(labels) logger.info(f"Final dataset: {len(image_paths)} total images") return image_paths, labels def create_dataloaders( image_paths: List[str], labels: List[int], processor: ViTImageProcessor, config: Config, augment: bool, ) -> Tuple[Dataset[torch.Tensor], Dataset[torch.Tensor], Dataset[torch.Tensor]]: total_size = len(image_paths) train_size = int(config.train_ratio * total_size) val_size = int(config.val_ratio * total_size) test_size = total_size - train_size - val_size # create our splits without actually augmenting anything full_dataset = NSFWDataset(image_paths, labels, processor, augment=False) train_dataset_temp, val_dataset_temp, test_dataset_temp = random_split( full_dataset, [train_size, val_size, test_size], generator=torch.Generator().manual_seed(42), ) train_indicies = train_dataset_temp.indices val_indicies = val_dataset_temp.indices test_indicies = test_dataset_temp.indices train_paths = [image_paths[i] for i in train_indicies] train_labels = [labels[i] for i in train_indicies] train_dataset = NSFWDataset(train_paths, train_labels, processor, augment) val_paths = [image_paths[i] for i in val_indicies] val_labels = [labels[i] for i in val_indicies] val_dataset = NSFWDataset(val_paths, val_labels, processor, augment=False) test_paths = [image_paths[i] for i in test_indicies] test_labels = [labels[i] for i in test_indicies] test_dataset = NSFWDataset(test_paths, test_labels, processor, augment=False) pin_memory = config.device.type == "cuda" train_loader = DataLoader[torch.Tensor]( train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=config.num_workers, pin_memory=pin_memory, ) val_loader = DataLoader[torch.Tensor]( val_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=pin_memory, ) test_loader = DataLoader[torch.Tensor]( test_dataset, batch_size=config.batch_size, shuffle=False, num_workers=config.num_workers, pin_memory=pin_memory, ) return train_loader, val_loader, test_loader def create_model(config: Config) -> ViTForImageClassification: logger.info(f"Loading pretrained ViT model: {config.model_name}") model = ViTForImageClassification.from_pretrained( config.model_name, num_labels=config.num_classes, ignore_mismatched_sizes=True, ) model = model.to(config.device) return model def train_epoch( model: ViTForImageClassification, train_loader: DataLoader[torch.Tensor], criterion: nn.Module, optimizer: optim.Optimizer, device: torch.device, ): model.train() running_loss: float = 0.0 correct: int = 0 total: int = 0 progress_bar = tqdm(train_loader, desc="Training") for batch_idx, (images, labels) in enumerate(progress_bar): images: torch.Tensor = images.to(device) labels: torch.Tensor = labels.to(device) optimizer.zero_grad() outputs = model(pixel_values=images) logits: torch.Tensor = outputs.logits loss: torch.Tensor = criterion(logits, labels) loss.backward() optimizer.step() running_loss += loss.item() _, predicted = torch.max(logits, 1) total += labels.size(0) correct += int((predicted == labels).sum().item()) progress_bar.set_postfix( {"loss": running_loss / (batch_idx + 1), "acc": 100.0 * correct / total} ) epoch_loss: float = running_loss / len(train_loader) epoch_acc: float = 100.0 * correct / total return epoch_loss, epoch_acc def validate( model: ViTForImageClassification, val_loader: DataLoader[torch.Tensor], criterion: nn.Module, device: torch.device, ): model.eval() running_loss: float = 0.0 correct: int = 0 total: int = 0 all_preds: List[int] = [] all_labels: List[int] = [] with torch.no_grad(): progress_bar = tqdm(val_loader, desc="Validation") for images, labels in progress_bar: images: torch.Tensor = images.to(device) labels: torch.Tensor = labels.to(device) outputs = model(pixel_values=images) logits: torch.Tensor = outputs.logits loss: torch.Tensor = criterion(logits, labels) running_loss += loss.item() _, predicted = torch.max(logits, 1) total += labels.size(0) correct += int((predicted == labels).sum().item()) all_preds.extend(predicted.cpu().numpy().tolist()) all_labels.extend(labels.cpu().numpy().tolist()) progress_bar.set_postfix( {"loss": running_loss / len(val_loader), "acc": 100.0 * correct / total} ) epoch_loss: float = running_loss / len(val_loader) epoch_acc: float = 100.0 * correct / total return epoch_loss, epoch_acc, all_preds, all_labels def train_model( model: ViTForImageClassification, train_loader: DataLoader[torch.Tensor], val_loader: DataLoader[torch.Tensor], config: Config, ): criterion = nn.CrossEntropyLoss() optimizer = optim.AdamW( model.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay, ) scheduler = optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", factor=0.5, patience=2, ) best_val_acc = 0.0 best_model_state: Optional[Dict[str, Any]] = None history: Dict[str, List[float]] = { "train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], } for epoch in range(config.num_epochs): logger.info(f"Epoch {epoch + 1}/{config.num_epochs}") print("-" * 60) train_loss, train_acc = train_epoch( model, train_loader, criterion, optimizer, config.device, ) val_loss, val_acc, _, _ = validate(model, val_loader, criterion, config.device) scheduler.step(val_loss) history["train_loss"].append(train_loss) history["train_acc"].append(train_acc) history["val_loss"].append(val_loss) history["val_acc"].append(val_acc) print(f"\nEpoch {epoch + 1} Summary:") print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}") print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.2f}") if val_acc > best_val_acc: best_val_acc = val_acc best_model_state = model.state_dict().copy() print(f"New best model! (Val Acc: {val_acc:2f}%)") if best_model_state is not None: model.load_state_dict(best_model_state) logger.info(f"Training complete. Best validation accuracy: {best_val_acc:.2f}%") return model, history def save_model( model: ViTForImageClassification, processor: ViTImageProcessor, config: Config, ): Path(config.model_save_path).mkdir(parents=True, exist_ok=True) model.save_pretrained(config.model_save_path) processor.save_pretrained(config.model_save_path) print(f"Model saved to {config.model_save_path}") def load_trained_model( model_path: str, device: torch.device ) -> Tuple[ViTForImageClassification, ViTImageProcessor]: model = ViTForImageClassification.from_pretrained(model_path) processor = ViTImageProcessor.from_pretrained(model_path) model = model.to(device) model.eval() return model, processor def predict_image( image_path: str, model: ViTForImageClassification, processor: ViTImageProcessor, device: torch.device, config: Config, ) -> Tuple[str, float]: model.eval() image: Image.Image = Image.open(image_path).convert("RGB") inputs: Dict[str, torch.Tensor] = processor(images=image, return_tensors="pt") pixel_values: torch.Tensor = inputs["pixel_values"].to(device) with torch.no_grad(): outputs = model(pixel_values=pixel_values) logits: torch.Tensor = outputs.logits probs: torch.Tensor = torch.softmax(logits, dim=1) predicted_class: int = torch.argmax(probs, dim=1).item() confidence: float = probs[0, predicted_class].item() prediction: str = config.class_names[predicted_class] return prediction, confidence def evaluate_model( model: ViTForImageClassification, test_loader: DataLoader[torch.Tensor], config: Config, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: model.eval() all_preds: List[int] = [] all_labels: List[int] = [] all_probs: List[List[float]] = [] print("\n" + "=" * 60) print("Evaluating on Test Set") print("=" * 60) with torch.no_grad(): for images, labels in tqdm(test_loader, desc="Testing"): images: torch.Tensor = images.to(config.device) outputs = model(pixel_values=images) logits: torch.Tensor = outputs.logits probs: torch.Tensor = torch.softmax(logits, dim=1) _, predicted = torch.max(logits, 1) all_preds.extend(predicted.cpu().numpy().tolist()) all_labels.extend(labels.numpy().tolist()) all_probs.extend(probs.cpu().numpy().tolist()) all_preds_array: np.ndarray = np.array(all_preds) all_labels_array: np.ndarray = np.array(all_labels) all_probs_array: np.ndarray = np.array(all_probs) print("\nClassification Report:") print( classification_report( all_labels_array, all_preds_array, target_names=config.class_names, digits=4 ) ) cm: np.ndarray = confusion_matrix(all_labels_array, all_preds_array) plt.figure(figsize=(8, 6)) sns.heatmap( cm, annot=True, fmt="d", cmap="Blues", xticklabels=config.class_names, yticklabels=config.class_names, ) plt.title("Confusion Matrix") plt.ylabel("True Label") plt.xlabel("Predicted Label") plt.tight_layout() plt.savefig("confusion_matrix.png", dpi=300, bbox_inches="tight") print("\n✓ Confusion matrix saved as 'confusion_matrix.png'") return all_preds_array, all_labels_array, all_probs_array def plot_training_history(history: Dict[str, List[float]]) -> None: fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) ax1.plot(history["train_loss"], label="Train Loss", marker="o") ax1.plot(history["val_loss"], label="Val Loss", marker="s") ax1.set_xlabel("Epoch") ax1.set_ylabel("Loss") ax1.set_title("Training and Validation Loss") ax1.legend() ax1.grid(True) ax2.plot(history["train_acc"], label="Train Acc", marker="o") ax2.plot(history["val_acc"], label="Val Acc", marker="s") ax2.set_xlabel("Epoch") ax2.set_ylabel("Accuracy (%)") ax2.set_title("Training and Validation Accuracy") ax2.legend() ax2.grid(True) plt.tight_layout() plt.savefig("training_history.png", dpi=300, bbox_inches="tight") print("✓ Training history saved as 'training_history.png'") class NSFWClassifier: def __init__(self, model_path: str, device: Optional[torch.device] = None) -> None: if device is None: self.device: torch.device = torch.device( "cuda" if torch.cuda.is_available() else "cpu" ) else: self.device = device print(f"Loading model from {model_path}...") self.model: ViTForImageClassification = ( ViTForImageClassification.from_pretrained(model_path) ) self.processor: ViTImageProcessor = ViTImageProcessor.from_pretrained( model_path ) self.model = self.model.to(self.device) self.model.eval() self.class_names: List[str] = ["SFW", "NSFW"] print(f"✓ Model loaded on {self.device}") def predict_single(self, image_path: str) -> Dict[str, Union[str, float]]: image: Image.Image = Image.open(image_path).convert("RGB") inputs: Dict[str, torch.Tensor] = self.processor( images=image, return_tensors="pt" ) pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device) with torch.no_grad(): outputs = self.model(pixel_values=pixel_values) logits: torch.Tensor = outputs.logits probs: torch.Tensor = torch.softmax(logits, dim=1)[0] predicted_class: int = torch.argmax(probs).item() return { "prediction": self.class_names[predicted_class], "confidence": probs[predicted_class].item(), "sfw_prob": probs[0].item(), "nsfw_prob": probs[1].item(), } def predict_single_bytes(self, image_bytes: bytes) -> Dict[str, Union[str, float]]: image = Image.open(BytesIO(image_bytes)).convert("RGB") inputs: Dict[str, torch.Tensor] = self.processor( images=image, return_tensors="pt" ) pixel_values: torch.Tensor = inputs["pixel_values"].to(self.device) with torch.no_grad(): outputs = self.model(pixel_values=pixel_values) logits: torch.Tensor = outputs.logits probs: torch.Tensor = torch.softmax(logits, dim=1)[0] predicted_class: int = torch.argmax(probs).item() return { "prediction": self.class_names[predicted_class], "confidence": probs[predicted_class].item(), "sfw_prob": probs[0].item(), "nsfw_prob": probs[1].item(), } @click.group() def cli(): pass @cli.command() @click.option( "--augment", type=bool, default=False, required=False, ) def train( augment: Optional[bool], ): config = CONFIG print("=" * 60) print("NSFW Image Classifier - ViT Finetuner") print("=" * 60) print(f"Device: {config.device}") print(f"Model: {config.model_name}") print(f"Batch size: {config.batch_size}") print(f"Learning rate: {config.learning_rate}") print(f"Epochs: {config.num_epochs}") logger.info("Loading image processor") processor = ViTImageProcessor.from_pretrained(config.model_name) logger.info("Creating dataset") image_paths, labels = load_image_paths_and_labels( config.nsfw_dir, config.sfw_humans_dir, config.sfw_anime_dir, ) logger.info("Creating dataloaders") train_loader, val_loader, test_loader = create_dataloaders( image_paths, labels, processor, config, augment=augment or False ) model = create_model(config) model, history = train_model(model, train_loader, val_loader, config) logger.info("Evaluating model") evaluate_model(model, test_loader, config) logger.info("Generating plots") plot_training_history(history) logger.info("Saving model") save_model(model, processor, config) @cli.command() @click.argument("image", type=str, required=True) def predict( image: str, ): classifer = NSFWClassifier( CONFIG.model_save_path, CONFIG.device, ) print(f"{classifer.predict_single(image)}") @cli.command() @click.argument("url", type=str, required=True) def predict_url( url: str, ): classifier = NSFWClassifier( CONFIG.model_save_path, CONFIG.device, ) resp = requests.get(url) resp.raise_for_status() # we can use predict_bytes_single, but we wont so that we can do something with this image like # move it to a directory for retraining with open("./testimg.jpeg", "wb") as f: f.write(resp.content) print(f"{classifier.predict_single('./testimg.jpeg')}") if __name__ == "__main__": cli()